Skip to content

Commit

Permalink
refactor: Test separate types
Browse files Browse the repository at this point in the history
Signed-off-by: Weixin Deng <[email protected]>
  • Loading branch information
dengwxn committed Oct 18, 2024
1 parent a0f1381 commit 6e71c93
Show file tree
Hide file tree
Showing 10 changed files with 67 additions and 38 deletions.
4 changes: 3 additions & 1 deletion python/ray/dag/collective_node.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,9 @@
from ray.experimental.channel.torch_tensor_nccl_channel import _init_nccl_group
from ray.experimental.channel.torch_tensor_type import GPUCommunicator, TorchTensorType
from ray.util.annotations import DeveloperAPI
from ray.util.collective.types import _CollectiveOp, ReduceOp
# [TYPE]
# from ray.util.collective.types import _CollectiveOp, ReduceOp
from ray.experimental.util.types import _CollectiveOp, ReduceOp


class _CollectiveOperation:
Expand Down
4 changes: 3 additions & 1 deletion python/ray/dag/tests/experimental/test_torch_tensor_dag.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
from ray.experimental.channel.nccl_group import _NcclGroup
from ray.experimental.channel.torch_tensor_type import TorchTensorType
from ray.tests.conftest import * # noqa
from ray.util.collective.types import ReduceOp
# [TYPE]
# from ray.util.collective.types import ReduceOp
from ray.experimental.util.types import ReduceOp

logger = logging.getLogger(__name__)

Expand Down
4 changes: 3 additions & 1 deletion python/ray/experimental/channel/gpu_communicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import ray
from ray.util.annotations import DeveloperAPI
from ray.util.collective.types import ReduceOp
# [TYPE]
# from ray.util.collective.types import ReduceOp
from ray.experimental.util.types import ReduceOp

if TYPE_CHECKING:
import torch
Expand Down
4 changes: 3 additions & 1 deletion python/ray/experimental/channel/nccl_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
GPUCommunicator,
TorchTensorAllocator,
)
from ray.util.collective.types import ReduceOp
# [TYPE]
# from ray.util.collective.types import ReduceOp
from ray.experimental.util.types import ReduceOp

if TYPE_CHECKING:
import cupy as cp
Expand Down
4 changes: 3 additions & 1 deletion python/ray/experimental/collective/allreduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,9 @@
PARENT_CLASS_NODE_KEY,
)
from ray.experimental.channel.torch_tensor_type import GPUCommunicator, TorchTensorType
from ray.util.collective.types import ReduceOp
# [TYPE]
# from ray.util.collective.types import ReduceOp
from ray.experimental.util.types import ReduceOp

logger = logging.getLogger(__name__)

Expand Down
Empty file.
19 changes: 19 additions & 0 deletions python/ray/experimental/util/types.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from enum import Enum

from ray.util.annotations import PublicAPI


class _CollectiveOp(Enum):
pass


@PublicAPI
class ReduceOp(_CollectiveOp):
SUM = 0
PRODUCT = 1
MAX = 2
MIN = 3
AVG = 4

def __str__(self):
return f"{self.name.lower()}"
52 changes: 26 additions & 26 deletions python/ray/util/collective/collective.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np

import ray
from ray.util.annotations import DeveloperAPI
from ray.util.annotations import PublicAPI
from ray.util.collective import types

_NCCL_AVAILABLE = True
Expand All @@ -30,17 +30,17 @@
_GLOO_AVAILABLE = False


@DeveloperAPI
@PublicAPI
def nccl_available():
return _NCCL_AVAILABLE


@DeveloperAPI
@PublicAPI
def gloo_available():
return _GLOO_AVAILABLE


@DeveloperAPI
@PublicAPI
class GroupManager(object):
"""Use this class to manage the collective groups we created so far.
Expand Down Expand Up @@ -116,13 +116,13 @@ def destroy_collective_group(self, group_name):
_group_mgr = GroupManager()


@DeveloperAPI
@PublicAPI
def is_group_initialized(group_name):
"""Check if the group is initialized in this process by the group name."""
return _group_mgr.is_group_exist(group_name)


@DeveloperAPI
@PublicAPI
def init_collective_group(
world_size: int, rank: int, backend=types.Backend.NCCL, group_name: str = "default"
):
Expand Down Expand Up @@ -154,7 +154,7 @@ def init_collective_group(
_group_mgr.create_collective_group(backend, world_size, rank, group_name)


@DeveloperAPI
@PublicAPI
def create_collective_group(
actors,
world_size: int,
Expand Down Expand Up @@ -220,15 +220,15 @@ def create_collective_group(


# TODO (we need a declarative destroy() API here.)
@DeveloperAPI
@PublicAPI
def destroy_collective_group(group_name: str = "default") -> None:
"""Destroy a collective group given its group name."""
_check_inside_actor()
global _group_mgr
_group_mgr.destroy_collective_group(group_name)


@DeveloperAPI
@PublicAPI
def get_rank(group_name: str = "default") -> int:
"""Return the rank of this process in the given group.
Expand All @@ -247,7 +247,7 @@ def get_rank(group_name: str = "default") -> int:
return g.rank


@DeveloperAPI
@PublicAPI
def get_collective_group_size(group_name: str = "default") -> int:
"""Return the size of the collective group with the given name.
Expand All @@ -265,7 +265,7 @@ def get_collective_group_size(group_name: str = "default") -> int:
return g.world_size


@DeveloperAPI
@PublicAPI
def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM):
"""Collective allreduce the tensor across the group.
Expand All @@ -284,7 +284,7 @@ def allreduce(tensor, group_name: str = "default", op=types.ReduceOp.SUM):
g.allreduce([tensor], opts)


