Skip to content

Commit

Permalink
Move copy down into guided decoding case
Browse files Browse the repository at this point in the history
Signed-off-by: mgoin <[email protected]>
  • Loading branch information
mgoin committed Dec 3, 2024
1 parent 9f97093 commit 975e040
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 5 deletions.
8 changes: 7 additions & 1 deletion vllm/engine/async_llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import asyncio
import copy
import time
import weakref
from functools import partial
Expand Down Expand Up @@ -533,9 +534,14 @@ async def build_guided_decoding_logits_processor_async(
those fields and adds the constructed logits processors to the
logits_processors field. Modifies sampling params in-place and returns
the modified sampling params."""
if (guided_decoding := sampling_params.guided_decoding) is None:
if sampling_params.guided_decoding is None:
return sampling_params

# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding

logger.debug("Building guided decoding logits processor. "
"Params: %s", guided_decoding)

Expand Down
7 changes: 6 additions & 1 deletion vllm/engine/llm_engine.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import copy
import time
from collections import Counter as collectionsCounter
from collections import deque
Expand Down Expand Up @@ -2035,7 +2036,11 @@ def _build_logits_processors(

logits_processors = []

if (guided_decoding := sampling_params.guided_decoding) is not None:
if sampling_params.guided_decoding is not None:
# Defensively copy sampling params since guided decoding logits
# processors can have different state for each request
sampling_params = copy.copy(sampling_params)
guided_decoding = sampling_params.guided_decoding

logger.debug(
"Building guided decoding logits processor in "
Expand Down
4 changes: 1 addition & 3 deletions vllm/entrypoints/llm.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import copy
import itertools
import json
import warnings
Expand Down Expand Up @@ -1038,8 +1037,7 @@ def _validate_and_add_requests(
for i, prompt in enumerate(prompts):
self._add_request(
prompt,
params[i]
if isinstance(params, Sequence) else copy.copy(params),
params[i] if isinstance(params, Sequence) else params,
lora_request=lora_request[i] if isinstance(
lora_request, Sequence) else lora_request,
prompt_adapter_request=prompt_adapter_request,
Expand Down

0 comments on commit 975e040

Please sign in to comment.