From 295c4730a85ce419e5b46e256240d69ad1cce619 Mon Sep 17 00:00:00 2001 From: Kevin Lin <42618777+kevin314@users.noreply.github.com> Date: Thu, 12 Sep 2024 00:45:24 -0500 Subject: [PATCH] [Misc] Raise error when using encoder/decoder model with cpu backend (#8355) --- vllm/utils.py | 4 ++++ vllm/worker/cpu_model_runner.py | 6 +++++- 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/vllm/utils.py b/vllm/utils.py index a22081ebe8df0..aba243071b69a 100644 --- a/vllm/utils.py +++ b/vllm/utils.py @@ -82,6 +82,9 @@ "currently supported with encoder/" "decoder models.") +STR_NOT_IMPL_ENC_DEC_CPU = ("CPU is not currently supported with " + "encoder/decoder models.") + # Efficiently import all enc/dec error strings # rather than having to import all of the above STR_NOT_IMPL_ENC_DEC_ERR_STRS = { @@ -97,6 +100,7 @@ "STR_NOT_IMPL_ENC_DEC_CUDA_GRAPH": STR_NOT_IMPL_ENC_DEC_CUDAGRAPH, "STR_NOT_IMPL_ENC_DEC_BACKEND": STR_NOT_IMPL_ENC_DEC_BACKEND, "STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER": STR_NOT_IMPL_ENC_DEC_PROMPT_ADAPTER, + "STR_NOT_IMPL_ENC_DEC_CPU": STR_NOT_IMPL_ENC_DEC_CPU } # Constants related to forcing the attention backend selection diff --git a/vllm/worker/cpu_model_runner.py b/vllm/worker/cpu_model_runner.py index 7205b1a7beb8d..7b2caf4973589 100644 --- a/vllm/worker/cpu_model_runner.py +++ b/vllm/worker/cpu_model_runner.py @@ -15,7 +15,7 @@ from vllm.multimodal import (MULTIMODAL_REGISTRY, BatchedTensorInputs, MultiModalInputs) from vllm.sequence import IntermediateTensors, SequenceGroupMetadata -from vllm.utils import make_tensor_with_pad +from vllm.utils import STR_NOT_IMPL_ENC_DEC_ERR_STRS, make_tensor_with_pad from vllm.worker.model_runner_base import ( ModelRunnerBase, ModelRunnerInputBase, _add_attn_metadata_broadcastable_dict, @@ -121,6 +121,10 @@ def __init__( # Lazy initialization. self.model: nn.Module # Set after init_Model + if self.model_config.is_encoder_decoder_model: + raise NotImplementedError( + STR_NOT_IMPL_ENC_DEC_ERR_STRS['STR_NOT_IMPL_ENC_DEC_CPU']) + def load_model(self) -> None: self.model = get_model(model_config=self.model_config, load_config=self.load_config,