From 975e04043cdf3e4dcf10328827e0d62ab64c8b6c Mon Sep 17 00:00:00 2001 From: mgoin Date: Tue, 3 Dec 2024 02:33:25 +0000 Subject: [PATCH] Move copy down into guided decoding case Signed-off-by: mgoin --- vllm/engine/async_llm_engine.py | 8 +++++++- vllm/engine/llm_engine.py | 7 ++++++- vllm/entrypoints/llm.py | 4 +--- 3 files changed, 14 insertions(+), 5 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index e991f3efed5e1..c845d1b89a8e3 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,4 +1,5 @@ import asyncio +import copy import time import weakref from functools import partial @@ -533,9 +534,14 @@ async def build_guided_decoding_logits_processor_async( those fields and adds the constructed logits processors to the logits_processors field. Modifies sampling params in-place and returns the modified sampling params.""" - if (guided_decoding := sampling_params.guided_decoding) is None: + if sampling_params.guided_decoding is None: return sampling_params + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding + logger.debug("Building guided decoding logits processor. " "Params: %s", guided_decoding) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 21afec3c006ab..4613372cfa039 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,3 +1,4 @@ +import copy import time from collections import Counter as collectionsCounter from collections import deque @@ -2035,7 +2036,11 @@ def _build_logits_processors( logits_processors = [] - if (guided_decoding := sampling_params.guided_decoding) is not None: + if sampling_params.guided_decoding is not None: + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding logger.debug( "Building guided decoding logits processor in " diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index 64dee5f0fa199..a25c401b4ea10 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -1,4 +1,3 @@ -import copy import itertools import json import warnings @@ -1038,8 +1037,7 @@ def _validate_and_add_requests( for i, prompt in enumerate(prompts): self._add_request( prompt, - params[i] - if isinstance(params, Sequence) else copy.copy(params), + params[i] if isinstance(params, Sequence) else params, lora_request=lora_request[i] if isinstance( lora_request, Sequence) else lora_request, prompt_adapter_request=prompt_adapter_request,