Skip to content

Commit

Permalink
Restore retire workers API (#8939)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Nov 27, 2024
1 parent 0660dab commit 03a45a8
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 17 deletions.
26 changes: 13 additions & 13 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7445,7 +7445,7 @@ async def retire_workers(
close_workers: bool = False,
remove: bool = True,
stimulus_id: str | None = None,
) -> list[str]: ...
) -> dict[str, Any]: ...

@overload
async def retire_workers(
Expand All @@ -7455,7 +7455,7 @@ async def retire_workers(
close_workers: bool = False,
remove: bool = True,
stimulus_id: str | None = None,
) -> list[str]: ...
) -> dict[str, Any]: ...

@overload
async def retire_workers(
Expand All @@ -7471,7 +7471,7 @@ async def retire_workers(
minimum: int | None = None,
target: int | None = None,
attribute: str = "address",
) -> list[str]: ...
) -> dict[str, Any]: ...

@log_errors
async def retire_workers(
Expand All @@ -7483,7 +7483,7 @@ async def retire_workers(
remove: bool = True,
stimulus_id: str | None = None,
**kwargs: Any,
) -> list[str]:
) -> dict[str, Any]:
"""Gracefully retire workers from cluster. Any key that is in memory exclusively
on the retired workers is replicated somewhere else.
Expand Down Expand Up @@ -7565,7 +7565,7 @@ async def retire_workers(
self.workers[address] for address in self.workers_to_close(**kwargs)
}
if not wss:
return []
return {}

stop_amm = False
amm: ActiveMemoryManagerExtension | None = self.extensions.get("amm")
Expand Down Expand Up @@ -7613,13 +7613,13 @@ async def retire_workers(
# time (depending on interval settings)
amm.run_once()

workers_info_ok = []
workers_info_abort = []
for addr, result in await asyncio.gather(*coros):
workers_info_ok = {}
workers_info_abort = {}
for addr, result, info in await asyncio.gather(*coros):
if result == "OK":
workers_info_ok.append(addr)
workers_info_ok[addr] = info
else:
workers_info_abort.append(addr)
workers_info_abort[addr] = info

finally:
if stop_amm:
Expand Down Expand Up @@ -7653,7 +7653,7 @@ async def _track_retire_worker(
close: bool,
remove: bool,
stimulus_id: str,
) -> tuple[str, Literal["OK", "no-recipients"]]:
) -> tuple[str, Literal["OK", "no-recipients"], dict]:
while not policy.done():
# Sleep 0.01s when there are 4 tasks or less
# Sleep 0.5s when there are 200 or more
Expand All @@ -7675,7 +7675,7 @@ async def _track_retire_worker(
f"Could not retire worker {ws.address!r}: unique data could not be "
f"moved to any other worker ({stimulus_id=!r})"
)
return ws.address, "no-recipients"
return ws.address, "no-recipients", ws.identity()

logger.debug(
f"All unique keys on worker {ws.address!r} have been replicated elsewhere"
Expand All @@ -7689,7 +7689,7 @@ async def _track_retire_worker(
self.close_worker(ws.address)

logger.info(f"Retired worker {ws.address!r} ({stimulus_id=!r})")
return ws.address, "OK"
return ws.address, "OK", ws.identity()

def add_keys(
self, worker: str, keys: Collection[Key] = (), stimulus_id: str | None = None
Expand Down
2 changes: 1 addition & 1 deletion distributed/tests/test_active_memory_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -1250,7 +1250,7 @@ async def test_RetireWorker_with_actor_proxy(c, s, a, b):
assert "y" in b.data

out = await c.retire_workers([b.address])
assert out == (b.address,)
assert b.address in out
assert "x" in a.state.actors
assert "y" in a.data

Expand Down
19 changes: 16 additions & 3 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -4389,7 +4389,11 @@ async def test_scatter_type(c, s, a, b):
async def test_retire_workers_2(c, s, a, b):
[x] = await c.scatter([1], workers=a.address)

await s.retire_workers(workers=[a.address])
info = await s.retire_workers(workers=[a.address])
assert info
assert info[a.address]
assert "name" in info[a.address]
assert a.address not in s.workers
assert b.data == {x.key: 1}

assert {ws.address for ws in s.tasks[x.key].who_has} == {b.address}
Expand All @@ -4402,7 +4406,8 @@ async def test_retire_workers_2(c, s, a, b):
async def test_retire_many_workers(c, s, *workers):
futures = await c.scatter(list(range(100)))

await s.retire_workers(workers=[w.address for w in workers[:7]])
info = await s.retire_workers(workers=[w.address for w in workers[:7]])
assert len(info) == 7

results = await c.gather(futures)
assert results == list(range(100))
Expand Down Expand Up @@ -4764,7 +4769,15 @@ def test_recreate_task_sync(c):
@gen_cluster(client=True)
async def test_retire_workers(c, s, a, b):
assert set(s.workers) == {a.address, b.address}
await c.retire_workers(workers=[a.address], close_workers=True)
info = await c.retire_workers(workers=[a.address], close_workers=True)

# Deployment tooling is sometimes relying on this information to be returned
# This represents WorkerState.idenity() right now but may be slimmed down in
# the future
assert info
assert info[a.address]
assert "name" in info[a.address]

assert set(s.workers) == {b.address}

while a.status != Status.closed:
Expand Down

0 comments on commit 03a45a8

Please sign in to comment.