Skip to content

Commit 017dc48

Browse files
joerundejeejeelee
authored andcommitted
[Bugfix] Validate lora adapters to avoid crashing server (vllm-project#11727)
Signed-off-by: Joe Runde <[email protected]> Co-authored-by: Jee Jee Li <[email protected]>
1 parent 1c953f6 commit 017dc48

15 files changed

+459
-171
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,269 @@
1+
import asyncio
2+
import json
3+
import shutil
4+
from contextlib import suppress
5+
6+
import openai # use the official client for correctness check
7+
import pytest
8+
import pytest_asyncio
9+
# downloading lora to test lora requests
10+
from huggingface_hub import snapshot_download
11+
12+
from ...utils import RemoteOpenAIServer
13+
14+
# any model with a chat template should work here
15+
MODEL_NAME = "HuggingFaceH4/zephyr-7b-beta"
16+
# technically this needs Mistral-7B-v0.1 as base, but we're not testing
17+
# generation quality here
18+
LORA_NAME = "typeof/zephyr-7b-beta-lora"
19+
20+
21+
@pytest.fixture(scope="module")
22+
def zephyr_lora_files():
23+
return snapshot_download(repo_id=LORA_NAME)
24+
25+
26+
@pytest.fixture(scope="module")
27+
def server_with_lora_modules_json(zephyr_lora_files):
28+
# Define the json format LoRA module configurations
29+
lora_module_1 = {
30+
"name": "zephyr-lora",
31+
"path": zephyr_lora_files,
32+
"base_model_name": MODEL_NAME
33+
}
34+
35+
lora_module_2 = {
36+
"name": "zephyr-lora2",
37+
"path": zephyr_lora_files,
38+
"base_model_name": MODEL_NAME
39+
}
40+
41+
args = [
42+
# use half precision for speed and memory savings in CI environment
43+
"--dtype",
44+
"bfloat16",
45+
"--max-model-len",
46+
"8192",
47+
"--enforce-eager",
48+
# lora config below
49+
"--enable-lora",
50+
"--lora-modules",
51+
json.dumps(lora_module_1),
52+
json.dumps(lora_module_2),
53+
"--max-lora-rank",
54+
"64",
55+
"--max-cpu-loras",
56+
"2",
57+
"--max-num-seqs",
58+
"64",
59+
]
60+
61+
# Enable the /v1/load_lora_adapter endpoint
62+
envs = {"VLLM_ALLOW_RUNTIME_LORA_UPDATING": "True"}
63+
64+
with RemoteOpenAIServer(MODEL_NAME, args, env_dict=envs) as remote_server:
65+
yield remote_server
66+
67+
68+
@pytest_asyncio.fixture
69+
async def client(server_with_lora_modules_json):
70+
async with server_with_lora_modules_json.get_async_client(
71+
) as async_client:
72+
yield async_client
73+
74+
75+
@pytest.mark.asyncio
76+
async def test_static_lora_lineage(client: openai.AsyncOpenAI,
77+
zephyr_lora_files):
78+
models = await client.models.list()
79+
models = models.data
80+
served_model = models[0]
81+
lora_models = models[1:]
82+
assert served_model.id == MODEL_NAME
83+
assert served_model.root == MODEL_NAME
84+
assert served_model.parent is None
85+
assert all(lora_model.root == zephyr_lora_files
86+
for lora_model in lora_models)
87+
assert all(lora_model.parent == MODEL_NAME for lora_model in lora_models)
88+
assert lora_models[0].id == "zephyr-lora"
89+
assert lora_models[1].id == "zephyr-lora2"
90+
91+
92+
@pytest.mark.asyncio
93+
async def test_dynamic_lora_lineage(client: openai.AsyncOpenAI,
94+
zephyr_lora_files):
95+
96+
response = await client.post("load_lora_adapter",
97+
cast_to=str,
98+
body={
99+
"lora_name": "zephyr-lora-3",
100+
"lora_path": zephyr_lora_files
101+
})
102+
# Ensure adapter loads before querying /models
103+
assert "success" in response
104+
105+
models = await client.models.list()
106+
models = models.data
107+
dynamic_lora_model = models[-1]
108+
assert dynamic_lora_model.root == zephyr_lora_files
109+
assert dynamic_lora_model.parent == MODEL_NAME
110+
assert dynamic_lora_model.id == "zephyr-lora-3"
111+
112+
113+
@pytest.mark.asyncio
114+
async def test_dynamic_lora_not_found(client: openai.AsyncOpenAI):
115+
with pytest.raises(openai.NotFoundError):
116+
await client.post("load_lora_adapter",
117+
cast_to=str,
118+
body={
119+
"lora_name": "notfound",
120+
"lora_path": "/not/an/adapter"
121+
})
122+
123+
124+
@pytest.mark.asyncio
125+
async def test_dynamic_lora_invalid_files(client: openai.AsyncOpenAI,
126+
tmp_path):
127+
invalid_files = tmp_path / "invalid_files"
128+
invalid_files.mkdir()
129+
(invalid_files / "adapter_config.json").write_text("this is not json")
130+
131+
with pytest.raises(openai.BadRequestError):
132+
await client.post("load_lora_adapter",
133+
cast_to=str,
134+
body={
135+
"lora_name": "invalid-json",
136+
"lora_path": str(invalid_files)
137+
})
138+
139+
140+
@pytest.mark.asyncio
141+
async def test_dynamic_lora_invalid_lora_rank(client: openai.AsyncOpenAI,
142+
tmp_path, zephyr_lora_files):
143+
invalid_rank = tmp_path / "invalid_rank"
144+
145+
# Copy adapter from zephyr_lora_files to invalid_rank
146+
shutil.copytree(zephyr_lora_files, invalid_rank)
147+
148+
with open(invalid_rank / "adapter_config.json") as f:
149+
adapter_config = json.load(f)
150+
151+
print(adapter_config)
152+
153+
# assert False
154+
155+
# Change rank to invalid value
156+
adapter_config["r"] = 1024
157+
with open(invalid_rank / "adapter_config.json", "w") as f:
158+
json.dump(adapter_config, f)
159+
160+
with pytest.raises(openai.BadRequestError,
161+
match="is greater than max_lora_rank"):
162+
await client.post("load_lora_adapter",
163+
cast_to=str,
164+
body={
165+
"lora_name": "invalid-json",
166+
"lora_path": str(invalid_rank)
167+
})
168+
169+
170+
@pytest.mark.asyncio
171+
async def test_multiple_lora_adapters(client: openai.AsyncOpenAI, tmp_path,
172+
zephyr_lora_files):
173+
"""Validate that many loras can be dynamically registered and inferenced
174+
with concurrently"""
175+
176+
# This test file configures the server with --max-cpu-loras=2 and this test
177+
# will concurrently load 10 adapters, so it should flex the LRU cache
178+
async def load_and_run_adapter(adapter_name: str):
179+
await client.post("load_lora_adapter",
180+
cast_to=str,
181+
body={
182+
"lora_name": adapter_name,
183+
"lora_path": str(zephyr_lora_files)
184+
})
185+
for _ in range(3):
186+
await client.completions.create(
187+
model=adapter_name,
188+
prompt=["Hello there", "Foo bar bazz buzz"],
189+
max_tokens=5,
190+
)
191+
192+
lora_tasks = []
193+
for i in range(10):
194+
lora_tasks.append(
195+
asyncio.create_task(load_and_run_adapter(f"adapter_{i}")))
196+
197+
results, _ = await asyncio.wait(lora_tasks)
198+
199+
for r in results:
200+
assert not isinstance(r, Exception), f"Got exception {r}"
201+
202+
203+
@pytest.mark.asyncio
204+
async def test_loading_invalid_adapters_does_not_break_others(
205+
client: openai.AsyncOpenAI, tmp_path, zephyr_lora_files):
206+
207+
invalid_files = tmp_path / "invalid_files"
208+
invalid_files.mkdir()
209+
(invalid_files / "adapter_config.json").write_text("this is not json")
210+
211+
stop_good_requests_event = asyncio.Event()
212+
213+
async def run_good_requests(client):
214+
# Run chat completions requests until event set
215+
216+
results = []
217+
218+
while not stop_good_requests_event.is_set():
219+
try:
220+
batch = await client.completions.create(
221+
model="zephyr-lora",
222+
prompt=["Hello there", "Foo bar bazz buzz"],
223+
max_tokens=5,
224+
)
225+
results.append(batch)
226+
except Exception as e:
227+
results.append(e)
228+
229+
return results
230+
231+
# Create task to run good requests
232+
good_task = asyncio.create_task(run_good_requests(client))
233+
234+
# Run a bunch of bad adapter loads
235+
for _ in range(25):
236+
with suppress(openai.NotFoundError):
237+
await client.post("load_lora_adapter",
238+
cast_to=str,
239+
body={
240+
"lora_name": "notfound",
241+
"lora_path": "/not/an/adapter"
242+
})
243+
for _ in range(25):
244+
with suppress(openai.BadRequestError):
245+
await client.post("load_lora_adapter",
246+
cast_to=str,
247+
body={
248+
"lora_name": "invalid",
249+
"lora_path": str(invalid_files)
250+
})
251+
252+
# Ensure all the running requests with lora adapters succeeded
253+
stop_good_requests_event.set()
254+
results = await good_task
255+
for r in results:
256+
assert not isinstance(r, Exception), f"Got exception {r}"
257+
258+
# Ensure we can load another adapter and run it
259+
await client.post("load_lora_adapter",
260+
cast_to=str,
261+
body={
262+
"lora_name": "valid",
263+
"lora_path": zephyr_lora_files
264+
})
265+
await client.completions.create(
266+
model="valid",
267+
prompt=["Hello there", "Foo bar bazz buzz"],
268+
max_tokens=5,
269+
)

tests/entrypoints/openai/test_lora_lineage.py

-109
This file was deleted.

0 commit comments

Comments
 (0)