Skip to content

Commit

Permalink
⚡ guard concurrent adapter loads
Browse files Browse the repository at this point in the history
Signed-off-by: Joe Runde <[email protected]>
  • Loading branch information
joerunde committed Oct 11, 2024
1 parent 37b4444 commit dc550a1
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 25 deletions.
55 changes: 30 additions & 25 deletions src/vllm_tgis_adapter/grpc/adapters.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ class AdapterStore:
cache_path: str # Path to local store of adapters to load from
adapters: dict[str, AdapterMetadata]
next_unique_id: int = 1
load_locks: dict[str, asyncio.Lock] = dataclasses.field(default_factory=dict)


async def validate_adapters(
Expand Down Expand Up @@ -78,31 +79,35 @@ async def validate_adapters(
if not adapter_id or not adapter_store:
return {}

# If not already cached, we need to validate that files exist and
# grab the type out of the adapter_config.json file
if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None:
_reject_bad_adapter_id(adapter_id)
local_adapter_path = str(Path(adapter_store.cache_path) / adapter_id)

loop = asyncio.get_running_loop()
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(max_workers=2)

# Increment the unique adapter id counter here in async land where we don't
# need to deal with thread-safety
unique_id = adapter_store.next_unique_id
adapter_store.next_unique_id += 1

adapter_metadata = await loop.run_in_executor(
global_thread_pool,
_load_adapter_metadata,
adapter_id,
local_adapter_path,
unique_id,
)

# Add to cache
adapter_store.adapters[adapter_id] = adapter_metadata
# Guard against concurrent access for the same adapter
async with adapter_store.load_locks.setdefault(adapter_id, asyncio.Lock()):
# If not already cached, we need to validate that files exist and
# grab the type out of the adapter_config.json file
if (adapter_metadata := adapter_store.adapters.get(adapter_id)) is None:
_reject_bad_adapter_id(adapter_id)
local_adapter_path = str(Path(adapter_store.cache_path) / adapter_id)

loop = asyncio.get_running_loop()
if global_thread_pool is None:
global_thread_pool = concurrent.futures.ThreadPoolExecutor(
max_workers=2
)

# Increment the unique adapter id counter here in async land where we don't
# need to deal with thread-safety
unique_id = adapter_store.next_unique_id
adapter_store.next_unique_id += 1

adapter_metadata = await loop.run_in_executor(
global_thread_pool,
_load_adapter_metadata,
adapter_id,
local_adapter_path,
unique_id,
)

# Add to cache
adapter_store.adapters[adapter_id] = adapter_metadata

# Build the proper vllm request object
if adapter_metadata.adapter_type == "LORA":
Expand Down
26 changes: 26 additions & 0 deletions tests/test_adapters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
from pathlib import Path

import pytest
Expand Down Expand Up @@ -112,3 +113,28 @@ async def test_store_handles_multiple_adapters():
adapters_1["lora_request"].lora_int_id
< adapters_2["prompt_adapter_request"].prompt_adapter_id
)


@pytest.mark.asyncio
async def test_cache_handles_concurrent_loads():
# Check that the cache does not hammer the filesystem when accessed concurrently
# Specifically, when concurrent requests for the same new adapter arrive

adapter_store = AdapterStore(cache_path=FIXTURES_DIR, adapters={})
# Use a caikit-style adapter that requires conversion, to test worst case
adapter_name = "bloom_sentiment_1"
request = BatchedGenerationRequest(
adapter_id=adapter_name,
)

# Fire off a bunch of concurrent requests for the same new adapter
tasks = [
asyncio.create_task(validate_adapters(request, adapter_store=adapter_store))
for _ in range(1000)
]

# Await all tasks
await asyncio.gather(*tasks)

# The adapter store should have only given out one unique ID
assert adapter_store.next_unique_id == 2

0 comments on commit dc550a1

Please sign in to comment.