Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
Signed-off-by: jiang1.li <[email protected]>
  • Loading branch information
bigPYJ1151 committed Nov 15, 2024
1 parent ba2575b commit 5980981
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 2 deletions.
8 changes: 8 additions & 0 deletions vllm/executor/cpu_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@
from functools import partial
from typing import Any, Awaitable, List, Optional, Set, Tuple, Union

import torch

import vllm.envs as envs
from vllm.config import (CacheConfig, ModelConfig, ParallelConfig,
SchedulerConfig)
Expand Down Expand Up @@ -64,6 +66,12 @@ def _init_executor(self) -> None:
self.parallel_config = _verify_and_get_parallel_config(
self.parallel_config)

if (self.scheduler_config.chunked_prefill_enabled
and self.model_config.dtype == torch.half):
logger.warning("Chunked-prefill on the CPU backend only does not"
" support fp16 for now, cast to bf16.")
self.model_config.dtype = torch.bfloat16

# Multiprocessing-based executor does not support multi-node setting.
# Since it only works for single node, we can use the loopback address
# 127.0.0.1 for communication.
Expand Down
2 changes: 0 additions & 2 deletions vllm/worker/cpu_model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,6 @@ def __init__(
**kwargs,
):
ModelRunnerBase.__init__(self, vllm_config)
# Currently, CPU worker doesn't support chunked prefill.
assert self.scheduler_config.chunked_prefill_enabled is False
model_config = self.model_config
cache_config = self.cache_config

Expand Down

0 comments on commit 5980981

Please sign in to comment.