Skip to content

Commit

Permalink
Addressed review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Srinath Krishnamachari <[email protected]>
  • Loading branch information
srinathk10 committed Oct 18, 2024
1 parent 1e26cc2 commit a5b2d55
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 62 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -399,10 +399,10 @@ def num_active_tasks(self) -> int:
"""
return len(self.get_active_tasks())

def num_active_actors(self) -> int:
"""Return the number of active actors.
def num_alive_actors(self) -> int:
"""Return the number of alive actors.
This method is used to display active actor info in the progress bar.
This method is used to display alive actor info in the progress bar.
"""
return 0

Expand All @@ -413,6 +413,13 @@ def num_pending_actors(self) -> int:
"""
return 0

def num_restarting_actors(self) -> int:
"""Return the number of restarting actors.
This method is used to display restarting actor info in the progress bar.
"""
return 0

def throttling_disabled(self) -> bool:
"""Whether to disable resource throttling for this operator.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -276,8 +276,6 @@ def progress_str(self) -> str:
return locality_string(
self._actor_pool._locality_hits,
self._actor_pool._locality_misses,
self._actor_pool.num_pending_actors(),
self._actor_pool.num_restarting_actors(),
)
return "[locality off]"

Expand Down Expand Up @@ -307,12 +305,12 @@ def pending_processor_usage(self) -> ExecutionResources:
gpu=self._ray_remote_args.get("num_gpus", 0) * num_pending_workers,
)

def num_active_actors(self) -> int:
"""Return the number of active actors.
def num_alive_actors(self) -> int:
"""Return the number of alive actors.
This method is used to display active actor info in the progress bar.
This method is used to display alive actor info in the progress bar.
"""
return self._actor_pool.num_running_actors()
return self._actor_pool.num_alive_actors()

def num_pending_actors(self) -> int:
"""Return the number of pending actors.
Expand All @@ -321,6 +319,13 @@ def num_pending_actors(self) -> int:
"""
return self._actor_pool.num_pending_actors()

def num_restarting_actors(self) -> int:
"""Return the number of restarting actors.
This method is used to display restarting actor info in the progress bar.
"""
return self._actor_pool.num_restarting_actors()

def incremental_resource_usage(self) -> ExecutionResources:
# Submitting tasks to existing actors doesn't require additional
# CPU/GPU resources.
Expand All @@ -336,8 +341,8 @@ def _extra_metrics(self) -> Dict[str, Any]:
if self._actor_locality_enabled:
res["locality_hits"] = self._actor_pool._locality_hits
res["locality_misses"] = self._actor_pool._locality_misses
res["pending_actors"] = self._actor_pool.num_pending_actors()
res["restarting_actors"] = self._actor_pool.num_restarting_actors()
res["pending_actors"] = self._actor_pool.num_pending_actors()
res["restarting_actors"] = self._actor_pool.num_restarting_actors()
return res

@staticmethod
Expand Down Expand Up @@ -369,7 +374,13 @@ def update_resource_usage(self) -> None:
if actor_state is None:
# actor._get_local_state can return None if the state is Unknown
continue
self._actor_pool.update_running_actor_state(actor, actor_state)
elif actor_state != gcs_pb2.ActorTableData.ActorState.ALIVE:
# The actors can be either ALIVE or RESTARTING here because they will
# be restarted indefinitely until execution finishes.
assert actor_state == gcs_pb2.ActorTableData.ActorState.RESTARTING
self._actor_pool.update_running_actor_state(actor, True)
else:
self._actor_pool.update_running_actor_state(actor, False)


class _MapWorker:
Expand Down Expand Up @@ -417,8 +428,8 @@ class _ActorState:
# Node id of each ready actor
actor_location: str

# Actor state
actor_state: gcs_pb2.ActorTableData.ActorState
# Is Actor state restarting or alive
is_restarting: bool


class _ActorPool(AutoscalingActorPool):
Expand Down Expand Up @@ -474,22 +485,21 @@ def num_running_actors(self) -> int:
def num_restarting_actors(self) -> int:
"""Restarting actors are all the running actors not in ALIVE state."""
return sum(
running_actor_state.actor_state != gcs_pb2.ActorTableData.ActorState.ALIVE
for running_actor_state in self._running_actors.values()
actor_state.is_restarting for actor_state in self._running_actors.values()
)

def num_active_actors(self) -> int:
"""Active actors are all the running actors with inflight tasks."""
return sum(
1 if running_actor_state.num_tasks_in_flight > 0 else 0
for running_actor_state in self._running_actors.values()
1 if actor_state.num_tasks_in_flight > 0 else 0
for actor_state in self._running_actors.values()
)

def num_alive_actors(self) -> int:
"""Alive actors are all the running actors in ALIVE state."""
return sum(
running_actor_state.actor_state == gcs_pb2.ActorTableData.ActorState.ALIVE
for running_actor_state in self._running_actors.values()
not actor_state.is_restarting
for actor_state in self._running_actors.values()
)

def num_pending_actors(self) -> int:
Expand All @@ -500,8 +510,8 @@ def max_tasks_in_flight_per_actor(self) -> int:

def current_in_flight_tasks(self) -> int:
return sum(
running_actor_state.num_tasks_in_flight
for running_actor_state in self._running_actors.values()
actor_state.num_tasks_in_flight
for actor_state in self._running_actors.values()
)

def scale_up(self, num_actors: int) -> int:
Expand All @@ -520,18 +530,16 @@ def scale_down(self, num_actors: int) -> int:
# === End of overriding methods of AutoscalingActorPool ===

def update_running_actor_state(
self,
actor: ray.actor.ActorHandle,
actor_state: gcs_pb2.ActorTableData.ActorState,
self, actor: ray.actor.ActorHandle, is_restarting: bool
):
"""Update running actor state.
Args:
actor: The running actor that needs state update.
actor_state: Updated actor state for the running actor.
is_restarting: Whether running actor is restarting or alive.
"""
assert actor in self._running_actors
self._running_actors[actor].actor_state = actor_state
self._running_actors[actor].is_restarting = is_restarting

def add_pending_actor(self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef):
"""Adds a pending actor to the pool.
Expand Down Expand Up @@ -566,7 +574,7 @@ def pending_to_running(self, ready_ref: ray.ObjectRef) -> bool:
self._running_actors[actor] = _ActorState(
num_tasks_in_flight=0,
actor_location=ray.get(ready_ref),
actor_state=gcs_pb2.ActorTableData.ActorState.ALIVE,
is_restarting=False,
)
return True

