From b45f0d79469f583736052b80bfc8b3bab29f50d8 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 3 Dec 2024 01:53:36 +0800 Subject: [PATCH 01/13] [Misc][LoRA] Move the implementation of lora bias to punica.py (#10829) Signed-off-by: Jee Jee Li --- tests/lora/test_llama_tp.py | 60 +++++++-------- vllm/lora/fully_sharded_layers.py | 41 +++-------- vllm/lora/layers.py | 113 +++-------------------------- vllm/lora/punica.py | 117 +++++++++++++++++++++++++++--- 4 files changed, 156 insertions(+), 175 deletions(-) diff --git a/tests/lora/test_llama_tp.py b/tests/lora/test_llama_tp.py index aae6310a2a213..d3ca7f878191a 100644 --- a/tests/lora/test_llama_tp.py +++ b/tests/lora/test_llama_tp.py @@ -55,15 +55,7 @@ def do_sample(llm: vllm.LLM, lora_path: str, lora_id: int) -> List[str]: return generated_texts -@fork_new_process_for_each_test -def test_llama_lora(sql_lora_files): - - llm = vllm.LLM(MODEL_PATH, - enable_lora=True, - max_num_seqs=16, - max_loras=4, - tensor_parallel_size=1) - +def generate_and_test(llm, sql_lora_files): print("lora adapter created") assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT @@ -79,6 +71,17 @@ def test_llama_lora(sql_lora_files): print("removing lora") +@fork_new_process_for_each_test +def test_llama_lora(sql_lora_files): + + llm = vllm.LLM(MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=1) + generate_and_test(llm, sql_lora_files) + + @fork_new_process_for_each_test def test_llama_lora_warmup(sql_lora_files): """Test that the LLM initialization works with a warmup LORA path and @@ -118,20 +121,7 @@ def test_llama_lora_tp4(sql_lora_files): max_loras=4, tensor_parallel_size=4, ) - - print("lora adapter created") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - - print("lora 1") - assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT - - print("no lora") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - - print("lora 2") - assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT - - print("removing lora") + generate_and_test(llm, sql_lora_files) @multi_gpu_test(num_gpus=4) @@ -146,16 +136,20 @@ def test_llama_lora_tp4_fully_sharded_loras(sql_lora_files): tensor_parallel_size=4, fully_sharded_loras=True, ) - print("lora adapter created") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - - print("lora 1") - assert do_sample(llm, sql_lora_files, lora_id=1) == EXPECTED_LORA_OUTPUT + generate_and_test(llm, sql_lora_files) - print("no lora") - assert do_sample(llm, sql_lora_files, lora_id=0) == EXPECTED_NO_LORA_OUTPUT - print("lora 2") - assert do_sample(llm, sql_lora_files, lora_id=2) == EXPECTED_LORA_OUTPUT +@multi_gpu_test(num_gpus=4) +@fork_new_process_for_each_test +def test_llama_lora_tp4_fully_sharded_enable_bias(sql_lora_files): - print("removing lora") + llm = vllm.LLM( + MODEL_PATH, + enable_lora=True, + max_num_seqs=16, + max_loras=4, + tensor_parallel_size=4, + fully_sharded_loras=True, + enable_lora_bias=True, + ) + generate_and_test(llm, sql_lora_files) diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index f5c2eced9d2bb..5f2d32defe030 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -73,6 +73,7 @@ def apply(self, x: torch.Tensor, self.punica_wrapper.add_expand(output, buffer, self.lora_b_stacked, + self.bias_stacked, add_input=True) # now have column partitioned output @@ -131,27 +132,14 @@ def _mcp_apply(x, bias, layer: QKVParallelLinearWithLora): layer.lora_a_stacked[idx], 1.0) buffers = tensor_model_parallel_all_gather(buffers) - left_offset = 0 - for idx in range(n): - shard_size = layer.lora_b_stacked[idx].shape[2] - - if layer.bias_stacked is not None: - bias = layer.bias_stacked[idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[layer.punica_wrapper.token_lora_indices] - bias[layer.punica_wrapper.token_lora_indices == -1] = 0 - output[:, left_offset:left_offset + shard_size] += bias - - layer.punica_wrapper.add_expand_slice( - output, - buffers[idx], - layer.lora_b_stacked[idx], - left_offset, - shard_size, - add_input=True, - ) - left_offset += shard_size + layer.punica_wrapper.add_expand_packed_nslice( + output, + buffers, + layer.lora_b_stacked, + layer.bias_stacked, + 1.0, + layer.output_slices, + ) output = output.view(*out_orig_shape) # now have column partitioned and packed output @@ -234,6 +222,7 @@ def apply(self, x: torch.Tensor, self.punica_wrapper.add_expand(output, buffer, self.lora_b_stacked, + self.bias_all, add_input=True) # now have column partitioned output output = output.view(*out_orig_shape) @@ -350,15 +339,9 @@ def apply(self, x: torch.Tensor) -> torch.Tensor: # reduced before being used shard_size = self.lora_b_stacked.shape[2] start_idx = self.tp_rank * shard_size - - if self.bias_stacked is not None: - bias = self.bias_stacked.view(-1, self.bias_stacked.shape[-1]) - bias = bias[self.punica_wrapper.token_lora_indices] - bias[self.punica_wrapper.token_lora_indices == -1] = 0 - output += bias - self.punica_wrapper.add_expand_slice(output, buffer, - self.lora_b_stacked, start_idx, + self.lora_b_stacked, + self.bias_stacked, start_idx, shard_size) output = output.view(*out_orig_shape) return output diff --git a/vllm/lora/layers.py b/vllm/lora/layers.py index 3701988ff692f..73748b5ce511e 100644 --- a/vllm/lora/layers.py +++ b/vllm/lora/layers.py @@ -67,63 +67,6 @@ def dec(*args, **kwargs): return dec -def apply_bias( - indices: torch.Tensor, - output: torch.Tensor, - bias_stacked: torch.Tensor, -): - """Applies bias to output - - Input shapes: - bias_stacked: (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, output_dim) - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1]) - bias_stacked = bias_stacked[indices] - bias_stacked[indices == -1] = 0 - output += bias_stacked - - return output.view_as(org_output) - - -def apply_bias_packed_nslice( - indices: torch.Tensor, - output: torch.Tensor, - output_slices: Tuple[int, ...], - bias_stacked: Tuple[torch.Tensor, torch.Tensor, torch.Tensor], -): - """Applies bias to output - - Input shapes: - bias_stacked: 3 element tuple of (num_loras, output_dim) - indices: (batch_size) - output: (batch_size, q_slice_size + 2*kv_slice_size) - output_slices: n-1 element tuple of (slice_size...), - where n is number of slices - """ - org_output = output - output = output.view(-1, output.shape[-1]) - indices = indices.view(-1) - - offset_left = 0 - for slice_idx, slice in enumerate(output_slices): - bias = bias_stacked[slice_idx] - if bias is not None: - bias = bias.view(-1, bias.shape[-1]) - bias = bias[indices] - bias[indices == -1] = 0 - output[:, offset_left:offset_left + slice] += bias - - offset_left += slice - - return output.view_as(org_output) - - @dataclass class LoRAMapping(AdapterMapping): is_prefill: bool = False @@ -311,6 +254,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: self.punica_wrapper.add_expand(full_output, full_lora_a_embeddings, self.lora_b_stacked, + bias_all=None, add_input=True) return full_output.view_as(full_output_org) @@ -399,15 +343,9 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias( - self.indices, - output, - self.bias_stacked, - ) self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0) + self.lora_b_stacked, self.bias_stacked, + 1.0) return output def forward(self, input_): @@ -576,15 +514,9 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias( - self.indices, - output, - self.bias_stacked, - ) self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0) + self.lora_b_stacked, self.bias_stacked, + 1.0) return output def forward(self, input_): @@ -687,8 +619,8 @@ def create_lora_weights( ) for _ in range(n_slices)) else: self.bias_stacked = None - self.output_dim = self.lora_b_stacked[0].shape[2] + self.output_slices = (self.output_dim, self.output_dim) def reset_lora(self, index: int): self.lora_a_stacked[0][index] = 0 @@ -772,17 +704,9 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias_packed_nslice( - self.indices, - output, - (self.output_dim, self.output_dim), - self.bias_stacked, - ) self.punica_wrapper.add_lora_packed_nslice( - output, x, self.lora_a_stacked, self.lora_b_stacked, 1.0, - (self.output_dim, self.output_dim)) + output, x, self.lora_a_stacked, self.lora_b_stacked, + self.bias_stacked, 1.0, (self.output_dim, self.output_dim)) return output @classmethod @@ -1129,17 +1053,10 @@ def set_lora( def apply(self, x: torch.Tensor, bias: Optional[torch.Tensor]) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x, bias) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias_packed_nslice( - self.indices, - output, - self.output_slices, - self.bias_stacked, - ) self.punica_wrapper.add_lora_packed_nslice(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0, + self.lora_b_stacked, + self.bias_stacked, 1.0, self.output_slices) return output @@ -1264,15 +1181,9 @@ def set_lora( def apply(self, x: torch.Tensor) -> torch.Tensor: output = self.base_layer.quant_method.apply(self.base_layer, x) - if self.bias_stacked is not None: - self.indices = self.punica_wrapper.token_lora_indices - output = apply_bias( - self.indices, - output, - self.bias_stacked, - ) self.punica_wrapper.add_lora(output, x, self.lora_a_stacked, - self.lora_b_stacked, 1.0) + self.lora_b_stacked, self.bias_stacked, + 1.0) return output def forward(self, input_): diff --git a/vllm/lora/punica.py b/vllm/lora/punica.py index 082041f390750..3f775b7ba363e 100644 --- a/vllm/lora/punica.py +++ b/vllm/lora/punica.py @@ -450,6 +450,62 @@ def expand_slice_decode( bgmv_expand_slice(x, w_t_all, y, self.token_lora_indices, y_offset, y_slice_size, add_input) + def apply_bias( + self, + indices: torch.Tensor, + output: torch.Tensor, + bias_stacked: torch.Tensor, + ): + """Applies bias to output + + Input shapes: + bias_stacked: (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, output_dim) + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + bias_stacked = bias_stacked.view(-1, bias_stacked.shape[-1]) + bias_stacked = bias_stacked[indices] + bias_stacked[indices == -1] = 0 + output += bias_stacked + + return output.view_as(org_output) + + def apply_bias_packed_nslice( + self, + indices: torch.Tensor, + output: torch.Tensor, + output_slices: Tuple[int, ...], + bias_stacked: Tuple[Optional[torch.Tensor], ...], + ): + """Applies bias to output + + Input shapes: + bias_stacked: 3 element tuple of (num_loras, output_dim) + indices: (batch_size) + output: (batch_size, q_slice_size + 2*kv_slice_size) + output_slices: n-1 element tuple of (slice_size...), + where n is number of slices + """ + org_output = output + output = output.view(-1, output.shape[-1]) + indices = indices.view(-1) + + offset_left = 0 + for slice_idx, slice in enumerate(output_slices): + bias = bias_stacked[slice_idx] + if bias is not None: + bias = bias.view(-1, bias.shape[-1]) + bias = bias[indices] + bias[indices == -1] = 0 + output[:, offset_left:offset_left + slice] += bias + offset_left += slice + + return output.view_as(org_output) + def add_shrink( self, y: torch.Tensor, @@ -474,16 +530,19 @@ def add_expand( y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + bias_all: Optional[torch.Tensor], add_input: bool = True, ): """ - Perform the ` y+=x@w_t_all` computation, which is suitable for the + Perform the ` y+=x@w_t_all+bias` computation, which is suitable for the GEMM of lora'b. When `is_prefill` is true, it indicates that it is currently the prefill stage, and the `expand_prefill` function should be called. Otherwise, it is the decode stage, and the expand_decode function should be called. """ + if bias_all is not None: + y = self.apply_bias(self.token_lora_indices, y, bias_all) expand_fun: Callable = (self.expand_prefill if self.is_prefill else self.expand_decode) @@ -493,23 +552,54 @@ def add_expand_slice(self, y: torch.Tensor, x: torch.Tensor, w_t_all: torch.Tensor, + bias_all: Optional[torch.Tensor], y_offset: Optional[int], y_slice_size: Optional[int], add_input: bool = True): """ Similar to `add_expand` """ + if bias_all is not None: + y = self.apply_bias(self.token_lora_indices, y, bias_all) expand_slice_fun: Callable = (self.expand_slice_prefill if self.is_prefill else self.expand_slice_decode) expand_slice_fun(y, x, w_t_all, y_offset, y_slice_size, add_input) + def add_expand_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, + lora_b_stacked: Tuple[torch.Tensor, ...], + bias_stacked: Optional[Tuple[torch.Tensor, + ...]], + scale: float, + output_slices: Tuple[int, ...]) -> None: + """ + Similar to `add_expand` + """ + y_org = y + y = y.view(-1, y.shape[-1]) + offset_left = 0 + if bias_stacked is not None: + self.apply_bias_packed_nslice(self.token_lora_indices, y, + output_slices, bias_stacked) + for slice_idx in range(len(lora_b_stacked)): + self.add_expand_slice(y, + x[slice_idx], + lora_b_stacked[slice_idx], + None, + offset_left, + output_slices[slice_idx], + add_input=True) + offset_left += output_slices[slice_idx] + + y = y.view_as(y_org) + def add_lora(self, y: torch.Tensor, x: torch.Tensor, wa_t_all: torch.Tensor, wb_t_all: torch.Tensor, + bias_all: Optional[torch.Tensor], scale: float, y_offset: Optional[int] = None, y_slice_size: Optional[int] = None, @@ -522,12 +612,13 @@ def add_lora(self, @ wa_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) @ wb_t_all[indices[i], layer_idx, :, :].transpose(-1, -2) * scale - ).squeeze(0) + ).squeeze(0)+bias[i] Args: y (torch.Tensor): Output tensor. Will be changed in-place. x (torch.Tensor): Input tensor wa_t_all (torch.Tensor): lora_a's weight wb_t_all (torch.Tensor): lora_b's weight + bias_all: (torch.Tensor): lora's bias scale (float): Scaling factor. y_offset (Optional[int], optional): Offset to apply to the starting column of y. @@ -544,27 +635,26 @@ def add_lora(self, buffer = torch.zeros((x.size(0), r), dtype=torch.float32, device=x.device) - + if bias_all is not None: + y = self.apply_bias(self.token_lora_indices, y, bias_all) self.add_shrink(buffer, x, wa_t_all, scale) if y_offset is None and y_slice_size is None: - self.add_expand(y, buffer, wb_t_all, add_input=True) + self.add_expand(y, buffer, wb_t_all, bias_all=None, add_input=True) else: self.add_expand_slice(y, buffer, wb_t_all, + None, y_offset, y_slice_size, add_input=True) y = y.view_as(y_org) def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, - lora_a_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - lora_b_stacked: Tuple[torch.Tensor, - torch.Tensor, - torch.Tensor], - scale: float, + lora_a_stacked: Tuple[torch.Tensor, ...], + lora_b_stacked: Tuple[torch.Tensor, ...], + bias_all: Tuple[Optional[torch.Tensor], + ...], scale: float, output_slices: Tuple[int, ...]) -> None: """ Applies lora to each input. Similar to add_lora, This method is @@ -575,10 +665,13 @@ def add_lora_packed_nslice(self, y: torch.Tensor, x: torch.Tensor, x = x.view(-1, x.shape[-1]) y = y.view(-1, y.shape[-1]) offset_left = 0 + if bias_all is not None: + y = self.apply_bias_packed_nslice(self.token_lora_indices, y, + output_slices, bias_all) # TODO fuse these kernels for slice_idx in range(len(output_slices)): self.add_lora(y, x, lora_a_stacked[slice_idx], - lora_b_stacked[slice_idx], scale, offset_left, + lora_b_stacked[slice_idx], None, scale, offset_left, output_slices[slice_idx]) offset_left += output_slices[slice_idx] From 519cc6ca12dc89eec35bc2579494e399da33c31a Mon Sep 17 00:00:00 2001 From: Yan Ma Date: Tue, 3 Dec 2024 01:53:55 +0800 Subject: [PATCH 02/13] [Misc][XPU] Avoid torch compile for XPU platform (#10747) Signed-off-by: yan ma Co-authored-by: youkaichao --- .buildkite/run-xpu-test.sh | 6 ++++-- vllm/plugins/__init__.py | 4 ++++ 2 files changed, 8 insertions(+), 2 deletions(-) diff --git a/.buildkite/run-xpu-test.sh b/.buildkite/run-xpu-test.sh index faeac8e2ded36..50f58f7d70430 100644 --- a/.buildkite/run-xpu-test.sh +++ b/.buildkite/run-xpu-test.sh @@ -12,5 +12,7 @@ remove_docker_container() { docker rm -f xpu-test || true; } trap remove_docker_container EXIT remove_docker_container -# Run the image and launch offline inference -docker run --network host --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path --entrypoint="" xpu-test python3 examples/offline_inference.py +# Run the image and test offline inference/tensor parallel +docker run -it -d --name xpu-test --device /dev/dri -v /dev/dri/by-path:/dev/dri/by-path xpu-test /bin/bash +docker exec xpu-test bash -c "python3 examples/offline_inference.py" +docker exec xpu-test bash -c "python3 examples/offline_inference_cli.py -tp 2" diff --git a/vllm/plugins/__init__.py b/vllm/plugins/__init__.py index 3c64726ca3344..81ee9975cdc4a 100644 --- a/vllm/plugins/__init__.py +++ b/vllm/plugins/__init__.py @@ -4,6 +4,7 @@ import torch import vllm.envs as envs +from vllm.platforms import current_platform logger = logging.getLogger(__name__) @@ -25,6 +26,9 @@ def load_general_plugins(): os.environ['TORCHINDUCTOR_COMPILE_THREADS'] = '1' # see https://github.com/vllm-project/vllm/issues/10619 torch._inductor.config.compile_threads = 1 + if current_platform.is_xpu(): + # see https://github.com/pytorch/pytorch/blob/8cada5cbe5450e17c26fb8b358116785324537b2/torch/_dynamo/config.py#L158 # noqa + os.environ['TORCH_COMPILE_DISABLE'] = 'True' global plugins_loaded if plugins_loaded: return From 9b14d978aa8c286b738f107fab4626273f4fc088 Mon Sep 17 00:00:00 2001 From: Jani Monoses Date: Mon, 2 Dec 2024 20:52:19 +0200 Subject: [PATCH 03/13] Fix openvino on GPU (#10793) --- vllm/worker/openvino_worker.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/vllm/worker/openvino_worker.py b/vllm/worker/openvino_worker.py index 205f8a337ce6c..0bf522d5333ed 100644 --- a/vllm/worker/openvino_worker.py +++ b/vllm/worker/openvino_worker.py @@ -489,7 +489,7 @@ def model_profile_run(): block_size = cache_config.block_size seq_num_blocks = (seq_len + block_size - 1) // block_size - seq_data, dummy_multi_modal_data = input_registry \ + dummy_data = input_registry \ .dummy_data_for_profiling(model_config, seq_len, mm_registry) @@ -498,11 +498,11 @@ def model_profile_run(): seq = SequenceGroupMetadata( request_id=str(group_id), is_prompt=True, - seq_data={group_id: seq_data}, + seq_data={group_id: dummy_data.seq_data}, sampling_params=sampling_params, block_tables=block_tables, lora_request=None, - multi_modal_data=dummy_multi_modal_data) + multi_modal_data=dummy_data.multi_modal_data) seqs.append(seq) self.model_runner.block_size = tmp_cache_config.block_size From 4c05edb33ae4ae279421ddf981816d070e8ec37a Mon Sep 17 00:00:00 2001 From: Isotr0py Date: Tue, 3 Dec 2024 07:06:09 +0800 Subject: [PATCH 04/13] [Model] Add TP and BNB quantization support to LlavaMultiModalProjector (#10834) Signed-off-by: Isotr0py <2037008807@qq.com> Co-authored-by: Cyrus Leung --- vllm/model_executor/model_loader/loader.py | 14 +++++++-- vllm/model_executor/models/llava.py | 35 ++++++++++++++-------- 2 files changed, 34 insertions(+), 15 deletions(-) diff --git a/vllm/model_executor/model_loader/loader.py b/vllm/model_executor/model_loader/loader.py index 0e12bc5691538..b4921cc80797f 100644 --- a/vllm/model_executor/model_loader/loader.py +++ b/vllm/model_executor/model_loader/loader.py @@ -1120,7 +1120,14 @@ def _load_weights(self, model_config: ModelConfig, model_config.revision, pre_quant, load_8bit)) - model.load_weights(qweight_iterator) + weights_to_load = {name for name, _ in model.named_parameters()} + loaded_weights = model.load_weights(qweight_iterator) + # Some models may have weights loading tracker unimplemented. + if loaded_weights is not None: + weights_not_loaded = weights_to_load - loaded_weights + if weights_not_loaded: + raise ValueError("Following weights were not initialized from " + f"checkpoint: {weights_not_loaded}") torch.cuda.empty_cache() @@ -1152,9 +1159,10 @@ def _load_weights(self, model_config: ModelConfig, shard_name, weight_name) break + # Models like Clip/Siglip may skip some layers in initialization, + # causing unused quant_param_name in state_dict. if quant_param_name not in param_dict: - raise ValueError( - f"Parameter {quant_param_name} not found in the model.") + continue if quant_param_name not in stacked_quant_state_dict: stacked_quant_state_dict[quant_param_name] = {} diff --git a/vllm/model_executor/models/llava.py b/vllm/model_executor/models/llava.py index db7fa82ceb9b7..d375c1c9da2a9 100644 --- a/vllm/model_executor/models/llava.py +++ b/vllm/model_executor/models/llava.py @@ -13,6 +13,8 @@ from vllm.inputs import (INPUT_REGISTRY, DecoderOnlyInputs, DummyData, InputContext) from vllm.model_executor.layers.activation import get_act_fn +from vllm.model_executor.layers.linear import (ColumnParallelLinear, + RowParallelLinear) from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.sampler import SamplerOutput, get_sampler from vllm.model_executor.sampling_metadata import SamplingMetadata @@ -59,25 +61,32 @@ class LlavaImageEmbeddingInputs(TypedDict): LlavaImageInputs = Union[LlavaImagePixelInputs, LlavaImageEmbeddingInputs] -# TODO(xwjiang): Run benchmark and decide if TP. class LlavaMultiModalProjector(nn.Module): - def __init__(self, vision_hidden_size: int, text_hidden_size: int, - projector_hidden_act: str): + def __init__(self, + vision_hidden_size: int, + text_hidden_size: int, + projector_hidden_act: str, + quant_config: Optional[QuantizationConfig] = None, + prefix: str = ""): super().__init__() - self.linear_1 = nn.Linear(vision_hidden_size, - text_hidden_size, - bias=True) + self.linear_1 = ColumnParallelLinear(vision_hidden_size, + text_hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_1") self.act = get_act_fn(projector_hidden_act) - self.linear_2 = nn.Linear(text_hidden_size, - text_hidden_size, - bias=True) + self.linear_2 = RowParallelLinear(text_hidden_size, + text_hidden_size, + bias=True, + quant_config=quant_config, + prefix=f"{prefix}.linear_2") def forward(self, image_features: torch.Tensor) -> torch.Tensor: - hidden_states = self.linear_1(image_features) + hidden_states, _ = self.linear_1(image_features) hidden_states = self.act(hidden_states) - hidden_states = self.linear_2(hidden_states) + hidden_states, _ = self.linear_2(hidden_states) return hidden_states @@ -325,7 +334,9 @@ def __init__(self, *, vllm_config: VllmConfig, prefix: str = "") -> None: self.multi_modal_projector = LlavaMultiModalProjector( vision_hidden_size=config.vision_config.hidden_size, text_hidden_size=config.text_config.hidden_size, - projector_hidden_act=config.projector_hidden_act) + projector_hidden_act=config.projector_hidden_act, + quant_config=quant_config, + prefix=maybe_prefix(prefix, "multi_modal_projector")) self.language_model = init_vllm_registered_model( vllm_config=vllm_config, From 4433195ab75e2bb367303ba5f34c97521c5677ce Mon Sep 17 00:00:00 2001 From: Michael Goin Date: Mon, 2 Dec 2024 21:26:15 -0500 Subject: [PATCH 05/13] [Bugfix] Prevent benchmark_throughput.py from using duplicated random prompts (#10753) --- benchmarks/benchmark_throughput.py | 47 +++++++++++++++++++----------- 1 file changed, 30 insertions(+), 17 deletions(-) diff --git a/benchmarks/benchmark_throughput.py b/benchmarks/benchmark_throughput.py index 159cf055737ce..1e5967bd9bf8b 100644 --- a/benchmarks/benchmark_throughput.py +++ b/benchmarks/benchmark_throughput.py @@ -294,23 +294,36 @@ def main(args: argparse.Namespace): tokenizer = AutoTokenizer.from_pretrained( args.tokenizer, trust_remote_code=args.trust_remote_code) if args.dataset is None: - # Synthesize a prompt with the given input length. - # As tokenizer may add additional tokens like BOS, we need to try - # different lengths to get the desired input length. - for i in range(-10, 10): - prompt = "hi " * (args.input_len + i) - tokenized_prompt = tokenizer(prompt).input_ids - if len(tokenized_prompt) == args.input_len: - break - else: - raise ValueError( - f"Failed to synthesize a prompt with {args.input_len} tokens.") - requests = [ - SampleRequest(prompt=prompt, - prompt_len=args.input_len, - expected_output_len=args.output_len) - for _ in range(args.num_prompts) - ] + vocab_size = tokenizer.vocab_size + requests = [] + for _ in range(args.num_prompts): + # Synthesize a prompt with the given input length. + candidate_ids = [ + random.randint(0, vocab_size - 1) + for _ in range(args.input_len) + ] + # As tokenizer may add additional tokens like BOS, we need to try + # different lengths to get the desired input length. + for _ in range(5): # Max attempts to correct + candidate_prompt = tokenizer.decode(candidate_ids) + tokenized_len = len(tokenizer.encode(candidate_prompt)) + + if tokenized_len == args.input_len: + break + + # Adjust length based on difference + diff = args.input_len - tokenized_len + if diff > 0: + candidate_ids.extend([ + random.randint(100, vocab_size - 100) + for _ in range(diff) + ]) + else: + candidate_ids = candidate_ids[:diff] + requests.append( + SampleRequest(prompt=candidate_prompt, + prompt_len=args.input_len, + expected_output_len=args.output_len)) else: requests = sample_requests(tokenizer, args) From d746268e92dc97d3a816c70637e20073eeac5103 Mon Sep 17 00:00:00 2001 From: zixuanzhang226 Date: Mon, 2 Dec 2024 19:06:41 -0800 Subject: [PATCH 06/13] [Model] support bitsandbytes quantization with minicpm model (#10842) Signed-off-by: Ubuntu --- vllm/model_executor/models/minicpm.py | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/vllm/model_executor/models/minicpm.py b/vllm/model_executor/models/minicpm.py index 6254d26c7060d..5a0f202364f26 100644 --- a/vllm/model_executor/models/minicpm.py +++ b/vllm/model_executor/models/minicpm.py @@ -534,6 +534,16 @@ class MiniCPMForCausalLM(nn.Module, SupportsLoRA, SupportsPP): } embedding_padding_modules = ["lm_head"] + # BitandBytes specific attributes + bitsandbytes_stacked_params_mapping = { + # shard_name, weight_name, index + "q_proj": ("qkv_proj", 0), + "k_proj": ("qkv_proj", 1), + "v_proj": ("qkv_proj", 2), + "gate_proj": ("gate_up_proj", 0), + "up_proj": ("gate_up_proj", 1), + } + def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config From a4cf2561599448d4a5c3de4d79c73ca37cb8d647 Mon Sep 17 00:00:00 2001 From: Jee Jee Li Date: Tue, 3 Dec 2024 12:10:29 +0800 Subject: [PATCH 07/13] [Bugfix] Fix QKVParallelLinearWithShardedLora bias bug (#10844) Signed-off-by: Jee Jee Li --- .buildkite/test-pipeline.yaml | 1 - vllm/lora/fully_sharded_layers.py | 9 +-------- 2 files changed, 1 insertion(+), 9 deletions(-) diff --git a/.buildkite/test-pipeline.yaml b/.buildkite/test-pipeline.yaml index f5591f1098534..455f02a2062f1 100644 --- a/.buildkite/test-pipeline.yaml +++ b/.buildkite/test-pipeline.yaml @@ -481,7 +481,6 @@ steps: - label: LoRA TP Test (Distributed) num_gpus: 4 - soft_fail: true source_file_dependencies: - vllm/lora - tests/lora diff --git a/vllm/lora/fully_sharded_layers.py b/vllm/lora/fully_sharded_layers.py index 5f2d32defe030..e25e453201f01 100644 --- a/vllm/lora/fully_sharded_layers.py +++ b/vllm/lora/fully_sharded_layers.py @@ -77,13 +77,6 @@ def apply(self, x: torch.Tensor, add_input=True) # now have column partitioned output - if self.bias_stacked is not None: - self.bias_stacked = self.bias_stacked.view( - -1, self.bias_stacked.shape[-1]) - self.bias_stacked = self.bias_stacked[ - self.punica_wrapper.token_lora_indices] - output += self.bias_stacked - output = output.view(*out_orig_shape) return output @@ -222,7 +215,7 @@ def apply(self, x: torch.Tensor, self.punica_wrapper.add_expand(output, buffer, self.lora_b_stacked, - self.bias_all, + self.bias_stacked, add_input=True) # now have column partitioned output output = output.view(*out_orig_shape) From 21fe7b481a3a84dc9ebe2497ec89a17002ad52c5 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 2 Dec 2024 20:53:23 -0800 Subject: [PATCH 08/13] [core][distributed] add pynccl broadcast (#10843) Signed-off-by: youkaichao --- tests/distributed/test_pynccl.py | 45 ++++++++++++++++++- .../device_communicators/pynccl.py | 19 ++++++++ .../device_communicators/pynccl_wrapper.py | 16 +++++++ 3 files changed, 78 insertions(+), 2 deletions(-) diff --git a/tests/distributed/test_pynccl.py b/tests/distributed/test_pynccl.py index fb24d6bc2c100..4e27babf12cc3 100644 --- a/tests/distributed/test_pynccl.py +++ b/tests/distributed/test_pynccl.py @@ -61,6 +61,7 @@ def worker_fn(): dtype=torch.float32).cuda(pynccl_comm.rank) with pynccl_comm.change_state(enable=True): tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == pynccl_comm.world_size @@ -86,10 +87,12 @@ def multiple_allreduce_worker_fn(): if torch.distributed.get_rank() in [0, 1]: tensor = pynccl_comm.all_reduce(tensor) tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == 4 else: tensor = pynccl_comm.all_reduce(tensor) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == 2 @@ -112,10 +115,12 @@ def multiple_allreduce_with_vllm_worker_fn(): if torch.distributed.get_rank() in [0, 1]: tensor = tensor_model_parallel_all_reduce(tensor) tensor = tensor_model_parallel_all_reduce(tensor) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == 4 else: tensor = tensor_model_parallel_all_reduce(tensor) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == 2 @@ -141,9 +146,9 @@ def worker_fn_with_cudagraph(): graph, stream=pynccl_comm.stream), pynccl_comm.change_state( enable=True): a_out = pynccl_comm.all_reduce(a) - pynccl_comm.stream.synchronize() + torch.cuda.synchronize() graph.replay() - pynccl_comm.stream.synchronize() + torch.cuda.synchronize() assert a_out.mean().cpu().item() == pynccl_comm.world_size**1 @@ -170,6 +175,7 @@ def all_gather_worker_fn(): with pynccl_comm.change_state(enable=True): pynccl_comm.all_gather(result, tensor) + torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -207,6 +213,7 @@ def reduce_scatter_worker_fn(): with pynccl_comm.change_state(enable=True): pynccl_comm.reduce_scatter(result, tensor) + torch.cuda.synchronize() torch.testing.assert_close(result, expected, rtol=1e-5, atol=1e-8) @@ -241,6 +248,7 @@ def send_recv_worker_fn(): pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + torch.cuda.synchronize() result = tensor.mean().cpu().item() assert result == 1 @@ -280,6 +288,7 @@ def multiple_send_recv_worker_fn(): pynccl_comm.recv(tensor, src=(pynccl_comm.rank - 1) % pynccl_comm.world_size) + torch.cuda.synchronize() result = tensor.mean().cpu().item() if torch.distributed.get_rank() in [0, 2]: assert result == 1 @@ -293,6 +302,38 @@ def test_pynccl_multiple_send_recv(): distributed_run(multiple_send_recv_worker_fn, 4) +@pytest.mark.skipif(torch.cuda.device_count() < 4, + reason="Need at least 4 GPUs to run the test.") +def test_pynccl_broadcast(): + distributed_run(broadcast_worker_fn, 4) + + +@worker_fn_wrapper +def broadcast_worker_fn(): + # Test broadcast for every root rank. + # Essentially this is an all-gather operation. + pynccl_comm = PyNcclCommunicator(get_world_group().cpu_group, + device=get_world_group().device) + recv_tensors = [ + torch.empty(16, + 1024, + 1024, + dtype=torch.float32, + device=pynccl_comm.device) + for i in range(pynccl_comm.world_size) + ] + recv_tensors[pynccl_comm.rank] = torch.ones( + 16, 1024, 1024, dtype=torch.float32, + device=pynccl_comm.device) * pynccl_comm.rank + + for i in range(pynccl_comm.world_size): + pynccl_comm.broadcast(recv_tensors[i], src=i) + # the broadcast op might be launched in a different stream + # need to synchronize to make sure the tensor is ready + torch.cuda.synchronize() + assert torch.all(recv_tensors[i] == i).cpu().item() + + def test_ncclGetUniqueId(): lib = NCCLLibrary() unique_id = lib.ncclGetUniqueId() diff --git a/vllm/distributed/device_communicators/pynccl.py b/vllm/distributed/device_communicators/pynccl.py index d4e3f81747038..a6800f93f167b 100644 --- a/vllm/distributed/device_communicators/pynccl.py +++ b/vllm/distributed/device_communicators/pynccl.py @@ -197,6 +197,25 @@ def recv(self, tensor: torch.Tensor, src: int, stream=None): ncclDataTypeEnum.from_torch(tensor.dtype), src, self.comm, cudaStream_t(stream.cuda_stream)) + def broadcast(self, tensor: torch.Tensor, src: int, stream=None): + if self.disabled: + return + assert tensor.device == self.device, ( + f"this nccl communicator is created to work on {self.device}, " + f"but the input tensor is on {tensor.device}") + if stream is None: + stream = self.stream + if src == self.rank: + sendbuff = buffer_type(tensor.data_ptr()) + # NCCL requires the sender also to have a receive buffer + recvbuff = buffer_type(tensor.data_ptr()) + else: + sendbuff = buffer_type() + recvbuff = buffer_type(tensor.data_ptr()) + self.nccl.ncclBroadcast(sendbuff, recvbuff, tensor.numel(), + ncclDataTypeEnum.from_torch(tensor.dtype), src, + self.comm, cudaStream_t(stream.cuda_stream)) + @contextmanager def change_state(self, enable: Optional[bool] = None, diff --git a/vllm/distributed/device_communicators/pynccl_wrapper.py b/vllm/distributed/device_communicators/pynccl_wrapper.py index ff88f72470b27..7dea61b6a09f1 100644 --- a/vllm/distributed/device_communicators/pynccl_wrapper.py +++ b/vllm/distributed/device_communicators/pynccl_wrapper.py @@ -189,6 +189,15 @@ class NCCLLibrary: ncclComm_t, cudaStream_t ]), + # ncclResult_t ncclBroadcast( + # const void* sendbuff, void* recvbuff, size_t count, + # ncclDataType_t datatype, int root, ncclComm_t comm, + # cudaStream_t stream); + Function("ncclBroadcast", ncclResult_t, [ + buffer_type, buffer_type, ctypes.c_size_t, ncclDataType_t, + ctypes.c_int, ncclComm_t, cudaStream_t + ]), + # be cautious! this is a collective call, it will block until all # processes in the communicator have called this function. # because Python object destruction can happen in random order, @@ -312,6 +321,13 @@ def ncclRecv(self, recvbuff: buffer_type, count: int, datatype: int, self.NCCL_CHECK(self._funcs["ncclRecv"](recvbuff, count, datatype, src, comm, stream)) + def ncclBroadcast(self, sendbuff: buffer_type, recvbuff: buffer_type, + count: int, datatype: int, root: int, comm: ncclComm_t, + stream: cudaStream_t) -> None: + self.NCCL_CHECK(self._funcs["ncclBroadcast"](sendbuff, recvbuff, count, + datatype, root, comm, + stream)) + def ncclCommDestroy(self, comm: ncclComm_t) -> None: self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm)) From dc5ce861bf0e10fc002384859b93b1eebbd70933 Mon Sep 17 00:00:00 2001 From: youkaichao Date: Mon, 2 Dec 2024 22:19:02 -0800 Subject: [PATCH 09/13] [torch.compile] remove compilation_context and simplify code (#10838) Signed-off-by: youkaichao --- tests/compile/piecewise/test_simple.py | 9 +- tests/compile/piecewise/test_toy_llama.py | 33 ++++---- .../decoder_only/language/test_jamba.py | 5 +- .../decoder_only/language/test_mamba.py | 5 +- .../test_encoder_decoder_model_runner.py | 4 +- tests/worker/test_model_runner.py | 5 +- vllm/compilation/backends.py | 4 - vllm/compilation/compile_context.py | 23 ----- vllm/config.py | 83 +++++++++++++++++-- vllm/model_executor/models/jamba.py | 6 +- vllm/model_executor/models/mamba.py | 6 +- vllm/v1/worker/gpu_model_runner.py | 14 ++-- vllm/worker/enc_dec_model_runner.py | 6 +- vllm/worker/model_runner.py | 68 ++------------- 14 files changed, 128 insertions(+), 143 deletions(-) delete mode 100644 vllm/compilation/compile_context.py diff --git a/tests/compile/piecewise/test_simple.py b/tests/compile/piecewise/test_simple.py index 7ef502abee345..aa11524812cdd 100644 --- a/tests/compile/piecewise/test_simple.py +++ b/tests/compile/piecewise/test_simple.py @@ -7,7 +7,6 @@ from torch import nn from torch.library import Library -from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, @@ -81,6 +80,7 @@ def test_simple_piecewise_compile(): use_cudagraph=True, splitting_ops=["silly.attention"], cudagraph_copy_inputs=True, + cudagraph_capture_sizes=[1, 2], )) with set_current_vllm_config(vllm_config): model = SillyModel(vllm_config=vllm_config, prefix='') @@ -96,11 +96,10 @@ def test_simple_piecewise_compile(): 6, # num_cudagraph_sizes * num_piecewise_capturable_graphs_seen ): - with set_compile_context([1, 2]): - model(inputs) + model(inputs) - model(torch.randn(2).cuda()) - model(torch.randn(1).cuda()) + model(torch.randn(2).cuda()) + model(torch.randn(1).cuda()) input = torch.zeros(2).cuda() global global_counter diff --git a/tests/compile/piecewise/test_toy_llama.py b/tests/compile/piecewise/test_toy_llama.py index dbd5a3bbffeab..07c10a3a18c55 100644 --- a/tests/compile/piecewise/test_toy_llama.py +++ b/tests/compile/piecewise/test_toy_llama.py @@ -13,7 +13,6 @@ from torch import nn from torch.library import Library -from vllm.compilation.compile_context import set_compile_context from vllm.compilation.counter import compilation_counter from vllm.compilation.decorators import support_torch_compile from vllm.config import (CompilationConfig, CompilationLevel, VllmConfig, @@ -256,6 +255,7 @@ def run_model(llama_config, compilation_config = CompilationConfig( level=CompilationLevel.PIECEWISE, use_cudagraph=True, + cudagraph_capture_sizes=[1, 2], ) if split_attn: compilation_config.splitting_ops = ["silly.attention"] @@ -273,10 +273,9 @@ def run_model(llama_config, input_ids = torch.randint(0, llama_config.vocab_size, (B, )).cuda() positions = torch.arange(B).cuda() - with set_compile_context([1, 2]): - model(input_ids, positions) - model(input_ids[:2], positions[:2]) - model(input_ids[:1], positions[:1]) + model(input_ids, positions) + model(input_ids[:2], positions[:2]) + model(input_ids[:1], positions[:1]) input_ids[:2].zero_() output = model(input_ids[:2], positions[:2]) @@ -379,10 +378,13 @@ def benchmark(): level=CompilationLevel.PIECEWISE, use_cudagraph=True, splitting_ops=["silly.attention"], + cudagraph_capture_sizes=cudagraph_sizes, ) else: compilation_config = CompilationConfig( - level=CompilationLevel.PIECEWISE, ) + level=CompilationLevel.PIECEWISE, + cudagraph_capture_sizes=cudagraph_sizes, + ) vllm_config = VllmConfig(compilation_config=compilation_config) with set_current_vllm_config(vllm_config): @@ -396,17 +398,16 @@ def benchmark(): graphs = {} - with set_compile_context(cudagraph_sizes): - model(input_ids, positions) - for b in cudagraph_sizes[::-1]: - if not piecewise: - graph = torch.cuda.CUDAGraph() - with torch.cuda.graph(graph, pool=pool): - output = model(input_ids[:b], positions[:b]) - graphs[b] = (graph, output) - else: + model(input_ids, positions) + for b in cudagraph_sizes[::-1]: + if not piecewise: + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph, pool=pool): output = model(input_ids[:b], positions[:b]) - graphs[b] = (model, output) + graphs[b] = (graph, output) + else: + output = model(input_ids[:b], positions[:b]) + graphs[b] = (model, output) for b in cudagraph_sizes: if piecewise: # noqa is for `Function definition does not bind loop variable` diff --git a/tests/models/decoder_only/language/test_jamba.py b/tests/models/decoder_only/language/test_jamba.py index 87a05b3011393..cae25ae9fa2c8 100644 --- a/tests/models/decoder_only/language/test_jamba.py +++ b/tests/models/decoder_only/language/test_jamba.py @@ -1,8 +1,8 @@ import pytest from tests.utils import multi_gpu_test +from vllm.config import VllmConfig from vllm.sampling_params import SamplingParams -from vllm.worker.model_runner import _get_graph_batch_size from ...utils import check_outputs_equal @@ -189,7 +189,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + while len(example_prompts) == VllmConfig.get_graph_batch_size( + len(example_prompts)): example_prompts.append(example_prompts[0]) try: diff --git a/tests/models/decoder_only/language/test_mamba.py b/tests/models/decoder_only/language/test_mamba.py index 01e208347bff4..35018c3c14dee 100644 --- a/tests/models/decoder_only/language/test_mamba.py +++ b/tests/models/decoder_only/language/test_mamba.py @@ -5,8 +5,8 @@ import pytest from transformers import AutoModelForCausalLM, AutoTokenizer +from vllm.config import VllmConfig from vllm.sampling_params import SamplingParams -from vllm.worker.model_runner import _get_graph_batch_size from ...utils import check_outputs_equal @@ -200,7 +200,8 @@ def test_mamba_cache_cg_padding( # This test is for verifying that mamba cache is padded to CG captured # batch size. If it's not, a torch RuntimeError will be raised because # tensor dimensions aren't compatible - while len(example_prompts) == _get_graph_batch_size(len(example_prompts)): + while len(example_prompts) == VllmConfig.get_graph_batch_size( + len(example_prompts)): example_prompts.append(example_prompts[0]) try: diff --git a/tests/worker/test_encoder_decoder_model_runner.py b/tests/worker/test_encoder_decoder_model_runner.py index 9e166ae64dbfb..5289c91f201cd 100644 --- a/tests/worker/test_encoder_decoder_model_runner.py +++ b/tests/worker/test_encoder_decoder_model_runner.py @@ -4,12 +4,12 @@ import pytest import torch +from vllm.config import VllmConfig from vllm.engine.arg_utils import EngineArgs from vllm.platforms import current_platform from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import make_tensor_with_pad from vllm.worker.enc_dec_model_runner import EncoderDecoderModelRunner -from vllm.worker.model_runner import _get_graph_batch_size BATCH_SIZES = [1, 4, 16, 64, 256] @@ -548,7 +548,7 @@ def test_prepare_decode_cuda_graph(batch_size, multiple_seqs_per_seq_group): # With CUDA Graph capture and replay enabled, the decoder and encoder # input sequences will be padded. Create the expected padded tensors # accordingly. - graph_batch_size = _get_graph_batch_size(expanded_batch_size) + graph_batch_size = VllmConfig.get_graph_batch_size(expanded_batch_size) cuda_graph_pad_size = graph_batch_size - expanded_batch_size padded_seq_lens = seq_lens + list(itertools.repeat(1, cuda_graph_pad_size)) padded_encoder_seq_lens = encoder_seq_lens + list( diff --git a/tests/worker/test_model_runner.py b/tests/worker/test_model_runner.py index 433a9b30ba57a..4055524f3e0c7 100644 --- a/tests/worker/test_model_runner.py +++ b/tests/worker/test_model_runner.py @@ -3,13 +3,14 @@ import pytest import torch +from vllm.config import VllmConfig from vllm.distributed.parallel_state import (ensure_model_parallel_initialized, init_distributed_environment) from vllm.engine.arg_utils import EngineArgs from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import SamplingParams, SequenceData, SequenceGroupMetadata from vllm.utils import get_open_port -from vllm.worker.model_runner import ModelRunner, _get_graph_batch_size +from vllm.worker.model_runner import ModelRunner def _create_model_runner(model: str, *args, **kwargs) -> ModelRunner: @@ -176,7 +177,7 @@ def test_prepare_decode_cuda_graph(batch_size): model_input.attn_metadata, model_input.attn_metadata.slot_mapping) assert len(slot_mapping) == len(input_tokens) - expected_bs = _get_graph_batch_size(len(seq_group_metadata_list)) + expected_bs = VllmConfig.get_graph_batch_size(len(seq_group_metadata_list)) # Verify input metadata is correct for prompts. device = model_runner.device assert attn_metadata.num_prefills == 0 diff --git a/vllm/compilation/backends.py b/vllm/compilation/backends.py index 464bc2af8fd6d..d49a83fe3981f 100644 --- a/vllm/compilation/backends.py +++ b/vllm/compilation/backends.py @@ -242,10 +242,6 @@ def __call__(self, graph: fx.GraphModule, example_inputs) -> Callable: assert not self._called, "VllmBackend can only be called once" self.graph = graph - # config is updated now, because only here can - # we get the sizes to capture for cudagraph - # from compilation context - self.compilation_configs.init_during_runtime() self.configure_post_pass() self.split_gm, self.piecewise_graphs = split_graph( diff --git a/vllm/compilation/compile_context.py b/vllm/compilation/compile_context.py deleted file mode 100644 index 29db3d4c637b9..0000000000000 --- a/vllm/compilation/compile_context.py +++ /dev/null @@ -1,23 +0,0 @@ -from contextlib import contextmanager -from typing import Any - -_compile_context: Any = None - - -def get_compile_context() -> Any: - """Get the current compile context.""" - return _compile_context - - -@contextmanager -def set_compile_context(context: Any): - """A context manager that stores the current compile context, - usually it is a list of sizes to specialize. - """ - global _compile_context - prev_context = _compile_context - _compile_context = context - try: - yield - finally: - _compile_context = prev_context diff --git a/vllm/config.py b/vllm/config.py index 5f50d65ec87e1..326340d3fa655 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -2357,15 +2357,10 @@ def init_backend(self) -> Union[str, Callable]: from vllm.compilation.backends import VllmBackend return VllmBackend(self) - def init_during_runtime(self): + def init_with_cudagraph_sizes(self, sizes_to_specialize: List[int]): """To complete the initialization of config, - we need to know the compile context, which is only available - during the first run of the model. - """ - from vllm.compilation.compile_context import get_compile_context - context = get_compile_context() - context = copy.deepcopy(context) if context is not None else [] - sizes_to_specialize: List[int] = context + we need to know the cudagraph sizes.""" + if self.cudagraph_capture_sizes is None: self.capture_sizes = sizes_to_specialize else: @@ -2386,6 +2381,21 @@ def init_during_runtime(self): self.inductor_compile_sizes = [] self.compile_sizes = self.inductor_compile_sizes + # sort to make sure cudagraph capture sizes are in descending order + self.capture_sizes.sort(reverse=True) + + +_BATCH_SIZE_ALIGNMENT = 8 +# all the token sizes that **can** be captured by cudagraph. +# they can be arbitrarily large. +# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. +# the actual sizes to capture will be determined by the model, +# depending on the model's max_num_seqs. +# NOTE: get_graph_batch_size needs to be updated if this list is changed. +_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ + _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) +] + @dataclass class VllmConfig: @@ -2413,6 +2423,41 @@ class VllmConfig: kv_transfer_config: KVTransferConfig = field(default=None, init=True) # type: ignore + @staticmethod + def get_graph_batch_size(batch_size: int) -> int: + """Returns the padded batch size given actual batch size. + + Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, + 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... + """ + if batch_size <= 2: + return batch_size + elif batch_size <= 4: + return 4 + else: + return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // + _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) + + @staticmethod + def get_max_graph_batch_size(max_num_seqs: int) -> int: + """ + max_num_seqs: Maximum number of sequences in a batch. + _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. + + pad the max_num_seqs if necessary by calling get_graph_batch_size, + which will deal with some edge cases like 1, 2, 4. + + if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded + size. if not, it means the padded size is larger than the largest size + in _BATCH_SIZES_TO_CAPTURE, return the largest size in + _BATCH_SIZES_TO_CAPTURE. + """ + padded_size = VllmConfig.get_graph_batch_size(max_num_seqs) + if padded_size in _BATCH_SIZES_TO_CAPTURE: + return padded_size + assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] + return _BATCH_SIZES_TO_CAPTURE[-1] + @staticmethod def _get_quantization_config( model_config: ModelConfig, @@ -2496,6 +2541,28 @@ def __post_init__(self): self.compilation_config.pass_config.enable_reshape = False self.compilation_config.level = CompilationLevel.PIECEWISE + if not envs.VLLM_USE_V1: + max_batchsize_to_capture = 0 + if self.scheduler_config is not None and \ + self.model_config is not None and \ + not self.model_config.enforce_eager: + max_batchsize_to_capture = \ + self.get_max_graph_batch_size( + self.scheduler_config.max_num_seqs) + batch_size_capture_list = [ + size for size in _BATCH_SIZES_TO_CAPTURE + if size <= max_batchsize_to_capture + ] + else: + batch_size_capture_list = [] + if self.model_config is not None and \ + not self.model_config.enforce_eager: + batch_size_capture_list = [1, 2, 4 + ] + [i for i in range(8, 513, 8)] + + self.compilation_config.init_with_cudagraph_sizes( + batch_size_capture_list) + if self.cache_config is not None and \ self.cache_config.cpu_offload_gb > 0 and \ self.compilation_config.level != CompilationLevel.NO_COMPILATION: diff --git a/vllm/model_executor/models/jamba.py b/vllm/model_executor/models/jamba.py index 099ca7e12b288..5d5e8ae1ee532 100644 --- a/vllm/model_executor/models/jamba.py +++ b/vllm/model_executor/models/jamba.py @@ -7,7 +7,7 @@ from vllm.attention.backends.abstract import AttentionMetadata from vllm.attention.layer import Attention -from vllm.config import CacheConfig, VllmConfig +from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.fused_moe import FusedMoE from vllm.model_executor.layers.layernorm import RMSNorm @@ -25,8 +25,6 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, - _get_graph_batch_size) from .interfaces import HasInnerState, SupportsLoRA from .utils import maybe_prefix @@ -404,7 +402,7 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (_get_graph_batch_size( + max_batch_size = (VllmConfig.get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) diff --git a/vllm/model_executor/models/mamba.py b/vllm/model_executor/models/mamba.py index ac0d265a961f0..b32032e411b0a 100644 --- a/vllm/model_executor/models/mamba.py +++ b/vllm/model_executor/models/mamba.py @@ -6,7 +6,7 @@ from transformers import MambaConfig from vllm.attention.backends.abstract import AttentionMetadata -from vllm.config import CacheConfig, VllmConfig +from vllm.config import _BATCH_SIZES_TO_CAPTURE, CacheConfig, VllmConfig from vllm.distributed import get_tensor_model_parallel_world_size from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.logits_processor import LogitsProcessor @@ -23,8 +23,6 @@ MambaCacheParams) from vllm.model_executor.sampling_metadata import SamplingMetadata from vllm.sequence import IntermediateTensors -from vllm.worker.model_runner import (_BATCH_SIZES_TO_CAPTURE, - _get_graph_batch_size) from .utils import maybe_prefix @@ -187,7 +185,7 @@ def forward(self, inputs_embeds: Optional[torch.Tensor] = None, **kwargs): if self.mamba_cache is None: - max_batch_size = (_get_graph_batch_size( + max_batch_size = (VllmConfig.get_graph_batch_size( self.scheduler_config.max_num_seqs) if self.scheduler_config else max(_BATCH_SIZES_TO_CAPTURE) + 2) self.mamba_cache = MambaCacheManager( diff --git a/vllm/v1/worker/gpu_model_runner.py b/vllm/v1/worker/gpu_model_runner.py index 1fa47f553dfd6..4692762493f00 100644 --- a/vllm/v1/worker/gpu_model_runner.py +++ b/vllm/v1/worker/gpu_model_runner.py @@ -8,7 +8,6 @@ import torch.distributed import torch.nn as nn -from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig from vllm.distributed.parallel_state import graph_capture from vllm.forward_context import set_forward_context @@ -100,7 +99,11 @@ def __init__( == CompilationLevel.PIECEWISE and not self.model_config.enforce_eager) # TODO(woosuk): Provide an option to tune the max cudagraph batch size. - self.cudagraph_batch_sizes = [1, 2, 4] + [i for i in range(8, 513, 8)] + # The convention is different. + # self.cudagraph_batch_sizes sorts in ascending order. + # The batch sizes in the config are in descending order. + self.cudagraph_batch_sizes = list( + reversed(self.vllm_config.compilation_config.capture_sizes)) self.positions = torch.zeros(self.max_num_tokens, dtype=torch.int64, device=self.device) @@ -548,10 +551,9 @@ def profile_run(self) -> None: torch.tensor([], dtype=torch.float32, device=self.device) for _ in range(self.num_attn_layers) ] - with set_compile_context(self.cudagraph_batch_sizes): - # Trigger compilation for general shape. - hidden_states = self._dummy_run(self.model, self.max_num_tokens, - dummy_kv_caches) + # Trigger compilation for general shape. + hidden_states = self._dummy_run(self.model, self.max_num_tokens, + dummy_kv_caches) logits = self.model.compute_logits(hidden_states, None) logits = logits[:self.max_num_tokens] # TODO(woosuk): Consider the memory usage of the sampler. diff --git a/vllm/worker/enc_dec_model_runner.py b/vllm/worker/enc_dec_model_runner.py index ae18c79c980c8..5697fbbaa2041 100644 --- a/vllm/worker/enc_dec_model_runner.py +++ b/vllm/worker/enc_dec_model_runner.py @@ -25,8 +25,7 @@ from vllm.utils import STR_NOT_IMPL_ENC_DEC_BACKEND, make_tensor_with_pad from vllm.worker.model_runner import (GPUModelRunnerBase, ModelInputForGPUBuilder, - ModelInputForGPUWithSamplingMetadata, - _get_graph_batch_size) + ModelInputForGPUWithSamplingMetadata) from vllm.worker.model_runner_base import ( _add_attn_metadata_broadcastable_dict, _add_sampling_metadata_broadcastable_dict) @@ -465,7 +464,8 @@ def _prepare_encoder_model_input_tensors( # We will be using CUDA graph replay for this decode. max_len_of_block_table = self.get_max_block_per_batch() batch_size = len(encoder_seq_lens) - graph_batch_size = _get_graph_batch_size(batch_size) + graph_batch_size = self.vllm_config.get_graph_batch_size( + batch_size) assert graph_batch_size >= batch_size cuda_graph_pad_size = graph_batch_size - batch_size # extend the cross_block_tables and encoder_seq_lens to match diff --git a/vllm/worker/model_runner.py b/vllm/worker/model_runner.py index c9f06eef3f907..4388b3c1ee164 100644 --- a/vllm/worker/model_runner.py +++ b/vllm/worker/model_runner.py @@ -18,7 +18,6 @@ from vllm.attention import AttentionMetadata, get_attn_backend from vllm.attention.backends.abstract import AttentionState from vllm.attention.backends.utils import CommonAttentionState -from vllm.compilation.compile_context import set_compile_context from vllm.config import CompilationLevel, VllmConfig from vllm.core.scheduler import SchedulerOutputs from vllm.distributed import get_kv_transfer_group, get_pp_group @@ -63,16 +62,7 @@ logger = init_logger(__name__) LORA_WARMUP_RANK = 8 -_BATCH_SIZE_ALIGNMENT = 8 -# all the token sizes that **can** be captured by cudagraph. -# they can be arbitrarily large. -# currently it includes: 1, 2, 4, 8, 16, 24, 32, 40, ..., 8192. -# the actual sizes to capture will be determined by the model, -# depending on the model's max_num_seqs. -# NOTE: _get_graph_batch_size needs to be updated if this list is changed. -_BATCH_SIZES_TO_CAPTURE = [1, 2, 4] + [ - _BATCH_SIZE_ALIGNMENT * i for i in range(1, 1025) -] + _NUM_WARMUP_ITERS = 2 TModelInputForGPU = TypeVar('TModelInputForGPU', bound="ModelInputForGPU") @@ -763,7 +753,6 @@ def _use_captured_graph(self, max_decode_seq_len: int, max_encoder_seq_len: int = 0) -> bool: return (decode_only and not self.runner.model_config.enforce_eager - and batch_size <= _BATCH_SIZES_TO_CAPTURE[-1] and max_decode_seq_len <= self.runner.max_seq_len_to_capture and max_encoder_seq_len <= self.runner.max_seq_len_to_capture and batch_size <= self.runner.max_batchsize_to_capture) @@ -811,7 +800,7 @@ def _get_cuda_graph_pad_size(self, max_encoder_seq_len): return -1 - graph_batch_size = _get_graph_batch_size(batch_size) + graph_batch_size = VllmConfig.get_graph_batch_size(batch_size) assert graph_batch_size >= batch_size return graph_batch_size - batch_size @@ -1023,7 +1012,7 @@ def __init__( self.sliding_window = model_config.get_sliding_window() self.block_size = cache_config.block_size self.max_seq_len_to_capture = self.model_config.max_seq_len_to_capture - self.max_batchsize_to_capture = _get_max_graph_batch_size( + self.max_batchsize_to_capture = VllmConfig.get_max_graph_batch_size( self.scheduler_config.max_num_seqs) self.graph_runners: List[Dict[int, CUDAGraphRunner]] = [ @@ -1333,14 +1322,7 @@ def profile_run(self) -> None: dtype=self.model_config.dtype, device=self.device) - graph_batch_size = self.max_batchsize_to_capture - batch_size_capture_list = [ - bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size - ] - if self.model_config.enforce_eager: - batch_size_capture_list = [] - with set_compile_context(batch_size_capture_list): - self.execute_model(model_input, kv_caches, intermediate_tensors) + self.execute_model(model_input, kv_caches, intermediate_tensors) torch.cuda.synchronize() return @@ -1459,18 +1441,14 @@ def capture_model(self, kv_caches: List[List[torch.Tensor]]) -> None: dtype=self.model_config.dtype, device=self.device) - graph_batch_size = self.max_batchsize_to_capture - batch_size_capture_list = [ - bs for bs in _BATCH_SIZES_TO_CAPTURE if bs <= graph_batch_size - ] - with self.attn_state.graph_capture( max_batch_size), graph_capture() as graph_capture_context: # NOTE: Capturing the largest batch size first may help reduce the # memory usage of CUDA graph. for virtual_engine in range( self.parallel_config.pipeline_parallel_size): - for batch_size in reversed(batch_size_capture_list): + for batch_size in \ + self.vllm_config.compilation_config.capture_sizes: attn_metadata = ( self.attn_state.graph_capture_get_metadata_for_batch( batch_size, @@ -1993,37 +1971,3 @@ def forward( return self.output_buffers["hidden_states"] return self.output_buffers - - -def _get_graph_batch_size(batch_size: int) -> int: - """Returns the padded batch size given actual batch size. - - Batch sizes are 1, 2, 4, _BATCH_SIZE_ALIGNMENT, - 2*_BATCH_SIZE_ALIGNMENT, 3*_BATCH_SIZE_ALIGNMENT... - """ - if batch_size <= 2: - return batch_size - elif batch_size <= 4: - return 4 - else: - return ((batch_size + _BATCH_SIZE_ALIGNMENT - 1) // - _BATCH_SIZE_ALIGNMENT * _BATCH_SIZE_ALIGNMENT) - - -def _get_max_graph_batch_size(max_num_seqs: int) -> int: - """ - max_num_seqs: Maximum number of sequences in a batch. - _BATCH_SIZES_TO_CAPTURE: all the sizes that we want to capture. - - pad the max_num_seqs if necessary by calling _get_graph_batch_size, - which will deal with some edge cases like 1, 2, 4. - - if the padded size is in _BATCH_SIZES_TO_CAPTURE, return the padded size. - if not, it means the padded size is larger than the largest size in - _BATCH_SIZES_TO_CAPTURE, return the largest size in _BATCH_SIZES_TO_CAPTURE. - """ - padded_size = _get_graph_batch_size(max_num_seqs) - if padded_size in _BATCH_SIZES_TO_CAPTURE: - return padded_size - assert padded_size > _BATCH_SIZES_TO_CAPTURE[-1] - return _BATCH_SIZES_TO_CAPTURE[-1] From ef51831ee8dbd64833b25e042d4e984d169202f9 Mon Sep 17 00:00:00 2001 From: Russell Bryant Date: Tue, 3 Dec 2024 01:46:07 -0500 Subject: [PATCH 10/13] [Doc] Add github links for source code references (#10672) Signed-off-by: Russell Bryant Signed-off-by: DarkLight1337 Co-authored-by: DarkLight1337 --- docs/requirements-docs.txt | 3 +- docs/source/conf.py | 66 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 68 insertions(+), 1 deletion(-) diff --git a/docs/requirements-docs.txt b/docs/requirements-docs.txt index 8ea240f59c38f..5c80645b405ae 100644 --- a/docs/requirements-docs.txt +++ b/docs/requirements-docs.txt @@ -16,4 +16,5 @@ mistral_common >= 1.5.0 aiohttp starlette openai # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args -partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args \ No newline at end of file +partial-json-parser # Required by docs/source/serving/openai_compatible_server.md's vllm.entrypoints.openai.cli_args +requests diff --git a/docs/source/conf.py b/docs/source/conf.py index 96ad9a4c26b09..4a1a5fb455ff3 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -10,11 +10,13 @@ # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. +import inspect import logging import os import sys from typing import List +import requests from sphinx.ext import autodoc logger = logging.getLogger(__name__) @@ -34,6 +36,7 @@ extensions = [ "sphinx.ext.napoleon", "sphinx.ext.viewcode", + "sphinx.ext.linkcode", "sphinx.ext.intersphinx", "sphinx_copybutton", "sphinx.ext.autodoc", @@ -94,6 +97,69 @@ def setup(app): generate_examples() +_cached_base: str = "" +_cached_branch: str = "" + + +def get_repo_base_and_branch(pr_number): + global _cached_base, _cached_branch + if _cached_base and _cached_branch: + return _cached_base, _cached_branch + + url = f"https://api.github.com/repos/vllm-project/vllm/pulls/{pr_number}" + response = requests.get(url) + if response.status_code == 200: + data = response.json() + _cached_base = data['head']['repo']['full_name'] + _cached_branch = data['head']['ref'] + return _cached_base, _cached_branch + else: + logger.error("Failed to fetch PR details: %s", response) + return None, None + + +def linkcode_resolve(domain, info): + if domain != 'py': + return None + if not info['module']: + return None + filename = info['module'].replace('.', '/') + module = info['module'] + + # try to determine the correct file and line number to link to + obj = sys.modules[module] + + # get as specific as we can + lineno: int = 0 + filename: str = "" + try: + for part in info['fullname'].split('.'): + obj = getattr(obj, part) + + if not (inspect.isclass(obj) or inspect.isfunction(obj) + or inspect.ismethod(obj)): + obj = obj.__class__ # Get the class of the instance + + lineno = inspect.getsourcelines(obj)[1] + filename = (inspect.getsourcefile(obj) + or f"{filename}.py").split("vllm/", 1)[1] + except Exception: + # For some things, like a class member, won't work, so + # we'll use the line number of the parent (the class) + pass + + if filename.startswith("checkouts/"): + # a PR build on readthedocs + pr_number = filename.split("/")[1] + filename = filename.split("/", 2)[2] + base, branch = get_repo_base_and_branch(pr_number) + if base and branch: + return f"https://github.com/{base}/blob/{branch}/{filename}#L{lineno}" + + # Otherwise, link to the source file on the main branch + return f"https://github.com/vllm-project/vllm/blob/main/{filename}#L{lineno}" + + # Mock out external dependencies here, otherwise the autodoc pages may be blank. autodoc_mock_imports = [ "compressed_tensors", From 3257d449fa0fd3e05aa20cc8c5fff79ad101984f Mon Sep 17 00:00:00 2001 From: Cyrus Leung Date: Tue, 3 Dec 2024 14:52:57 +0800 Subject: [PATCH 11/13] [Misc] Remove deprecated names (#10817) Signed-off-by: DarkLight1337 --- vllm/engine/async_llm_engine.py | 8 +++++-- vllm/engine/llm_engine.py | 5 ++-- vllm/engine/multiprocessing/__init__.py | 5 +++- vllm/engine/multiprocessing/client.py | 7 ++++-- vllm/entrypoints/llm.py | 11 +++++++++ vllm/inputs/__init__.py | 31 ------------------------- vllm/inputs/data.py | 31 ------------------------- vllm/model_executor/models/aria.py | 5 ++-- vllm/multimodal/__init__.py | 15 ------------ vllm/multimodal/base.py | 15 ------------ 10 files changed, 31 insertions(+), 102 deletions(-) diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 7b1bb7b05708d..4395588d29cda 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -6,6 +6,8 @@ List, Mapping, Optional, Set, Tuple, Type, Union, overload) from weakref import ReferenceType +from typing_extensions import deprecated + import vllm.envs as envs from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, ParallelConfig, SchedulerConfig, VllmConfig) @@ -422,7 +424,8 @@ async def get_tokenizer_async(self, return await ( self.get_tokenizer_group().get_lora_tokenizer_async(lora_request)) - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") async def add_request_async( self, request_id: str, @@ -894,7 +897,8 @@ async def run_engine_loop(engine_ref: ReferenceType): # This method does not need to be async, but kept that way # for backwards compatibility. - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") def add_request( self, request_id: str, diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7911dc8d04500..dd55aa2818621 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -10,7 +10,7 @@ from typing import Set, Type, Union, cast, overload import torch -from typing_extensions import TypeVar +from typing_extensions import TypeVar, deprecated import vllm.envs as envs from vllm.config import (DecodingConfig, LoRAConfig, ModelConfig, @@ -719,7 +719,8 @@ def _add_processed_request( def stop_remote_worker_execution_loop(self) -> None: self.model_executor.stop_remote_worker_execution_loop() - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") def add_request( self, request_id: str, diff --git a/vllm/engine/multiprocessing/__init__.py b/vllm/engine/multiprocessing/__init__.py index 34c161e9395ae..7020012e8bb86 100644 --- a/vllm/engine/multiprocessing/__init__.py +++ b/vllm/engine/multiprocessing/__init__.py @@ -2,6 +2,8 @@ from enum import Enum from typing import List, Mapping, Optional, Union, overload +from typing_extensions import deprecated + from vllm import PoolingParams from vllm.inputs import PromptType from vllm.lora.request import LoRARequest @@ -32,7 +34,8 @@ class RPCProcessRequest: prompt_adapter_request: Optional[PromptAdapterRequest] = None priority: int = 0 - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") def __init__( self, *, diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index d26728e8c6e67..8383e774db20f 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -9,6 +9,7 @@ import psutil import zmq import zmq.asyncio +from typing_extensions import deprecated from zmq import Frame # type: ignore[attr-defined] from zmq.asyncio import Socket @@ -414,7 +415,8 @@ def errored(self) -> bool: def dead_error(self) -> BaseException: return ENGINE_DEAD_ERROR(self._errored_with) - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") def generate( self, *, @@ -485,7 +487,8 @@ def generate( lora_request, trace_headers, prompt_adapter_request, priority) - @overload # DEPRECATED + @overload + @deprecated("'inputs' will be renamed to 'prompt") def encode( self, *, diff --git a/vllm/entrypoints/llm.py b/vllm/entrypoints/llm.py index a25c401b4ea10..65fa9873df28c 100644 --- a/vllm/entrypoints/llm.py +++ b/vllm/entrypoints/llm.py @@ -6,6 +6,7 @@ Union, cast, overload) from tqdm import tqdm +from typing_extensions import deprecated from vllm import envs from vllm.beam_search import (BeamSearchInstance, BeamSearchOutput, @@ -256,6 +257,7 @@ def set_tokenizer(self, tokenizer: AnyTokenizer) -> None: tokenizer_group.tokenizer = get_cached_tokenizer(tokenizer) @overload # LEGACY: single (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts") def generate( self, prompts: str, @@ -268,6 +270,7 @@ def generate( ... @overload # LEGACY: multi (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts") def generate( self, prompts: List[str], @@ -280,6 +283,7 @@ def generate( ... @overload # LEGACY: single (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts") def generate( self, prompts: Optional[str] = None, @@ -293,6 +297,7 @@ def generate( ... @overload # LEGACY: multi (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts") def generate( self, prompts: Optional[List[str]] = None, @@ -306,6 +311,7 @@ def generate( ... @overload # LEGACY: single or multi token ids [pos-only] + @deprecated("'prompt_token_ids' will become part of 'prompts") def generate( self, prompts: None, @@ -671,6 +677,7 @@ def chat( ) @overload # LEGACY: single (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts") def encode( self, prompts: str, @@ -683,6 +690,7 @@ def encode( ... @overload # LEGACY: multi (prompt + optional token ids) + @deprecated("'prompt_token_ids' will become part of 'prompts") def encode( self, prompts: List[str], @@ -695,6 +703,7 @@ def encode( ... @overload # LEGACY: single (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts") def encode( self, prompts: Optional[str] = None, @@ -708,6 +717,7 @@ def encode( ... @overload # LEGACY: multi (token ids + optional prompt) + @deprecated("'prompt_token_ids' will become part of 'prompts") def encode( self, prompts: Optional[List[str]] = None, @@ -721,6 +731,7 @@ def encode( ... @overload # LEGACY: single or multi token ids [pos-only] + @deprecated("'prompt_token_ids' will become part of 'prompts") def encode( self, prompts: None, diff --git a/vllm/inputs/__init__.py b/vllm/inputs/__init__.py index 54fbd7a321a6f..d4402e77a3886 100644 --- a/vllm/inputs/__init__.py +++ b/vllm/inputs/__init__.py @@ -38,34 +38,3 @@ "InputProcessingContext", "InputRegistry", ] - - -def __getattr__(name: str): - import warnings - - if name == "PromptInput": - msg = ("PromptInput has been renamed to PromptType. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return PromptType - - if name == "LLMInputs": - msg = ("LLMInputs has been renamed to DecoderOnlyInputs. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return DecoderOnlyInputs - - if name == "EncoderDecoderLLMInputs": - msg = ( - "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return EncoderDecoderInputs - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/inputs/data.py b/vllm/inputs/data.py index fb7dbbebd7b90..e8fc78f1a66f6 100644 --- a/vllm/inputs/data.py +++ b/vllm/inputs/data.py @@ -358,34 +358,3 @@ def to_enc_dec_tuple_list( return [(enc_dec_prompt["encoder_prompt"], enc_dec_prompt["decoder_prompt"]) for enc_dec_prompt in enc_dec_prompts] - - -def __getattr__(name: str): - import warnings - - if name == "PromptInput": - msg = ("PromptInput has been renamed to PromptType. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return PromptType - - if name == "LLMInputs": - msg = ("LLMInputs has been renamed to DecoderOnlyInputs. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return DecoderOnlyInputs - - if name == "EncoderDecoderLLMInputs": - msg = ( - "EncoderDecoderLLMInputs has been renamed to EncoderDecoderInputs. " - "The original name will be removed in an upcoming version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return EncoderDecoderInputs - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/model_executor/models/aria.py b/vllm/model_executor/models/aria.py index fa6b95f5481ad..dd4b0c75cb84d 100644 --- a/vllm/model_executor/models/aria.py +++ b/vllm/model_executor/models/aria.py @@ -32,9 +32,8 @@ maybe_prefix, merge_multimodal_embeddings) from vllm.multimodal import MULTIMODAL_REGISTRY -from vllm.multimodal.base import MultiModalInputs from vllm.multimodal.image import cached_get_image_processor -from vllm.multimodal.inputs import NestedTensors +from vllm.multimodal.inputs import MultiModalKwargs, NestedTensors from vllm.multimodal.utils import (cached_get_tokenizer, repeat_and_pad_placeholder_tokens) from vllm.sequence import IntermediateTensors @@ -451,7 +450,7 @@ def get_max_multimodal_tokens(ctx): def input_mapper_for_aria(ctx, data): - return MultiModalInputs(data) + return MultiModalKwargs(data) def input_processor(ctx, llm_inputs): diff --git a/vllm/multimodal/__init__.py b/vllm/multimodal/__init__.py index 03a5f3a91f7a1..928c31a2f2843 100644 --- a/vllm/multimodal/__init__.py +++ b/vllm/multimodal/__init__.py @@ -27,18 +27,3 @@ "MULTIMODAL_REGISTRY", "MultiModalRegistry", ] - - -def __getattr__(name: str): - import warnings - - if name == "MultiModalInputs": - msg = ("MultiModalInputs has been renamed to MultiModalKwargs. " - "The original name will take another meaning in an upcoming " - "version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return MultiModalKwargs - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") diff --git a/vllm/multimodal/base.py b/vllm/multimodal/base.py index bbb8fb4bc1cd1..f93722523728d 100644 --- a/vllm/multimodal/base.py +++ b/vllm/multimodal/base.py @@ -433,18 +433,3 @@ def index_map(self) -> "IndexMap": return MultiModalPlaceholderMap.IndexMap(src=src_indices, dest=dest_indices) - - -def __getattr__(name: str): - import warnings - - if name == "MultiModalInputs": - msg = ("MultiModalInputs has been renamed to MultiModalKwargs. " - "The original name will take another meaning in an upcoming " - "version.") - - warnings.warn(DeprecationWarning(msg), stacklevel=2) - - return MultiModalKwargs - - raise AttributeError(f"module {__name__!r} has no attribute {name!r}") From 9323a3153b20d4a2ca7ac04a2784609d6ce656e0 Mon Sep 17 00:00:00 2001 From: Aaron Pham Date: Tue, 3 Dec 2024 02:17:00 -0500 Subject: [PATCH 12/13] [Core][Performance] Add XGrammar support for guided decoding and set it as default (#10785) Signed-off-by: Aaron Pham Signed-off-by: mgoin Co-authored-by: mgoin --- docs/source/conf.py | 1 + requirements-common.txt | 1 + tests/entrypoints/llm/test_guided_generate.py | 27 ++ .../model_executor/test_guided_processors.py | 3 +- vllm/config.py | 15 +- vllm/engine/arg_utils.py | 9 +- vllm/engine/async_llm_engine.py | 18 +- vllm/engine/llm_engine.py | 15 +- vllm/engine/multiprocessing/client.py | 5 +- .../guided_decoding/__init__.py | 73 ++++- .../guided_decoding/xgrammar_decoding.py | 251 ++++++++++++++++++ 11 files changed, 385 insertions(+), 33 deletions(-) create mode 100644 vllm/model_executor/guided_decoding/xgrammar_decoding.py diff --git a/docs/source/conf.py b/docs/source/conf.py index 4a1a5fb455ff3..e9d9ac68c9560 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -178,6 +178,7 @@ def linkcode_resolve(domain, info): "tensorizer", "pynvml", "outlines", + "xgrammar," "librosa", "soundfile", "gguf", diff --git a/requirements-common.txt b/requirements-common.txt index 02e3d65fb774c..818f72e14be96 100644 --- a/requirements-common.txt +++ b/requirements-common.txt @@ -19,6 +19,7 @@ prometheus-fastapi-instrumentator >= 7.0.0 tiktoken >= 0.6.0 # Required for DBRX tokenizer lm-format-enforcer >= 0.10.9, < 0.11 outlines >= 0.0.43, < 0.1 +xgrammar typing_extensions >= 4.10 filelock >= 3.16.1 # need to contain https://github.com/tox-dev/filelock/pull/317 partial-json-parser # used for parsing partial JSON outputs diff --git a/tests/entrypoints/llm/test_guided_generate.py b/tests/entrypoints/llm/test_guided_generate.py index 67c79415f322a..c3706f696b264 100644 --- a/tests/entrypoints/llm/test_guided_generate.py +++ b/tests/entrypoints/llm/test_guided_generate.py @@ -159,3 +159,30 @@ def test_validation_against_both_guided_decoding_options(sample_regex, llm): sampling_params=sampling_params, use_tqdm=True, guided_options_request=dict(guided_regex=sample_regex)) + + +@pytest.mark.skip_global_cleanup +def test_guided_json_object(llm): + sampling_params = SamplingParams( + temperature=1.0, + max_tokens=100, + guided_decoding=GuidedDecodingParams(json_object=True)) + + outputs = llm.generate( + prompts=("Generate a JSON object describing a person with name " + "and age for John Smith who is 31 years old."), + sampling_params=sampling_params, + use_tqdm=True) + + assert outputs is not None + for output in outputs: + assert output is not None + assert isinstance(output, RequestOutput) + + generated_text = output.outputs[0].text + print(generated_text) + assert generated_text is not None + + # Parse to verify it is valid JSON + parsed_json = json.loads(generated_text) + assert isinstance(parsed_json, dict) diff --git a/tests/model_executor/test_guided_processors.py b/tests/model_executor/test_guided_processors.py index 45fab8e96b968..9f4d81b583141 100644 --- a/tests/model_executor/test_guided_processors.py +++ b/tests/model_executor/test_guided_processors.py @@ -36,7 +36,8 @@ def test_guided_logits_processors(sample_regex, sample_json_schema): @pytest.mark.asyncio -@pytest.mark.parametrize("backend", ["outlines", "lm-format-enforcer"]) +@pytest.mark.parametrize("backend", + ["outlines", "lm-format-enforcer", "xgrammar"]) async def test_guided_logits_processor_black_box(backend: str, sample_regex, sample_json_schema): tokenizer = AutoTokenizer.from_pretrained('HuggingFaceH4/zephyr-7b-beta') diff --git a/vllm/config.py b/vllm/config.py index 326340d3fa655..971eb36d677b8 100644 --- a/vllm/config.py +++ b/vllm/config.py @@ -1789,15 +1789,15 @@ class PoolerConfig: step_tag_id: Optional[int] = None """ - If set, only the score corresponding to the ``step_tag_id`` in the + If set, only the score corresponding to the ``step_tag_id`` in the generated sentence should be returned. Otherwise, the scores for all tokens are returned. """ returned_token_ids: Optional[List[int]] = None """ - A list of indices for the vocabulary dimensions to be extracted, - such as the token IDs of ``good_token`` and ``bad_token`` in the + A list of indices for the vocabulary dimensions to be extracted, + such as the token IDs of ``good_token`` and ``bad_token`` in the ``math-shepherd-mistral-7b-prm`` model. """ @@ -2031,11 +2031,12 @@ def get_served_model_name(model: str, class DecodingConfig: """Dataclass which contains the decoding strategy of the engine""" - # Which guided decoding algo to use. 'outlines' / 'lm-format-enforcer' - guided_decoding_backend: str = 'outlines' + # Which guided decoding algo to use. + # 'outlines' / 'lm-format-enforcer' / 'xgrammar' + guided_decoding_backend: str = 'xgrammar' def __post_init__(self): - valid_guided_backends = ['outlines', 'lm-format-enforcer'] + valid_guided_backends = ['outlines', 'lm-format-enforcer', 'xgrammar'] backend = self.guided_decoding_backend if backend not in valid_guided_backends: raise ValueError(f"Invalid guided_decoding_backend '{backend}," @@ -2222,7 +2223,7 @@ class CompilationConfig(BaseModel): from Python, functions can also be passed directly via Python object constructor, e.g. `CompilationConfig(inductor_passes={"a": func})` - custom inductor passes: see PassConfig for more details - + Why we have different sizes for cudagraph and inductor: - cudagraph: a cudagraph captured for a specific size can only be used for the same size. We need to capture all the sizes we want to use. diff --git a/vllm/engine/arg_utils.py b/vllm/engine/arg_utils.py index 4aa0eebd976c9..3b776c1d9d39f 100644 --- a/vllm/engine/arg_utils.py +++ b/vllm/engine/arg_utils.py @@ -168,7 +168,7 @@ class EngineArgs: scheduler_delay_factor: float = 0.0 enable_chunked_prefill: Optional[bool] = None - guided_decoding_backend: str = 'outlines' + guided_decoding_backend: str = 'xgrammar' # Speculative decoding configuration. speculative_model: Optional[str] = None speculative_model_quantization: Optional[str] = None @@ -364,11 +364,12 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: parser.add_argument( '--guided-decoding-backend', type=str, - default='outlines', - choices=['outlines', 'lm-format-enforcer'], + default='xgrammar', + choices=['outlines', 'lm-format-enforcer', 'xgrammar'], help='Which engine will be used for guided decoding' ' (JSON schema / regex etc) by default. Currently support ' - 'https://github.com/outlines-dev/outlines and ' + 'https://github.com/outlines-dev/outlines,' + 'https://github.com/mlc-ai/xgrammar, and ' 'https://github.com/noamgat/lm-format-enforcer.' ' Can be overridden per request via guided_decoding_backend' ' parameter.') diff --git a/vllm/engine/async_llm_engine.py b/vllm/engine/async_llm_engine.py index 4395588d29cda..60dccd7a0812c 100644 --- a/vllm/engine/async_llm_engine.py +++ b/vllm/engine/async_llm_engine.py @@ -1,4 +1,5 @@ import asyncio +import copy import time import weakref from functools import partial @@ -507,7 +508,8 @@ async def add_request_async( sampling_params=params, tokenizer=await self.get_tokenizer_async(lora_request), default_guided_backend=self.decoding_config. - guided_decoding_backend) + guided_decoding_backend, + model_config=self.model_config) self._add_processed_request( request_id=request_id, @@ -528,22 +530,30 @@ async def check_health_async(self) -> None: async def build_guided_decoding_logits_processor_async( sampling_params: SamplingParams, tokenizer: AnyTokenizer, - default_guided_backend: str) -> SamplingParams: + default_guided_backend: str, + model_config: ModelConfig) -> SamplingParams: """Constructs logits processors based on the guided_decoding, logits_bias, and allowed_token_ids fields in sampling_params. Deletes those fields and adds the constructed logits processors to the logits_processors field. Modifies sampling params in-place and returns the modified sampling params.""" - if (guided_decoding := sampling_params.guided_decoding) is None: + if sampling_params.guided_decoding is None: return sampling_params + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding + logger.debug("Building guided decoding logits processor. " "Params: %s", guided_decoding) guided_decoding.backend = guided_decoding.backend or default_guided_backend processor = await get_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=model_config) if processor: if sampling_params.logits_processors is None: diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index dd55aa2818621..af66b307028cf 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -1,3 +1,4 @@ +import copy import time from collections import Counter as collectionsCounter from collections import deque @@ -1024,9 +1025,9 @@ def _update_num_computed_tokens_for_multi_step_prefill( This function updates num_computed_tokens for prompt sequences when Multi-Step is enabled. - seq_group: SequenceGroup to update the num_computed_tokens for. + seq_group: SequenceGroup to update the num_computed_tokens for. seq_group_meta: Metadata of the given SequenceGroup. - is_first_step_output: Optional[bool] - + is_first_step_output: Optional[bool] - When available, is_first_step_output indicates if the appended output token is the output of the first-step in multi-step. A value of None indicates that outputs from all steps in @@ -2036,7 +2037,11 @@ def _build_logits_processors( logits_processors = [] - if (guided_decoding := sampling_params.guided_decoding) is not None: + if sampling_params.guided_decoding is not None: + # Defensively copy sampling params since guided decoding logits + # processors can have different state for each request + sampling_params = copy.copy(sampling_params) + guided_decoding = sampling_params.guided_decoding logger.debug( "Building guided decoding logits processor in " @@ -2047,7 +2052,9 @@ def _build_logits_processors( self.decoding_config.guided_decoding_backend processor = get_local_guided_decoding_logits_processor( - guided_params=guided_decoding, tokenizer=tokenizer) + guided_params=guided_decoding, + tokenizer=tokenizer, + model_config=self.model_config) if processor: logits_processors.append(processor) diff --git a/vllm/engine/multiprocessing/client.py b/vllm/engine/multiprocessing/client.py index 8383e774db20f..d21136c03d7d2 100644 --- a/vllm/engine/multiprocessing/client.py +++ b/vllm/engine/multiprocessing/client.py @@ -474,8 +474,8 @@ def generate( trace_headers: OpenTelemetry trace headers. prompt_adapter_request: Prompt Adapter request to use for generation, if any. - priority: Priority of the request (lower means earlier handling). - Any priority other than 0 will lead to an error if the + priority: Priority of the request (lower means earlier handling). + Any priority other than 0 will lead to an error if the scheduling policy is not "priority". """ if inputs is not None: @@ -589,6 +589,7 @@ async def _process_request( default_guided_backend=(self.decoding_config.guided_decoding_backend if self.decoding_config else DecodingConfig.guided_decoding_backend), + model_config=self.model_config ) # 1) Create output queue for this requests. diff --git a/vllm/model_executor/guided_decoding/__init__.py b/vllm/model_executor/guided_decoding/__init__.py index d7b67425fcbc0..23c31fcfd7f05 100644 --- a/vllm/model_executor/guided_decoding/__init__.py +++ b/vllm/model_executor/guided_decoding/__init__.py @@ -1,14 +1,54 @@ -from typing import Optional +from __future__ import annotations -from vllm.logits_process import LogitsProcessor -from vllm.sampling_params import GuidedDecodingParams +from typing import TYPE_CHECKING + +from vllm.logger import init_logger + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.config import ModelConfig + from vllm.logits_process import LogitsProcessor + from vllm.sampling_params import GuidedDecodingParams + +logger = init_logger(__name__) + + +def maybe_backend_fallback( + guided_params: GuidedDecodingParams) -> GuidedDecodingParams: + # lm-format-enforce doesn't support grammar, fallback to xgrammar + if (guided_params.backend == "lm-format-enforcer" + and guided_params.grammar is not None): + logger.warning( + "lm-format-enforcer does not support grammar guided decoding. " + "Falling back to use xgrammar instead.") + guided_params.backend = "xgrammar" + + if guided_params.backend == "xgrammar": + # xgrammar doesn't support regex or choice, fallback to outlines + if guided_params.regex is not None or guided_params.choice is not None: + logger.warning( + "xgrammar only supports json or grammar guided decoding. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + + # xgrammar only supports EBNF grammars and uses the GBNF format + # https://github.com/ggerganov/llama.cpp/blob/master/grammars/README.md + elif (guided_params.grammar is not None + and "::=" not in guided_params.grammar): + logger.warning("xgrammar only supports EBNF grammars. " + "Falling back to use outlines instead.") + guided_params.backend = "outlines" + + return guided_params async def get_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer) -> Optional[LogitsProcessor]: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: + guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend == 'outlines' or guided_params.grammar: + if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_outlines_guided_decoding_logits_processor) @@ -19,17 +59,23 @@ async def get_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " - "Must be one of 'outlines, 'lm-format-enforcer'") + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") def get_local_guided_decoding_logits_processor( - guided_params: GuidedDecodingParams, - tokenizer) -> Optional[LogitsProcessor]: + guided_params: GuidedDecodingParams, tokenizer: PreTrainedTokenizer, + model_config: ModelConfig) -> LogitsProcessor | None: + guided_params = maybe_backend_fallback(guided_params) # CFG grammar not supported by LMFE, so we use outlines instead - if guided_params.backend == 'outlines' or guided_params.grammar: + if guided_params.backend == 'outlines': # NOTE: lazy import outlines to avoid https://github.com/vllm-project/vllm/issues/4193 from vllm.model_executor.guided_decoding.outlines_decoding import ( # noqa get_local_outlines_guided_decoding_logits_processor) @@ -40,7 +86,12 @@ def get_local_guided_decoding_logits_processor( get_local_lm_format_enforcer_guided_decoding_logits_processor) return get_local_lm_format_enforcer_guided_decoding_logits_processor( guided_params, tokenizer) + if guided_params.backend == 'xgrammar': + from vllm.model_executor.guided_decoding.xgrammar_decoding import ( # noqa + get_local_xgrammar_guided_decoding_logits_processor) + return get_local_xgrammar_guided_decoding_logits_processor( + guided_params, tokenizer, model_config) raise ValueError( f"Unknown guided decoding backend '{guided_params.backend}'. " - "Must be one of 'outlines, 'lm-format-enforcer'") + "Must be one of 'outlines, 'lm-format-enforcer', 'xgrammar'") diff --git a/vllm/model_executor/guided_decoding/xgrammar_decoding.py b/vllm/model_executor/guided_decoding/xgrammar_decoding.py new file mode 100644 index 0000000000000..8287cd6cf3aa0 --- /dev/null +++ b/vllm/model_executor/guided_decoding/xgrammar_decoding.py @@ -0,0 +1,251 @@ +# noqa: UP007 +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import TYPE_CHECKING, Any, NamedTuple + +import torch +from transformers import PreTrainedTokenizerFast + +try: + import xgrammar as xgr + from xgrammar.base import _core as xgr_core +except ImportError: + pass + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from vllm.config import ModelConfig + from vllm.sampling_params import GuidedDecodingParams + + +# TODO: passing batch size to max threads here +def get_local_xgrammar_guided_decoding_logits_processor( + guided_params: GuidedDecodingParams, + tokenizer: PreTrainedTokenizer, + model_config: ModelConfig, + max_threads: int = 8): + config = GrammarConfig.from_guided_params(guided_params=guided_params, + model_config=model_config, + tokenizer=tokenizer, + max_threads=max_threads) + return XGrammarLogitsProcessor(config) + + +class TokenizerData(NamedTuple): + """Immutable container for cached tokenizer data.""" + encoded_vocab: list[str] + stop_token_ids: list[int] | None + backend_str: str + + +class TokenizerDataCache: + """Cache manager for tokenizer data to avoid repeated processing.""" + _cache: dict[int, TokenizerData] = {} + + @classmethod + def get_tokenizer_data(cls, + tokenizer: PreTrainedTokenizer) -> TokenizerData: + tokenizer_hash = hash(tokenizer) + + if tokenizer_hash not in cls._cache: + # Vendored from xgrammar logic since we cannot pickle the tokenizer + # https://github.com/mlc-ai/xgrammar/blob/d77c0a0173ef14779c918e3be7966ba852f7910f/python/xgrammar/tokenizer_info.py#L98 # noqa: E501 + try: + encoded_vocab = [ + token for token, _ in sorted(tokenizer.get_vocab().items(), + key=lambda x: x[1]) + ] + except AttributeError as e: + raise ValueError( + f"Cannot get the vocabulary of the tokenizer " + f"{type(tokenizer)}. The tokenizer should have a " + "get_vocab method.") from e + + stop_token_ids = None + backend_str = xgr.VocabType.RAW + if isinstance(tokenizer, PreTrainedTokenizerFast): + backend_str = tokenizer.backend_tokenizer.to_str() + if stop_token_ids is None and hasattr( + tokenizer, + "eos_token_id") and tokenizer.eos_token_id is not None: + stop_token_ids = [tokenizer.eos_token_id] + + cls._cache[tokenizer_hash] = TokenizerData( + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str) + + return cls._cache[tokenizer_hash] + + +class GrammarCompilerCache: + """ + Cache for GrammarCompiler instances based on tokenizer. + + This cache reduces the overhead of creating new compiler instances when + using the same tokenizer configuration. + """ + _cache: dict[str, xgr.GrammarCompiler] = {} + + @classmethod + def get_compiler(cls, config: GrammarConfig) -> xgr.GrammarCompiler: + cache_key = str(config.tokenizer_hash) + + if cache_key not in cls._cache: + assert config.encoded_vocab is not None + tokenizer_info = xgr.TokenizerInfo._create_from_handle( + xgr_core.TokenizerInfo.from_huggingface( + config.encoded_vocab, config.backend_str, + config.vocab_size, config.stop_token_ids)) + cls._cache[cache_key] = xgr.GrammarCompiler( + tokenizer_info, max_threads=config.max_threads) + + return cls._cache[cache_key] + + +@dataclass +class GrammarConfig: + """Serializable configuration for grammar compilation""" + tokenizer_hash: int + vocab_size: int + json_str: str | None = None + grammar_str: str | None = None + json_object: bool | None = None + max_threads: int = 8 + # Only populated if tokenizer_hash not in cache + encoded_vocab: list[str] | None = None + stop_token_ids: list[int] | None = None + backend_str: str | None = None + + @classmethod + def from_guided_params(cls, + guided_params: GuidedDecodingParams, + model_config: ModelConfig, + tokenizer: PreTrainedTokenizer, + max_threads: int = 8) -> GrammarConfig: + + tokenizer_hash = hash(tokenizer) + # Only get tokenizer data if not already cached + if tokenizer_hash in TokenizerDataCache._cache: + encoded_vocab = None + stop_token_ids = None + backend_str = None + else: + tokenizer_data = TokenizerDataCache.get_tokenizer_data(tokenizer) + encoded_vocab = tokenizer_data.encoded_vocab + stop_token_ids = tokenizer_data.stop_token_ids + backend_str = tokenizer_data.backend_str + + if guided_params.json: + if not isinstance(guided_params.json, str): + json_str = json.dumps(guided_params.json) + else: + json_str = guided_params.json + return cls(json_str=json_str, + vocab_size=model_config.hf_config.vocab_size, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) + elif guided_params.grammar: + return cls(grammar_str=guided_params.grammar, + vocab_size=model_config.hf_config.vocab_size, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) + elif guided_params.json_object: + return cls(json_object=True, + vocab_size=model_config.hf_config.vocab_size, + encoded_vocab=encoded_vocab, + stop_token_ids=stop_token_ids, + backend_str=backend_str, + tokenizer_hash=tokenizer_hash, + max_threads=max_threads) + else: + raise ValueError( + "Currently only support JSON and EBNF grammar mode for xgrammar" + ) + + +@dataclass +class XGrammarLogitsProcessor: + """Wrapper class to support pickle protocol""" + config: GrammarConfig + + ctx: xgr.CompiledGrammar | None = None + token_bitmask: torch.Tensor = None # type: ignore[assignment] + matchers: list[xgr.GrammarMatcher] = field(default_factory=list) + batch_size: int = field(default=1) + prefilled: bool = field(default=False) + + def __getstate__(self) -> dict[str, Any]: + return {'config': self.config} + + def __setstate__(self, state: dict[str, Any]): + self.config = state['config'] + + self.ctx = None + self.matchers = [] + self.batch_size = 1 + self.token_bitmask = None # type: ignore[assignment] + self.prefilled = False + + def _ensure_ctx(self): + """Lazily initialize the processor in the worker process""" + if self.ctx is None: + compiler = GrammarCompilerCache.get_compiler(self.config) + if self.config.json_str is not None: + self.ctx = compiler.compile_json_schema(self.config.json_str) + elif self.config.grammar_str is not None: + self.ctx = compiler.compile_grammar(self.config.grammar_str) + elif self.config.json_object: + self.ctx = compiler.compile_builtin_json_grammar() + else: + raise ValueError( + "Invalid configuration for xgrammar logits processor") + + def __call__(self, input_ids: list[int], + scores: torch.Tensor) -> torch.Tensor: + if self.ctx is None: + self._ensure_ctx() + + if len(self.matchers) == 0: + self.matchers = [ + xgr.GrammarMatcher(self.ctx) for _ in range(self.batch_size) + ] + self.token_bitmask = xgr.allocate_token_bitmask( + self.batch_size, self.config.vocab_size) + + if not self.prefilled: + # Have not sampled a token yet + self.prefilled = True + else: + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + sampled_token = input_ids[-1] + assert self.matchers[i].accept_token(sampled_token) + + for i, matcher in enumerate(self.matchers): + if not matcher.is_terminated(): + # @ubospica: ideally, fill_next_token_bitmask should be + # parallelized with model decoding + # See https://github.com/vllm-project/vllm/pull/10785/files#r1864278303 + matcher.fill_next_token_bitmask(self.token_bitmask, i) + + # token_bitmask is a CPU tensor for use with accept_token and + # fill_next_token_bitmask so we move it to the device of scores + device_type = scores.device.type + if device_type != "cuda": + scores = scores.to("cpu") + xgr.apply_token_bitmask_inplace(scores, + self.token_bitmask.to(scores.device)) + if device_type != "cuda": + scores = scores.to(device_type) + + return scores From f6084f63248a89df52bed9d9c24d6604f87e51f3 Mon Sep 17 00:00:00 2001 From: Yang Zheng <50227060+zhengy001@users.noreply.github.com> Date: Tue, 3 Dec 2024 17:01:39 +0800 Subject: [PATCH 13/13] [Speculative Decoding] Move indices to device before filtering output (#10850) Co-authored-by: Yang Zheng(SW)(Alex) --- vllm/spec_decode/multi_step_worker.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/vllm/spec_decode/multi_step_worker.py b/vllm/spec_decode/multi_step_worker.py index d249b37c780e4..676ac5eb3609d 100644 --- a/vllm/spec_decode/multi_step_worker.py +++ b/vllm/spec_decode/multi_step_worker.py @@ -120,6 +120,9 @@ def sampler_output( indices_of_seq_with_bonus_tokens) model_outputs.append(model_output) + # move indices to device to avoid stream sync + indices_of_seq_with_bonus_tokens = torch.tensor( + indices_of_seq_with_bonus_tokens, device=self.device) filtered_model_outputs = self._filter_model_output( model_outputs, indices_of_seq_with_bonus_tokens) return filtered_model_outputs, True @@ -189,7 +192,7 @@ def _expand_execute_model_request( @staticmethod def _filter_model_output( expanded_batch_outputs: List[SamplerOutput], - output_indices_to_retain: List[int]) -> List[SamplerOutput]: + output_indices_to_retain: torch.Tensor) -> List[SamplerOutput]: """ Filters the model output to include only the specified sequence outputs. This method contracts the expanded batch output from the @@ -199,8 +202,8 @@ def _filter_model_output( Args: expanded_batch_output (List[SamplerOutput]): The expanded output batch from the model. - output_indices_to_retain (List[int]): Indices of the model outputs - to retain. + output_indices_to_retain (torch.Tensor): Indices of the model + outputs to retain. Returns: List[SamplerOutput]: A list containing the filtered model