@@ -15,14 +15,13 @@ class GPUExecutor(ExecutorBase):
15
15
16
16
def _init_executor (self ) -> None :
17
17
"""Initialize the worker and load the model.
18
-
19
- If speculative decoding is enabled, we instead create the speculative
20
- worker.
21
18
"""
22
- if self .speculative_config is None :
23
- self ._init_non_spec_worker ()
24
- else :
25
- self ._init_spec_worker ()
19
+ assert self .parallel_config .world_size == 1 , (
20
+ "GPUExecutor only supports single GPU." )
21
+
22
+ self .driver_worker = self ._create_worker ()
23
+ self .driver_worker .init_device ()
24
+ self .driver_worker .load_model ()
26
25
27
26
def _get_worker_kwargs (
28
27
self ,
@@ -45,66 +44,30 @@ def _get_worker_kwargs(
45
44
distributed_init_method = distributed_init_method ,
46
45
lora_config = self .lora_config ,
47
46
vision_language_config = self .vision_language_config ,
47
+ speculative_config = self .speculative_config ,
48
48
is_driver_worker = rank == 0 ,
49
49
)
50
50
51
51
def _create_worker (self ,
52
52
local_rank : int = 0 ,
53
53
rank : int = 0 ,
54
54
distributed_init_method : Optional [str ] = None ):
55
+
56
+ if self .speculative_config is None :
57
+ worker_module_name = "vllm.worker.worker"
58
+ worker_class_name = "Worker"
59
+ else :
60
+ worker_module_name = "vllm.spec_decode.spec_decode_worker"
61
+ worker_class_name = "create_spec_worker"
62
+
55
63
wrapper = WorkerWrapperBase (
56
- worker_module_name = "vllm.worker.worker" ,
57
- worker_class_name = "Worker" ,
64
+ worker_module_name = worker_module_name ,
65
+ worker_class_name = worker_class_name ,
58
66
)
59
67
wrapper .init_worker (** self ._get_worker_kwargs (local_rank , rank ,
60
68
distributed_init_method ))
61
69
return wrapper .worker
62
70
63
- def _init_non_spec_worker (self ):
64
- assert self .parallel_config .world_size == 1 , (
65
- "GPUExecutor only supports single GPU." )
66
-
67
- self .driver_worker = self ._create_worker ()
68
- self .driver_worker .init_device ()
69
- self .driver_worker .load_model ()
70
-
71
- def _init_spec_worker (self ):
72
- """Initialize a SpecDecodeWorker, using a draft model for proposals.
73
- """
74
- assert self .speculative_config is not None
75
-
76
- from vllm .spec_decode .spec_decode_worker import SpecDecodeWorker
77
-
78
- target_worker = self ._create_worker ()
79
-
80
- draft_worker_kwargs = self ._get_worker_kwargs ()
81
- # Override draft-model specific worker args.
82
- draft_worker_kwargs .update (
83
- model_config = self .speculative_config .draft_model_config ,
84
- parallel_config = self .speculative_config .draft_parallel_config ,
85
- ngram_prompt_lookup_max = self .speculative_config .
86
- ngram_prompt_lookup_max ,
87
- ngram_prompt_lookup_min = self .speculative_config .
88
- ngram_prompt_lookup_min ,
89
- # TODO allow draft-model specific load config.
90
- #load_config=self.load_config,
91
- )
92
-
93
- spec_decode_worker = SpecDecodeWorker .create_worker (
94
- scorer_worker = target_worker ,
95
- draft_worker_kwargs = draft_worker_kwargs ,
96
- disable_by_batch_size = self .speculative_config .
97
- speculative_disable_by_batch_size ,
98
- )
99
-
100
- assert self .parallel_config .world_size == 1 , (
101
- "GPUExecutor only supports single GPU." )
102
-
103
- self .driver_worker = spec_decode_worker
104
-
105
- # Load model handled in spec decode worker.
106
- self .driver_worker .init_device ()
107
-
108
71
def determine_num_available_blocks (self ) -> Tuple [int , int ]:
109
72
"""Determine the number of available KV blocks by invoking the
110
73
underlying worker.
0 commit comments