Expand All @@ -590,33 +598,32 @@ def pick_actor(
else:
preferred_loc = None

def penalty_key(actor):
"""Returns the key that should be minimized for the best actor.
We prioritize actors with argument locality, and those that are not busy,
in that order.
"""
busyness = self._running_actors[actor].num_tasks_in_flight
requires_remote_fetch = (
self._running_actors[actor].actor_location != preferred_loc
)
return requires_remote_fetch, busyness

# Filter out actors that are invalid, i.e. actors with number of tasks in
# flight >= _max_tasks_in_flight or actor_state is not ALIVE.
valid_actors = [
actor
for actor in self._running_actors
if self._running_actors[actor].num_tasks_in_flight
< self._max_tasks_in_flight
and self._running_actors[actor].actor_state
== gcs_pb2.ActorTableData.ActorState.ALIVE
and not self._running_actors[actor].is_restarting
]

if not valid_actors:
# All actors are at capacity or actor state is not ALIVE.
return None

def penalty_key(actor):
"""Returns the key that should be minimized for the best actor.
We prioritize actors with argument locality, and those that are not busy,
in that order.
"""
busyness = self._running_actors[actor].num_tasks_in_flight
requires_remote_fetch = (
self._running_actors[actor].actor_location != preferred_loc
)
return requires_remote_fetch, busyness

# Pick the best valid actor based on the penalty key
actor = min(valid_actors, key=penalty_key)

Expand Down
24 changes: 16 additions & 8 deletions python/ray/data/_internal/execution/streaming_executor_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,6 +256,20 @@ def refresh_progress_bar(self, resource_manager: ResourceManager) -> None:
self.progress_bar.set_description(self.summary_str(resource_manager))
self.progress_bar.refresh()

def actor_info_progress_str(self) -> str:
# Alive/Pending/Restarting actors
alive = self.op.num_alive_actors()
pending = self.op.num_pending_actors()
restarting = self.op.num_restarting_actors()
total = alive + pending + restarting
if total == alive:
return f"; Actors: {total}"
else:
return (
f"; Actors: {total} (alive {alive}, restarting {restarting}, "
f"pending {pending})"
)

def summary_str(self, resource_manager: ResourceManager) -> str:
# Active tasks
active = self.op.num_active_tasks()
Expand All @@ -266,14 +280,8 @@ def summary_str(self, resource_manager: ResourceManager) -> str:
):
desc += " [backpressured]"

