Skip to content

Commit

Permalink
Manage restarting state for Actors during _task_done_callback
Browse files Browse the repository at this point in the history
Signed-off-by: Srinath Krishnamachari <[email protected]>
  • Loading branch information
srinathk10 committed Oct 10, 2024
1 parent b8d705a commit 766db7d
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 19 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,17 @@ def _dispatch_tasks(self):
).remote(DataContext.get_current(), ctx, *input_blocks)

def _task_done_callback(actor_to_return):
# Return the actor that was running the task to the pool.
self._actor_pool.return_actor(actor_to_return)
if actor_to_return in self._actor_pool._num_tasks_in_flight:
# Return the actor that was running the task to the pool.
self._actor_pool.return_actor(actor_to_return)
else:
assert (
actor_to_return.get_location
in self._actor_pool._restarting_actors
)
# Move the actor from restarting to running state.
self._actor_pool.restarting_to_running(actor_to_return)

# Dipsatch more tasks.
self._dispatch_tasks()

Expand Down Expand Up @@ -294,9 +303,11 @@ def current_processor_usage(self) -> ExecutionResources:

def pending_processor_usage(self) -> ExecutionResources:
num_pending_workers = self._actor_pool.num_pending_actors()
num_restarting_workers = self._actor_pool.num_restarting_actors()
num_non_running_workers = num_pending_workers + num_restarting_workers
return ExecutionResources(
cpu=self._ray_remote_args.get("num_cpus", 0) * num_pending_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_pending_workers,
cpu=self._ray_remote_args.get("num_cpus", 0) * num_non_running_workers,
gpu=self._ray_remote_args.get("num_gpus", 0) * num_non_running_workers,
)

def num_active_actors(self) -> int:
Expand Down Expand Up @@ -354,18 +365,19 @@ def get_autoscaling_actor_pools(self) -> List[AutoscalingActorPool]:

def update_resource_usage(self) -> None:
"""Updates resources usage."""
# Walk all active actors and for each actor that's not ALIVE,
# it's a candidate to be marked as a pending actor.
actors = list(self._actor_pool._num_tasks_in_flight.keys())
for actor in actors:
actor_state = actor._get_local_state()
if (actor_state != gcs_pb2.ActorTableData.ActorState.ALIVE):
if actor_state != gcs_pb2.ActorTableData.ActorState.ALIVE:
# If an actor is not ALIVE, it's a candidate to be marked as a
# restarting actor.
self._actor_pool.running_to_restarting(actor, actor.get_location)
else:
# If an actor is ALIVE, it's a candidate to be marked as a
# running actor, if not already the case.
self._actor_pool.restarting_to_running(actor.get_location)



class _MapWorker:
"""An actor worker for MapOperator."""

Expand Down Expand Up @@ -522,12 +534,15 @@ def pending_to_running(self, ready_ref: ray.ObjectRef) -> bool:
self._actor_locations[actor] = ray.get(ready_ref)
return True

def running_to_restarting(self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef) -> bool:
"""Mark the actor corresponding to the provided ready future as restaring.
def running_to_restarting(
self, actor: ray.actor.ActorHandle, ready_ref: ray.ObjectRef
) -> bool:
"""Mark the actor corresponding to the provided ready future as restarting.
Args:
actor: The running actor to add as restarting to the pool.
ready_ref: The ready future for the actor that we wish to mark as restarting.
ready_ref: The ready future for the actor that we wish to mark as
restarting.
Returns:
Whether the actor was still running. This can return False if the actor had
Expand All @@ -548,8 +563,8 @@ def restarting_to_running(self, ready_ref: ray.ObjectRef) -> bool:
ready_ref: The ready future for the actor that we wish to mark as running.
Returns:
Whether the actor was still restarting. This can return False if the actor had
already been killed.
Whether the actor was still restarting. This can return False if the actor
had already been killed.
"""
if ready_ref not in self._restarting_actors:
# The actor has been removed from the pool before becoming running.
Expand Down Expand Up @@ -635,7 +650,8 @@ def num_free_slots(self) -> int:
)

def kill_inactive_actor(self) -> bool:
"""Kills a single pending, restarting or idle actor, if any actors are pending/restarting/idle.
"""Kills a single pending, restarting or idle actor, if any actors are
pending/restarting/idle.
Returns whether an inactive actor was actually killed.
"""
Expand Down
8 changes: 3 additions & 5 deletions python/ray/data/tests/test_actor_pool_fault_tolerance.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,9 @@
import asyncio
import threading
import time

import pytest

import ray
from ray.data.context import DataContext
from ray.tests.conftest import * # noqa


Expand All @@ -15,7 +13,6 @@ def test_removed_nodes_and_added_back(ray_start_cluster):
cluster = ray_start_cluster
cluster.add_node(num_cpus=0)
ray.init()
#DataContext.get_current().max_errored_blocks = -1

@ray.remote(num_cpus=0)
class Signal:
Expand Down Expand Up @@ -50,7 +47,7 @@ async def wait_for_nodes_restarted(self):
signal_actor = Signal.remote()

# Spin up nodes
num_nodes = 3
num_nodes = 5
nodes = []
for _ in range(num_nodes):
nodes.append(cluster.add_node(num_cpus=10, num_gpus=1))
Expand Down Expand Up @@ -79,7 +76,7 @@ def __call__(self, batch):
return batch

res = []
num_items = 10
num_items = 100

def run_dataset():
nonlocal res
Expand Down Expand Up @@ -116,6 +113,7 @@ def run_dataset():
thread.join()
assert sorted(res, key=lambda x: x["id"]) == [{"id": i} for i in range(num_items)]


if __name__ == "__main__":
import sys

Expand Down

0 comments on commit 766db7d

Please sign in to comment.