Skip to content

Commit

Permalink
[bugfix] fix cpu tests (#10585)
Browse files Browse the repository at this point in the history
Signed-off-by: youkaichao <[email protected]>
  • Loading branch information
youkaichao authored Nov 23, 2024
1 parent d345f40 commit d559979
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 10 deletions.
4 changes: 3 additions & 1 deletion vllm/worker/cpu_embedding_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import torch

from vllm.forward_context import set_forward_context
from vllm.model_executor.pooling_metadata import PoolingMetadata
from vllm.multimodal import MultiModalKwargs
from vllm.pooling_params import PoolingParams
Expand Down Expand Up @@ -64,7 +65,8 @@ def execute_model(
intermediate_tensors,
}

hidden_states = model_executable(**execute_model_kwargs)
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable(**execute_model_kwargs)

# Only perform pooling in the driver worker.
if not self.is_driver_worker:
Expand Down
4 changes: 3 additions & 1 deletion vllm/worker/cpu_enc_dec_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import torch

from vllm.attention import AttentionMetadata
from vllm.forward_context import set_forward_context
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.sampler import SamplerOutput
from vllm.multimodal import MultiModalKwargs
Expand Down Expand Up @@ -303,7 +304,8 @@ def execute_model(
intermediate_tensors,
}

hidden_states = model_executable(**execute_model_kwargs)
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable(**execute_model_kwargs)

# Compute the logits.
logits = self.model.compute_logits(hidden_states,
Expand Down
18 changes: 10 additions & 8 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from vllm.attention import AttentionMetadata, get_attn_backend
from vllm.config import VllmConfig
from vllm.forward_context import set_forward_context
from vllm.logger import init_logger
from vllm.model_executor import SamplingMetadata
from vllm.model_executor.layers.rotary_embedding import MRotaryEmbedding
Expand Down Expand Up @@ -487,14 +488,15 @@ def execute_model(
multimodal_kwargs = MultiModalKwargs.as_kwargs(
model_input.multi_modal_kwargs, device=self.device)

hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multimodal_kwargs,
)
with set_forward_context(model_input.attn_metadata, self.vllm_config):
hidden_states = model_executable(
input_ids=model_input.input_tokens,
positions=model_input.input_positions,
kv_caches=kv_caches,
attn_metadata=model_input.attn_metadata,
intermediate_tensors=intermediate_tensors,
**multimodal_kwargs,
)

# Compute the logits.
logits = self.model.compute_logits(hidden_states,
Expand Down

0 comments on commit d559979

Please sign in to comment.