Skip to content

Commit

Permalink
[Frontend] Reduce frequency of client cancellation checking (vllm-pro…
Browse files Browse the repository at this point in the history
  • Loading branch information
njhill authored Oct 21, 2024
1 parent 5241aa1 commit 9d9186b
Showing 1 changed file with 38 additions and 19 deletions.
57 changes: 38 additions & 19 deletions vllm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,11 @@
import sys
import tempfile
import threading
import time
import uuid
import warnings
import weakref
from asyncio import FIRST_COMPLETED, ensure_future
from asyncio import FIRST_COMPLETED, AbstractEventLoop, Future, Task
from collections.abc import Mapping
from functools import lru_cache, partial, wraps
from platform import uname
Expand Down Expand Up @@ -437,6 +438,12 @@ def _async_wrapper(*args: P.args, **kwargs: P.kwargs) -> asyncio.Future:
return _async_wrapper


def _next_task(iterator: AsyncGenerator[T, None],
loop: AbstractEventLoop) -> Task:
# Can use anext() in python >= 3.10
return loop.create_task(iterator.__anext__()) # type: ignore[arg-type]


async def iterate_with_cancellation(
iterator: AsyncGenerator[T, None],
is_cancelled: Callable[[], Awaitable[bool]],
Expand All @@ -445,19 +452,27 @@ async def iterate_with_cancellation(
at least once per second to check for client cancellation.
"""

# Can use anext() in python >= 3.10
awaits = [ensure_future(iterator.__anext__())]
loop = asyncio.get_running_loop()

awaits: List[Future[T]] = [_next_task(iterator, loop)]
next_cancel_check: float = 0
while True:
done, pending = await asyncio.wait(awaits, timeout=1)
if await is_cancelled():
with contextlib.suppress(BaseException):
awaits[0].cancel()
await iterator.aclose()
raise asyncio.CancelledError("client cancelled")
done, pending = await asyncio.wait(awaits, timeout=1.5)

# Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled():
with contextlib.suppress(BaseException):
awaits[0].cancel()
await iterator.aclose()
raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1

if done:
try:
item = await awaits[0]
awaits[0] = ensure_future(iterator.__anext__())
awaits[0] = _next_task(iterator, loop)
yield item
except StopAsyncIteration:
# we are done
Expand All @@ -478,25 +493,29 @@ async def merge_async_iterators(
to check for client cancellation.
"""

# Can use anext() in python >= 3.10
awaits = {
ensure_future(pair[1].__anext__()): pair
for pair in enumerate(iterators)
}
timeout = None if is_cancelled is None else 1
loop = asyncio.get_running_loop()

awaits = {_next_task(pair[1], loop): pair for pair in enumerate(iterators)}
timeout = None if is_cancelled is None else 1.5
next_cancel_check: float = 0
try:
while awaits:
done, pending = await asyncio.wait(awaits.keys(),
return_when=FIRST_COMPLETED,
timeout=timeout)
if is_cancelled is not None and await is_cancelled():
raise asyncio.CancelledError("client cancelled")
if is_cancelled is not None:
# Check for cancellation at most once per second
time_now = time.time()
if time_now >= next_cancel_check:
if await is_cancelled():
raise asyncio.CancelledError("client cancelled")
next_cancel_check = time_now + 1
for d in done:
pair = awaits.pop(d)
try:
item = await d
i, it = pair
awaits[ensure_future(it.__anext__())] = pair
awaits[_next_task(it, loop)] = pair
yield i, item
except StopAsyncIteration:
pass
Expand Down

0 comments on commit 9d9186b

Please sign in to comment.