@DeveloperAPI
@PublicAPI
def allreduce_multigpu(
tensor_list: list, group_name: str = "default", op=types.ReduceOp.SUM
):
Expand All @@ -307,7 +307,7 @@ def allreduce_multigpu(
g.allreduce(tensor_list, opts)


@DeveloperAPI
@PublicAPI
def barrier(group_name: str = "default"):
"""Barrier all processes in the collective group.
Expand All @@ -321,7 +321,7 @@ def barrier(group_name: str = "default"):
g.barrier()


@DeveloperAPI
@PublicAPI
def reduce(
tensor, dst_rank: int = 0, group_name: str = "default", op=types.ReduceOp.SUM
):
Expand All @@ -348,7 +348,7 @@ def reduce(
g.reduce([tensor], opts)


@DeveloperAPI
@PublicAPI
def reduce_multigpu(
tensor_list: list,
dst_rank: int = 0,
Expand Down Expand Up @@ -385,7 +385,7 @@ def reduce_multigpu(
g.reduce(tensor_list, opts)


@DeveloperAPI
@PublicAPI
def broadcast(tensor, src_rank: int = 0, group_name: str = "default"):
"""Broadcast the tensor from a source process to all others.
Expand All @@ -408,7 +408,7 @@ def broadcast(tensor, src_rank: int = 0, group_name: str = "default"):
g.broadcast([tensor], opts)


@DeveloperAPI
@PublicAPI
def broadcast_multigpu(
tensor_list, src_rank: int = 0, src_tensor: int = 0, group_name: str = "default"
):
Expand Down Expand Up @@ -437,7 +437,7 @@ def broadcast_multigpu(
g.broadcast(tensor_list, opts)


@DeveloperAPI
@PublicAPI
def allgather(tensor_list: list, tensor, group_name: str = "default"):
"""Allgather tensors from each process of the group into a list.
Expand All @@ -463,7 +463,7 @@ def allgather(tensor_list: list, tensor, group_name: str = "default"):
g.allgather([tensor_list], [tensor], opts)


@DeveloperAPI
@PublicAPI
def allgather_multigpu(
output_tensor_lists: list, input_tensor_list: list, group_name: str = "default"
):
Expand All @@ -488,7 +488,7 @@ def allgather_multigpu(
g.allgather(output_tensor_lists, input_tensor_list, opts)


@DeveloperAPI
@PublicAPI
def reducescatter(
tensor, tensor_list: list, group_name: str = "default", op=types.ReduceOp.SUM
):
Expand Down Expand Up @@ -519,7 +519,7 @@ def reducescatter(
g.reducescatter([tensor], [tensor_list], opts)


@DeveloperAPI
@PublicAPI
def reducescatter_multigpu(
output_tensor_list,
input_tensor_lists,
Expand Down Expand Up @@ -549,7 +549,7 @@ def reducescatter_multigpu(
g.reducescatter(output_tensor_list, input_tensor_lists, opts)


@DeveloperAPI
@PublicAPI
def send(tensor, dst_rank: int, group_name: str = "default"):
"""Send a tensor to a remote process synchronously.
Expand All @@ -571,7 +571,7 @@ def send(tensor, dst_rank: int, group_name: str = "default"):
g.send([tensor], opts)


@DeveloperAPI
@PublicAPI
def send_multigpu(
tensor,
dst_rank: int,
Expand Down Expand Up @@ -614,7 +614,7 @@ def send_multigpu(
g.send([tensor], opts)


@DeveloperAPI
@PublicAPI
def recv(tensor, src_rank: int, group_name: str = "default"):
"""Receive a tensor from a remote process synchronously.
Expand All @@ -636,7 +636,7 @@ def recv(tensor, src_rank: int, group_name: str = "default"):
g.recv([tensor], opts)


@DeveloperAPI
@PublicAPI
def recv_multigpu(
tensor,
src_rank: int,
Expand Down Expand Up @@ -677,7 +677,7 @@ def recv_multigpu(
g.recv([tensor], opts)


@DeveloperAPI
@PublicAPI
def synchronize(gpu_id: int):
"""Synchronize the current process to a give device.
Expand Down
6 changes: 3 additions & 3 deletions python/ray/util/collective/const.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,10 +7,10 @@
import os
from enum import Enum, auto

from ray.util.annotations import DeveloperAPI
from ray.util.annotations import PublicAPI


@DeveloperAPI
@PublicAPI
def get_store_name(group_name):
"""Generate the unique name for the NCCLUniqueID store (named actor).
Expand All @@ -25,7 +25,7 @@ def get_store_name(group_name):
return hexlified_name


@DeveloperAPI
@PublicAPI
class ENV(Enum):
"""ray.util.collective environment variables."""

Expand Down
8 changes: 4 additions & 4 deletions python/ray/util/collective/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from dataclasses import dataclass
from datetime import timedelta

from ray.util.annotations import PublicAPI, DeveloperAPI
from ray.util.annotations import PublicAPI

_NUMPY_AVAILABLE = True
_TORCH_AVAILABLE = True
Expand All @@ -20,17 +20,17 @@
_CUPY_AVAILABLE = False


@DeveloperAPI
@PublicAPI
def cupy_available():
return _CUPY_AVAILABLE


@DeveloperAPI
@PublicAPI
def torch_available():
return _TORCH_AVAILABLE


@DeveloperAPI
@PublicAPI
class Backend(object):
"""A class to represent different backends."""

Expand Down

0 comments on commit 6e71c93

Please sign in to comment.