-
-
Notifications
You must be signed in to change notification settings - Fork 130
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
api: add client timeouts for the ZeroMQ server (#897)
- Loading branch information
Showing
6 changed files
with
181 additions
and
30 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,115 @@ | ||
import asyncio | ||
import tempfile | ||
import unittest | ||
import unittest.mock | ||
import uuid | ||
|
||
import pytest | ||
import pytest_asyncio | ||
|
||
from aphrodite.endpoints.openai.rpc.client import (AsyncEngineRPCClient, | ||
RPCClientClosedError) | ||
from aphrodite.endpoints.openai.rpc.server import AsyncEngineRPCServer | ||
from aphrodite.engine.async_aphrodite import AsyncAphrodite | ||
|
||
|
||
@pytest.fixture(scope="function") | ||
def tmp_socket(): | ||
with tempfile.TemporaryDirectory() as td: | ||
yield f"ipc://{td}/{uuid.uuid4()}" | ||
|
||
|
||
@pytest_asyncio.fixture(scope="function") | ||
async def dummy_server(tmp_socket, monkeypatch): | ||
dummy_engine = unittest.mock.AsyncMock() | ||
|
||
def dummy_engine_builder(*args, **kwargs): | ||
return dummy_engine | ||
|
||
with monkeypatch.context() as m: | ||
m.setattr(AsyncAphrodite, "from_engine_args", dummy_engine_builder) | ||
server = AsyncEngineRPCServer(None, rpc_path=tmp_socket) | ||
loop = asyncio.get_running_loop() | ||
server_task = loop.create_task(server.run_server_loop()) | ||
try: | ||
yield server | ||
finally: | ||
server_task.cancel() | ||
server.cleanup() | ||
|
||
|
||
@pytest_asyncio.fixture(scope="function") | ||
async def client(tmp_socket): | ||
client = AsyncEngineRPCClient(rpc_path=tmp_socket) | ||
# Sanity check: the server is connected | ||
await client._wait_for_server_rpc() | ||
try: | ||
yield client | ||
finally: | ||
client.close() | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_client_data_methods_use_timeouts( | ||
monkeypatch, dummy_server, client: AsyncEngineRPCClient | ||
): | ||
with monkeypatch.context() as m: | ||
# Make the server _not_ reply with a model config | ||
m.setattr(dummy_server, "get_config", lambda x: None) | ||
m.setattr(client, "_data_timeout", 10) | ||
# And ensure the task completes anyway | ||
# (client.setup() invokes server.get_config()) | ||
client_task = asyncio.get_running_loop().create_task(client.setup()) | ||
with pytest.raises(TimeoutError, match="Server didn't reply within"): | ||
await asyncio.wait_for(client_task, timeout=0.05) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_client_aborts_use_timeouts( | ||
monkeypatch, dummy_server, client: AsyncEngineRPCClient | ||
): | ||
with monkeypatch.context() as m: | ||
# Hang all abort requests | ||
m.setattr(dummy_server, "abort", lambda x: None) | ||
m.setattr(client, "_data_timeout", 10) | ||
# Ensure the client doesn't hang | ||
client_task = asyncio.get_running_loop().create_task( | ||
client.abort("test request id") | ||
) | ||
with pytest.raises(TimeoutError, match="Server didn't reply within"): | ||
await asyncio.wait_for(client_task, timeout=0.05) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_client_data_methods_reraise_exceptions( | ||
monkeypatch, dummy_server, client: AsyncEngineRPCClient | ||
): | ||
with monkeypatch.context() as m: | ||
# Make the server raise some random exception | ||
exception = RuntimeError("Client test exception") | ||
|
||
def raiser(): | ||
raise exception | ||
|
||
m.setattr(dummy_server.engine, "get_model_config", raiser) | ||
m.setattr(client, "_data_timeout", 10) | ||
client_task = asyncio.get_running_loop().create_task(client.setup()) | ||
# And ensure the task completes, raising the exception | ||
with pytest.raises(RuntimeError, match=str(exception)): | ||
await asyncio.wait_for(client_task, timeout=0.05) | ||
|
||
|
||
@pytest.mark.asyncio | ||
async def test_client_errors_after_closing( | ||
monkeypatch, dummy_server, client: AsyncEngineRPCClient | ||
): | ||
client.close() | ||
# Healthchecks and generate requests will fail with explicit errors | ||
with pytest.raises(RPCClientClosedError): | ||
await client.check_health() | ||
with pytest.raises(RPCClientClosedError): | ||
async for _ in client.generate(None, None, None): | ||
pass | ||
# But no-ops like aborting will pass | ||
await client.abort("test-request-id") | ||
await client.do_log_stats() |