Skip to content

Commit

Permalink
[Bugfix] Fix pickle of input when async output processing is on (vllm…
Browse files Browse the repository at this point in the history
…-project#9931)

Signed-off-by: Wallas Santos <[email protected]>
Signed-off-by: Tyler Michael Smith <[email protected]>
  • Loading branch information
wallashss authored and tlrmchlsmth committed Nov 23, 2024
1 parent 762faba commit 6f8d56b
Show file tree
Hide file tree
Showing 2 changed files with 38 additions and 0 deletions.
26 changes: 26 additions & 0 deletions tests/basic_correctness/test_basic_correctness.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,29 @@ def test_model_with_failure(vllm_runner) -> None:
ModelInputForGPUWithSamplingMetadata)
finally:
os.remove(filename)


def test_failure_with_async_out_proc(vllm_runner) -> None:

filename = None
try:
with vllm_runner("facebook/opt-125m",
dtype="half",
enforce_eager=False,
gpu_memory_utilization=0.7) as vllm_model,\
patch("vllm.model_executor.models.opt.OPTForCausalLM.forward",
side_effect=ValueError()):
model_config = vllm_model.model.llm_engine.model_config
assert model_config.use_async_output_proc
with pytest.raises(ValueError) as exc_info:
vllm_model.generate_greedy('how to make pizza?', 250)
matches = re.search(r"input dumped to (.+).pkl",
str(exc_info.value))
assert matches is not None

filename = f"{matches.group(1)}.pkl"
finally:
# Clean up
if filename is not None:
os.remove(filename)
pass
12 changes: 12 additions & 0 deletions vllm/worker/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,18 @@ def from_broadcasted_tensor_dict(
attn_backend, tensor_dict)
return cls(**tensor_dict)

# Exclude `async_callback` to be able to pickle this object
def __getstate__(self):
state = self.__dict__.copy()
del state["async_callback"]
return state

# TODO: What happens when we depickle this object?
# How can we update this callback to properly pass it to the engine?
def __setstate__(self, state):
self.__dict__.update(state)
self.__dict__.update({'async_callback': None})


@dataclass(frozen=True)
class ModelInputForGPUWithSamplingMetadata(ModelInputForGPU):
Expand Down

0 comments on commit 6f8d56b

Please sign in to comment.