# Active/pending actors
active = self.op.num_active_actors()
pending = self.op.num_pending_actors()
if active or pending:
actor_str = f"; Actors: {active}"
if pending > 0:
actor_str += f", (pending: {pending})"
desc += actor_str
# Actors info
desc += self.actor_info_progress_str()

# Queued blocks
queued = self.num_queued() + self.op.internal_queue_size()
Expand Down
30 changes: 18 additions & 12 deletions python/ray/data/tests/test_actor_pool_map_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,6 @@
import ray
from ray._private.test_utils import wait_for_condition
from ray.actor import ActorHandle
from ray.core.generated import gcs_pb2
from ray.data._internal.compute import ActorPoolStrategy
from ray.data._internal.execution.operators.actor_pool_map_operator import _ActorPool
from ray.data._internal.execution.util import make_ref_bundles
Expand Down Expand Up @@ -135,9 +134,7 @@ def test_restarting_to_alive(self):
actor = self._add_ready_actor(pool)

# Mark the actor as restarting and test pick_actor fails
pool.update_running_actor_state(
actor, gcs_pb2.ActorTableData.ActorState.RESTARTING
)
pool.update_running_actor_state(actor, True)
assert pool.pick_actor() is None
assert pool.current_size() == 1
assert pool.num_pending_actors() == 0
Expand All @@ -149,7 +146,7 @@ def test_restarting_to_alive(self):
assert pool.num_free_slots() == 1

# Mark the actor as alive and test pick_actor succeeds
pool.update_running_actor_state(actor, gcs_pb2.ActorTableData.ActorState.ALIVE)
pool.update_running_actor_state(actor, False)
picked_actor = pool.pick_actor()
assert picked_actor == actor
assert pool.current_size() == 1
Expand All @@ -163,6 +160,14 @@ def test_restarting_to_alive(self):

# Return the actor
pool.return_actor(picked_actor)
assert pool.current_size() == 1
assert pool.num_pending_actors() == 0
assert pool.num_running_actors() == 1
assert pool.num_restarting_actors() == 0
assert pool.num_alive_actors() == 1
assert pool.num_active_actors() == 0
assert pool.num_idle_actors() == 1
assert pool.num_free_slots() == 1

def test_repeated_picking(self):
# Test that we can repeatedly pick the same actor.
Expand Down Expand Up @@ -594,14 +599,18 @@ def __call__(self, x):
).take_all()


def test_actor_pool_fault_tolerance_e2e(ray_start_cluster):
def test_actor_pool_fault_tolerance_e2e(ray_start_cluster, restore_data_context):
"""Test that a dataset with actor pools can finish, when
all nodes in the cluster are removed and added back."""
ray.shutdown()
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
ray.init()

# Ensure block size is small enough to pass resource limits
context = ray.data.DataContext.get_current()
context.target_max_block_size = 1

@ray.remote(num_cpus=0)
class Signal:
def __init__(self):
Expand Down Expand Up @@ -635,10 +644,10 @@ async def wait_for_nodes_restarted(self):
signal_actor = Signal.remote()

# Spin up nodes
num_nodes = 1
num_nodes = 4
nodes = []
for _ in range(num_nodes):
nodes.append(cluster.add_node(num_cpus=10))
nodes.append(cluster.add_node(num_cpus=10, num_gpus=1))
cluster.wait_for_nodes()

class MyUDF:
Expand All @@ -656,9 +665,6 @@ def __call__(self, batch):
# actors are running tasks when removing nodes.
ray.get(self._signal_actor.wait_for_nodes_removed.remote())

# Wait for the driver to add nodes.
ray.get(self._signal_actor.wait_for_nodes_restarted.remote())

self._signal_sent = True

return batch
Expand Down Expand Up @@ -693,7 +699,7 @@ def run_dataset():

# Add back all the nodes
for _ in range(num_nodes):
nodes.append(cluster.add_node(num_cpus=10))
nodes.append(cluster.add_node(num_cpus=10, num_gpus=1))
cluster.wait_for_nodes()
ray.get(signal_actor.notify_nodes_restarted.remote())

Expand Down

0 comments on commit a5b2d55

Please sign in to comment.