diff --git a/vllm/spec_decode/spec_decode_worker.py b/vllm/spec_decode/spec_decode_worker.py index a402181b13db8..eb3c2e88e668c 100644 --- a/vllm/spec_decode/spec_decode_worker.py +++ b/vllm/spec_decode/spec_decode_worker.py @@ -61,6 +61,10 @@ def create_spec_worker(*args, **kwargs) -> "SpecDecodeWorker": draft_worker_config = copy.deepcopy(vllm_config) draft_worker_config.model_config = speculative_config.draft_model_config + draft_worker_config.quant_config = VllmConfig._get_quantization_config( + draft_worker_config.model_config, + vllm_config.load_config, + ) draft_worker_config.parallel_config = speculative_config.draft_parallel_config # noqa # TODO allow draft-model specific load config.