diff --git a/python/cog/server/worker.py b/python/cog/server/worker.py index a91ca85d92..0af1838989 100644 --- a/python/cog/server/worker.py +++ b/python/cog/server/worker.py @@ -12,6 +12,7 @@ import types import uuid import warnings +import weakref from concurrent.futures import Future, ThreadPoolExecutor from enum import Enum, auto, unique from multiprocessing.connection import Connection @@ -100,6 +101,10 @@ class PredictionState: class Worker: + @property + def uses_concurrency(self) -> bool: + return self._max_concurrency > 1 + def __init__( self, child: "_ChildWorker", events: Connection, max_concurrency: int = 1 ) -> None: @@ -294,7 +299,7 @@ def _consume_events_inner(self) -> None: with self._predictions_lock: predict_state = self._predictions_in_flight.get(ev.tag) if predict_state and not predict_state.cancel_sent: - self._child.send_cancel() + self._child.send_cancel_signal() self._events.send(Envelope(event=Cancel(), tag=ev.tag)) predict_state.cancel_sent = True else: @@ -434,7 +439,7 @@ async def _runner() -> None: redirector, ) - def send_cancel(self) -> None: + def send_cancel_signal(self) -> None: if self.is_alive() and self.pid: os.kill(self.pid, signal.SIGUSR1) @@ -546,7 +551,9 @@ def _loop( while True: e = cast(Envelope, self._events.recv()) if isinstance(e.event, Cancel): - continue # Ignored in sync predictors. + # for sync predictors, this is handled via SIGUSR1 signals from + # the parent via send_cancel_signal + continue elif isinstance(e.event, Shutdown): break elif isinstance(e.event, PredictionInput): @@ -563,17 +570,27 @@ async def _aloop( assert isinstance(self._events, LockedConnection) self._events = AsyncConnection(self._events.connection) - task = None - async with asyncio.TaskGroup() as tg: + tasks = weakref.WeakValueDictionary[str | None, asyncio.Task[Any]]() while True: e = cast(Envelope, await self._events.recv()) - if isinstance(e.event, Cancel) and task and self._cancelable: + if isinstance(e.event, Cancel): + # NOTE: We don't check the _cancelable flag here, instead we rely + # on the presence of the value in the weakmap to determine if + # a prediction is actively being processed. + task = tasks.get(e.tag) + if not task: + print( + "Got cancel event for unrecognized prediction", + file=sys.stderr, + ) + continue + task.cancel() elif isinstance(e.event, Shutdown): break elif isinstance(e.event, PredictionInput): - task = tg.create_task( + tasks[e.tag] = tg.create_task( self._apredict(e.tag, e.event.payload, predict, redirector) ) else: diff --git a/python/tests/server/test_worker.py b/python/tests/server/test_worker.py index 6681ed36ae..da40d81287 100644 --- a/python/tests/server/test_worker.py +++ b/python/tests/server/test_worker.py @@ -155,11 +155,24 @@ SLEEP_FIXTURES = [ WorkerConfig("sleep"), WorkerConfig("sleep_async", min_python=(3, 11), is_async=True), + WorkerConfig( + "sleep_async", + min_python=(3, 11), + is_async=True, + max_concurrency=10, + ), ] SLEEP_NO_SETUP_FIXTURES = [ WorkerConfig("sleep", setup=False), WorkerConfig("sleep_async", min_python=(3, 11), setup=False, is_async=True), + WorkerConfig( + "sleep_async", + min_python=(3, 11), + setup=False, + is_async=True, + max_concurrency=10, + ), ] @@ -503,30 +516,40 @@ def test_predict_logging(worker, expected_stdout, expected_stderr): @uses_worker_configs(SLEEP_NO_SETUP_FIXTURES) -def test_cancel_is_safe(worker): +def test_cancel_is_safe(worker: Worker): """ Calls to cancel at any time should not result in unexpected things happening or the cancelation of unexpected predictions. """ + tag = None + if worker.uses_concurrency: + tag = "p1" + for _ in range(50): - worker.cancel() + worker.cancel(tag) result = _process(worker, worker.setup) assert not result.done.error for _ in range(50): - worker.cancel() + worker.cancel(tag) result1 = _process( - worker, lambda: worker.predict({"sleep": 0.5}), swallow_exceptions=True + worker, + lambda: worker.predict({"sleep": 0.5}, tag), + swallow_exceptions=True, + tag=tag, ) for _ in range(50): - worker.cancel() + worker.cancel(tag) result2 = _process( - worker, lambda: worker.predict({"sleep": 0.1}), swallow_exceptions=True + worker, + lambda: worker.predict({"sleep": 0.1}, tag), + swallow_exceptions=True, + tag=tag, ) assert not result1.exception @@ -537,68 +560,113 @@ def test_cancel_is_safe(worker): @uses_worker_configs(SLEEP_NO_SETUP_FIXTURES) -def test_cancel_idempotency(worker): +def test_cancel_idempotency(worker: Worker): """ Multiple calls to cancel within the same prediction, while not necessary or recommended, should still only result in a single cancelled prediction, and should not affect subsequent predictions. """ - def cancel_a_bunch(_): - for _ in range(100): - worker.cancel() + tag = None + if worker.uses_concurrency: + tag = "p1" result = _process(worker, worker.setup) assert not result.done.error - fut = worker.predict({"sleep": 0.5}) + fut = worker.predict({"sleep": 0.5}, tag) # We call cancel a WHOLE BUNCH to make sure that we don't propagate any # of those cancelations to subsequent predictions, regardless of the # internal implementation of exceptions raised inside signal handlers. for _ in range(5): time.sleep(0.05) for _ in range(100): - worker.cancel() + worker.cancel(tag) result1 = fut.result() assert result1.canceled - result2 = _process(worker, lambda: worker.predict({"sleep": 0.1})) + tag = None + if worker.uses_concurrency: + tag = "p2" + result2 = _process(worker, lambda: worker.predict({"sleep": 0.1}, tag)) assert not result2.done.canceled assert result2.output == "done in 0.1 seconds" -@uses_worker_configs(SLEEP_FIXTURES) -def test_cancel_multiple_predictions(worker): +@uses_worker_configs( + [ + WorkerConfig("sleep"), + WorkerConfig("sleep_async", min_python=(3, 11), is_async=True), + WorkerConfig( + "sleep_async", min_python=(3, 11), is_async=True, max_concurrency=5 + ), + ] +) +def test_cancel_multiple_predictions(worker: Worker): """ Multiple predictions cancelled in a row shouldn't be a problem. This test is mainly ensuring that the _allow_cancel latch in Worker is correctly reset every time a prediction starts. """ dones: list[Done] = [] - for _ in range(5): - fut = worker.predict({"sleep": 1}) + for i in range(5): + tag = None + if worker._max_concurrency > 1: + tag = f"p{i}" + fut = worker.predict({"sleep": 0.2}, tag) time.sleep(0.1) - worker.cancel() + worker.cancel(tag) dones.append(fut.result()) + assert dones == [Done(canceled=True)] * 5 - assert not worker.predict({"sleep": 0}).result().canceled + assert not worker.predict({"sleep": 0}, "p6").result().canceled + + +@uses_worker_configs( + [ + WorkerConfig( + "sleep_async", min_python=(3, 11), is_async=True, max_concurrency=5 + ), + ] +) +def test_cancel_some_predictions_async_with_concurrency(worker: Worker): + """ + Multiple predictions cancelled in a row shouldn't be a problem. This test + is mainly ensuring that the _allow_cancel latch in Worker is correctly + reset every time a prediction starts. + """ + fut1 = worker.predict({"sleep": 0.2}, "p1") + fut2 = worker.predict({"sleep": 0.2}, "p2") + fut3 = worker.predict({"sleep": 0.2}, "p3") + + time.sleep(0.1) + + worker.cancel("p2") + + assert not fut1.result().canceled + assert fut2.result().canceled + assert not fut3.result().canceled @uses_worker_configs(SLEEP_FIXTURES) -def test_graceful_shutdown(worker): +def test_graceful_shutdown(worker: Worker): """ On shutdown, the worker should finish running the current prediction, and then exit. """ + tag = None + if worker.uses_concurrency: + tag = "p1" + saw_first_event = threading.Event() # When we see the first event, we'll start the shutdown process. - worker.subscribe(lambda event: saw_first_event.set()) + worker.subscribe(lambda event: saw_first_event.set(), tag=tag) - fut = worker.predict({"sleep": 1}) + fut = worker.predict({"sleep": 1}, tag) saw_first_event.wait(timeout=1) worker.shutdown(timeout=2) @@ -644,7 +712,7 @@ def start(self): def is_alive(self): return self.alive - def send_cancel(self): + def send_cancel_signal(self): pass def terminate(self):