From 5621569b751497adfa2fc87eb074e93aacd72184 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Fri, 19 Jan 2024 01:36:58 +0800 Subject: [PATCH 1/5] complete coding --- examples/offline_inference.py | 10 +++------- vllm/engine/llm_engine.py | 18 ++++++++++++++++++ vllm/model_executor/model_loader.py | 4 ++-- .../parallel_utils/communication_op.py | 4 ++-- vllm/worker/worker.py | 7 ++++--- 5 files changed, 29 insertions(+), 14 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index 9b758fa2479f6..e8f091f375aa8 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,17 +1,13 @@ from vllm import LLM, SamplingParams # Sample prompts. -prompts = [ - "Hello, my name is", - "The president of the United States is", - "The capital of France is", - "The future of AI is", -] +prompts = ["who are you"] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="facebook/opt-125m") +llm = LLM(model="/home/leejee/Data/LLM_Pretrained/chatglm3_backup", + trust_remote_code=True) # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 7072a8bbc5b3e..a03569b82c2a1 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -108,6 +108,7 @@ def __init__( os.environ["RAY_USAGE_STATS_ENABLED"] = "0" self._init_workers_ray(placement_group) else: + self._init_single_gpu_config() self._init_workers() # Profile the memory usage and initialize the cache. @@ -917,3 +918,20 @@ def _run_workers( ray_worker_outputs = ray.get(ray_worker_outputs) return [driver_worker_output] + ray_worker_outputs + + def _init_single_gpu_config(self) -> None: + + def _parallel_rank_mp(*args, **kargs) -> int: + return 0 + + def _parallel_world_size_mp(*args, **kargs) -> int: + return 1 + + def _parallel_group_mp(*args, **kargs) -> int: + return 1 + + import vllm.model_executor.parallel_utils.parallel_state + + vllm.model_executor.parallel_utils.parallel_state.get_tensor_model_parallel_world_size = _parallel_world_size_mp + vllm.model_executor.parallel_utils.parallel_state.get_tensor_model_parallel_rank = _parallel_rank_mp + vllm.model_executor.parallel_utils.parallel_state.get_tensor_model_parallel_group = _parallel_group_mp diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index 37543d8c9838e..fb56924f4eeeb 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -8,8 +8,6 @@ from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry -from vllm.model_executor.weight_utils import (get_quant_config, - initialize_dummy_weights) @contextlib.contextmanager @@ -33,6 +31,8 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def get_model(model_config: ModelConfig) -> nn.Module: + from vllm.model_executor.weight_utils import (get_quant_config, + initialize_dummy_weights) model_class = _get_model_architecture(model_config.hf_config) # Get the (maybe quantized) linear method. diff --git a/vllm/model_executor/parallel_utils/communication_op.py b/vllm/model_executor/parallel_utils/communication_op.py index 8bf04f3d1f056..6b95fd91a4e87 100644 --- a/vllm/model_executor/parallel_utils/communication_op.py +++ b/vllm/model_executor/parallel_utils/communication_op.py @@ -82,7 +82,7 @@ def tensor_model_parallel_gather(input_, dst=0, dim=-1): def broadcast(input_, src=0): """Broadcast the input tensor.""" - world_size = torch.distributed.get_world_size() + world_size = get_tensor_model_parallel_world_size() assert 0 <= src < world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. @@ -95,7 +95,7 @@ def broadcast(input_, src=0): def broadcast_object_list(obj_list, src=0): """Broadcast the input object list.""" - world_size = torch.distributed.get_world_size() + world_size = get_tensor_model_parallel_world_size() assert 0 <= src < world_size, f"Invalid src rank ({src})" # Bypass the function if we are using only 1 GPU. diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index c2a2ac148085b..de7bae73dd29d 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -70,9 +70,10 @@ def init_model(self) -> None: _check_if_gpu_supports_dtype(self.model_config.dtype) - # Initialize the distributed environment. - _init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) + if self.parallel_config.worker_use_ray: + # Initialize the distributed environment. + _init_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method) # Initialize the model. set_random_seed(self.model_config.seed) From 4219d19844b0748b73d14f31356ea34bbb50f718 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Fri, 19 Jan 2024 01:44:33 +0800 Subject: [PATCH 2/5] modify offline_inference.py --- examples/offline_inference.py | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/examples/offline_inference.py b/examples/offline_inference.py index e8f091f375aa8..23cc6e8539431 100644 --- a/examples/offline_inference.py +++ b/examples/offline_inference.py @@ -1,13 +1,17 @@ from vllm import LLM, SamplingParams # Sample prompts. -prompts = ["who are you"] +prompts = [ + "Hello, my name is", + "The president of the United States is", + "The capital of France is", + "The future of AI is", +] # Create a sampling params object. sampling_params = SamplingParams(temperature=0.8, top_p=0.95) # Create an LLM. -llm = LLM(model="/home/leejee/Data/LLM_Pretrained/chatglm3_backup", - trust_remote_code=True) +llm = LLM(model="facebook/opt-125m") # Generate texts from the prompts. The output is a list of RequestOutput objects # that contain the prompt, generated text, and other information. outputs = llm.generate(prompts, sampling_params) @@ -15,4 +19,4 @@ for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text - print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") + print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") \ No newline at end of file From 999dda22f68501e1c4c63de2d8eb9ba8ed5fffb7 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Fri, 19 Jan 2024 14:36:22 +0800 Subject: [PATCH 3/5] refactor code --- vllm/engine/llm_engine.py | 161 +++++++++++++++++----------- vllm/model_executor/model_loader.py | 6 +- vllm/worker/worker.py | 9 +- 3 files changed, 105 insertions(+), 71 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a03569b82c2a1..545dbdf2763ff 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2,11 +2,9 @@ from collections import defaultdict import os import time -from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, - Union) +from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union -from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, - SchedulerConfig) +from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import record_metrics @@ -14,10 +12,15 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, - SequenceGroupOutput, SequenceOutput, SequenceStatus) -from vllm.transformers_utils.tokenizer import (detokenize_incrementally, - get_tokenizer) +from vllm.sequence import ( + SamplerOutput, + Sequence, + SequenceGroup, + SequenceGroupOutput, + SequenceOutput, + SequenceStatus, +) +from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port if ray: @@ -97,7 +100,8 @@ def __init__( tokenizer_mode=model_config.tokenizer_mode, trust_remote_code=model_config.trust_remote_code, tokenizer_revision=model_config.tokenizer_revision, - revision=model_config.revision) + revision=model_config.revision, + ) self.seq_counter = Counter() # Create the parallel GPU workers. @@ -108,7 +112,6 @@ def __init__( os.environ["RAY_USAGE_STATS_ENABLED"] = "0" self._init_workers_ray(placement_group) else: - self._init_single_gpu_config() self._init_workers() # Profile the memory usage and initialize the cache. @@ -129,9 +132,9 @@ def _init_workers(self): # before CUDA_VISIBLE_DEVICES is set in the Worker from vllm.worker.worker import Worker - assert self.parallel_config.world_size == 1, ( - "Ray is required if parallel_config.world_size > 1.") - + assert (self.parallel_config.world_size == 1 + ), "Ray is required if parallel_config.world_size > 1." + self._init_single_gpu_config() self.workers: List[Worker] = [] distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}" self.driver_worker = Worker( @@ -326,9 +329,11 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": # Initialize the cluster. placement_group = initialize_cluster(parallel_config) # Create the LLM engine. - engine = cls(*engine_configs, - placement_group, - log_stats=not engine_args.disable_log_stats) + engine = cls( + *engine_configs, + placement_group, + log_stats=not engine_args.disable_log_stats, + ) return engine def add_request( @@ -397,8 +402,8 @@ def add_request( seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) # Check whether the input specifies prefix - prefix = self.scheduler.prefix_pool.add_or_get_prefix( - prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None + prefix = (self.scheduler.prefix_pool.add_or_get_prefix( + prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None) # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, @@ -450,13 +455,13 @@ def _check_beam_search_early_stopping( if early_stopping is True: return True - current_worst_score = (current_worst_seq.get_beam_search_score( + current_worst_score = current_worst_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.tokenizer.eos_token_id) if early_stopping is False: - highest_attainable_score = (best_running_seq.get_beam_search_score( + highest_attainable_score = best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + eos_token_id=self.tokenizer.eos_token_id) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -466,20 +471,21 @@ def _check_beam_search_early_stopping( max_possible_length = max( best_running_seq.get_prompt_len() + sampling_params.max_tokens, - self.scheduler_config.max_model_len) - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id, - seq_len=max_possible_length)) + self.scheduler_config.max_model_len, + ) + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + seq_len=max_possible_length, + ) else: # Otherwise, beam search will prefer shorter sequences. The # highest attainable score calculation is based on the current # sequence length. - highest_attainable_score = ( - best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id)) + highest_attainable_score = best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + ) return current_worst_score >= highest_attainable_score def _process_sequence_group_outputs(self, seq_group: SequenceGroup, @@ -568,10 +574,12 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq.is_finished()] all_finished_seqs = existing_finished_seqs + new_finished_seqs # Sort the finished sequences by their scores. - all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), - reverse=True) + all_finished_seqs.sort( + key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id), + reverse=True, + ) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: # A newly generated child sequence finishes and has a high @@ -596,10 +604,12 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, running_child_seqs = [(seq, parent) for seq, parent in child_seqs if not seq.is_finished()] # Sort the running sequences by their scores. - running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), - reverse=True) + running_child_seqs.sort( + key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id), + reverse=True, + ) # Check if we can stop the beam search. if len(running_child_seqs) == 0: @@ -614,7 +624,10 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, current_worst_seq = all_finished_seqs[beam_width - 1][0] stop_beam_search = self._check_beam_search_early_stopping( seq_group.sampling_params.early_stopping, - seq_group.sampling_params, best_running_seq, current_worst_seq) + seq_group.sampling_params, + best_running_seq, + current_worst_seq, + ) if stop_beam_search: # Stop the beam search and remove all the running sequences from @@ -747,7 +760,8 @@ def step(self) -> List[RequestOutput]: "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, "blocks_to_copy": scheduler_outputs.blocks_to_copy, - }) + }, + ) # Only the driver worker returns the sampling results. output = all_outputs[0] @@ -797,15 +811,15 @@ def _log_system_stats( avg_generation_throughput = 0.0 total_num_gpu_blocks = self.cache_config.num_gpu_blocks - num_free_gpu_blocks = ( - self.scheduler.block_manager.get_num_free_gpu_blocks()) + num_free_gpu_blocks = self.scheduler.block_manager.get_num_free_gpu_blocks( + ) num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks total_num_cpu_blocks = self.cache_config.num_cpu_blocks if total_num_cpu_blocks > 0: - num_free_cpu_blocks = ( - self.scheduler.block_manager.get_num_free_cpu_blocks()) + num_free_cpu_blocks = self.scheduler.block_manager.get_num_free_cpu_blocks( + ) num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks else: @@ -834,16 +848,20 @@ def _log_system_stats( def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" - (new_tokens, new_output_text, prefix_offset, - read_offset) = detokenize_incrementally( - self.tokenizer, - all_input_ids=seq.get_token_ids(), - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) + ( + new_tokens, + new_output_text, + prefix_offset, + read_offset, + ) = detokenize_incrementally( + self.tokenizer, + all_input_ids=seq.get_token_ids(), + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) if seq.tokens is None: seq.tokens = new_tokens else: @@ -878,8 +896,8 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has generated the EOS token. - if ((not sampling_params.ignore_eos) - and seq.get_last_token_id() == self.tokenizer.eos_token_id): + if (not sampling_params.ignore_eos + ) and seq.get_last_token_id() == self.tokenizer.eos_token_id: seq.status = SequenceStatus.FINISHED_STOPPED return @@ -920,6 +938,13 @@ def _run_workers( return [driver_worker_output] + ray_worker_outputs def _init_single_gpu_config(self) -> None: + _NEED_RELOAD_MODULES = [ + "vllm.model_executor.parallel_utils.communication_op", + "vllm.model_executor.layers.linear", + "vllm.model_executor.layers.activation", + "vllm.model_executor.layers.sampler", + "vllm.model_executor.layers.vocab_parallel_embedding", + ] def _parallel_rank_mp(*args, **kargs) -> int: return 0 @@ -930,8 +955,18 @@ def _parallel_world_size_mp(*args, **kargs) -> int: def _parallel_group_mp(*args, **kargs) -> int: return 1 - import vllm.model_executor.parallel_utils.parallel_state - - vllm.model_executor.parallel_utils.parallel_state.get_tensor_model_parallel_world_size = _parallel_world_size_mp - vllm.model_executor.parallel_utils.parallel_state.get_tensor_model_parallel_rank = _parallel_rank_mp - vllm.model_executor.parallel_utils.parallel_state.get_tensor_model_parallel_group = _parallel_group_mp + import sys + import importlib + import vllm.model_executor.parallel_utils.parallel_state as ps_module + + ps_module.get_tensor_model_parallel_world_size = _parallel_world_size_mp + ps_module.get_tensor_model_parallel_rank = _parallel_rank_mp + ps_module.get_tensor_model_parallel_group = _parallel_group_mp + for module_name in _NEED_RELOAD_MODULES: + if module_name in sys.modules: + module_before = sys.modules.get(module_name, None) + _ = importlib.reload(module_before) # retrurn reloaded module + module_worker = "vllm.worker.worker" + module = sys.modules.get(module_worker, None) + assert module + module._init_distributed_environment = lambda *args, **kargs: 0 diff --git a/vllm/model_executor/model_loader.py b/vllm/model_executor/model_loader.py index fb56924f4eeeb..40f441fc9cb50 100644 --- a/vllm/model_executor/model_loader.py +++ b/vllm/model_executor/model_loader.py @@ -8,6 +8,8 @@ from vllm.config import ModelConfig from vllm.model_executor.models import ModelRegistry +from vllm.model_executor.weight_utils import (get_quant_config, + initialize_dummy_weights) @contextlib.contextmanager @@ -31,8 +33,6 @@ def _get_model_architecture(config: PretrainedConfig) -> Type[nn.Module]: def get_model(model_config: ModelConfig) -> nn.Module: - from vllm.model_executor.weight_utils import (get_quant_config, - initialize_dummy_weights) model_class = _get_model_architecture(model_config.hf_config) # Get the (maybe quantized) linear method. @@ -71,4 +71,4 @@ def get_model(model_config: ModelConfig) -> nn.Module: # Load the weights from the cached or downloaded files. model.load_weights(model_config.model, model_config.download_dir, model_config.load_format, model_config.revision) - return model.eval() + return model.eval() \ No newline at end of file diff --git a/vllm/worker/worker.py b/vllm/worker/worker.py index de7bae73dd29d..db41eac9f2631 100644 --- a/vllm/worker/worker.py +++ b/vllm/worker/worker.py @@ -70,10 +70,9 @@ def init_model(self) -> None: _check_if_gpu_supports_dtype(self.model_config.dtype) - if self.parallel_config.worker_use_ray: - # Initialize the distributed environment. - _init_distributed_environment(self.parallel_config, self.rank, - self.distributed_init_method) + # Initialize the distributed environment. + _init_distributed_environment(self.parallel_config, self.rank, + self.distributed_init_method) # Initialize the model. set_random_seed(self.model_config.seed) @@ -242,4 +241,4 @@ def _check_if_gpu_supports_dtype(torch_dtype: torch.dtype): f"of at least 8.0. Your {gpu_name} GPU has compute capability " f"{compute_capability[0]}.{compute_capability[1]}. " "You can use float16 instead by explicitly setting the" - "`dtype` flag in CLI, for example: --dtype=half.") + "`dtype` flag in CLI, for example: --dtype=half.") \ No newline at end of file From a5f9036fcf520b370c871cbc8f54c354022c140f Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Fri, 19 Jan 2024 14:48:09 +0800 Subject: [PATCH 4/5] format code --- vllm/engine/llm_engine.py | 131 ++++++++++++++++---------------------- 1 file changed, 56 insertions(+), 75 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index 545dbdf2763ff..a9317574876bc 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -2,9 +2,11 @@ from collections import defaultdict import os import time -from typing import TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, Union +from typing import (TYPE_CHECKING, Any, Dict, Iterable, List, Optional, Tuple, + Union) -from vllm.config import CacheConfig, ModelConfig, ParallelConfig, SchedulerConfig +from vllm.config import (CacheConfig, ModelConfig, ParallelConfig, + SchedulerConfig) from vllm.core.scheduler import Scheduler, SchedulerOutputs from vllm.engine.arg_utils import EngineArgs from vllm.engine.metrics import record_metrics @@ -12,15 +14,10 @@ from vllm.logger import init_logger from vllm.outputs import RequestOutput from vllm.sampling_params import SamplingParams -from vllm.sequence import ( - SamplerOutput, - Sequence, - SequenceGroup, - SequenceGroupOutput, - SequenceOutput, - SequenceStatus, -) -from vllm.transformers_utils.tokenizer import detokenize_incrementally, get_tokenizer +from vllm.sequence import (SamplerOutput, Sequence, SequenceGroup, + SequenceGroupOutput, SequenceOutput, SequenceStatus) +from vllm.transformers_utils.tokenizer import (detokenize_incrementally, + get_tokenizer) from vllm.utils import Counter, set_cuda_visible_devices, get_ip, get_open_port if ray: @@ -100,8 +97,7 @@ def __init__( tokenizer_mode=model_config.tokenizer_mode, trust_remote_code=model_config.trust_remote_code, tokenizer_revision=model_config.tokenizer_revision, - revision=model_config.revision, - ) + revision=model_config.revision) self.seq_counter = Counter() # Create the parallel GPU workers. @@ -132,8 +128,8 @@ def _init_workers(self): # before CUDA_VISIBLE_DEVICES is set in the Worker from vllm.worker.worker import Worker - assert (self.parallel_config.world_size == 1 - ), "Ray is required if parallel_config.world_size > 1." + assert self.parallel_config.world_size == 1, ( + "Ray is required if parallel_config.world_size > 1.") self._init_single_gpu_config() self.workers: List[Worker] = [] distributed_init_method = f"tcp://{get_ip()}:{get_open_port()}" @@ -329,11 +325,9 @@ def from_engine_args(cls, engine_args: EngineArgs) -> "LLMEngine": # Initialize the cluster. placement_group = initialize_cluster(parallel_config) # Create the LLM engine. - engine = cls( - *engine_configs, - placement_group, - log_stats=not engine_args.disable_log_stats, - ) + engine = cls(*engine_configs, + placement_group, + log_stats=not engine_args.disable_log_stats) return engine def add_request( @@ -402,8 +396,8 @@ def add_request( seq = Sequence(seq_id, prompt, prompt_token_ids, block_size) # Check whether the input specifies prefix - prefix = (self.scheduler.prefix_pool.add_or_get_prefix( - prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None) + prefix = self.scheduler.prefix_pool.add_or_get_prefix( + prompt_token_ids[:prefix_pos]) if prefix_pos is not None else None # Create the sequence group. seq_group = SequenceGroup(request_id, [seq], sampling_params, @@ -455,13 +449,13 @@ def _check_beam_search_early_stopping( if early_stopping is True: return True - current_worst_score = current_worst_seq.get_beam_search_score( + current_worst_score = (current_worst_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id) + eos_token_id=self.tokenizer.eos_token_id)) if early_stopping is False: - highest_attainable_score = best_running_seq.get_beam_search_score( + highest_attainable_score = (best_running_seq.get_beam_search_score( length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id) + eos_token_id=self.tokenizer.eos_token_id)) else: assert early_stopping == "never" if length_penalty > 0.0: @@ -471,21 +465,20 @@ def _check_beam_search_early_stopping( max_possible_length = max( best_running_seq.get_prompt_len() + sampling_params.max_tokens, - self.scheduler_config.max_model_len, - ) - highest_attainable_score = best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id, - seq_len=max_possible_length, - ) + self.scheduler_config.max_model_len) + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id, + seq_len=max_possible_length)) else: # Otherwise, beam search will prefer shorter sequences. The # highest attainable score calculation is based on the current # sequence length. - highest_attainable_score = best_running_seq.get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id, - ) + highest_attainable_score = ( + best_running_seq.get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id)) return current_worst_score >= highest_attainable_score def _process_sequence_group_outputs(self, seq_group: SequenceGroup, @@ -574,12 +567,10 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, if seq.is_finished()] all_finished_seqs = existing_finished_seqs + new_finished_seqs # Sort the finished sequences by their scores. - all_finished_seqs.sort( - key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), - reverse=True, - ) + all_finished_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id), + reverse=True) for seq, parent, is_new in all_finished_seqs[:beam_width]: if is_new: # A newly generated child sequence finishes and has a high @@ -604,12 +595,10 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, running_child_seqs = [(seq, parent) for seq, parent in child_seqs if not seq.is_finished()] # Sort the running sequences by their scores. - running_child_seqs.sort( - key=lambda x: x[0].get_beam_search_score( - length_penalty=length_penalty, - eos_token_id=self.tokenizer.eos_token_id), - reverse=True, - ) + running_child_seqs.sort(key=lambda x: x[0].get_beam_search_score( + length_penalty=length_penalty, + eos_token_id=self.tokenizer.eos_token_id), + reverse=True) # Check if we can stop the beam search. if len(running_child_seqs) == 0: @@ -624,10 +613,7 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, current_worst_seq = all_finished_seqs[beam_width - 1][0] stop_beam_search = self._check_beam_search_early_stopping( seq_group.sampling_params.early_stopping, - seq_group.sampling_params, - best_running_seq, - current_worst_seq, - ) + seq_group.sampling_params, best_running_seq, current_worst_seq) if stop_beam_search: # Stop the beam search and remove all the running sequences from @@ -760,8 +746,7 @@ def step(self) -> List[RequestOutput]: "blocks_to_swap_in": scheduler_outputs.blocks_to_swap_in, "blocks_to_swap_out": scheduler_outputs.blocks_to_swap_out, "blocks_to_copy": scheduler_outputs.blocks_to_copy, - }, - ) + }) # Only the driver worker returns the sampling results. output = all_outputs[0] @@ -811,15 +796,15 @@ def _log_system_stats( avg_generation_throughput = 0.0 total_num_gpu_blocks = self.cache_config.num_gpu_blocks - num_free_gpu_blocks = self.scheduler.block_manager.get_num_free_gpu_blocks( - ) + num_free_gpu_blocks = ( + self.scheduler.block_manager.get_num_free_gpu_blocks()) num_used_gpu_blocks = total_num_gpu_blocks - num_free_gpu_blocks gpu_cache_usage = num_used_gpu_blocks / total_num_gpu_blocks total_num_cpu_blocks = self.cache_config.num_cpu_blocks if total_num_cpu_blocks > 0: - num_free_cpu_blocks = self.scheduler.block_manager.get_num_free_cpu_blocks( - ) + num_free_cpu_blocks = ( + self.scheduler.block_manager.get_num_free_cpu_blocks()) num_used_cpu_blocks = total_num_cpu_blocks - num_free_cpu_blocks cpu_cache_usage = num_used_cpu_blocks / total_num_cpu_blocks else: @@ -848,20 +833,16 @@ def _log_system_stats( def _decode_sequence(self, seq: Sequence, prms: SamplingParams) -> None: """Decodes the new token for a sequence.""" - ( - new_tokens, - new_output_text, - prefix_offset, - read_offset, - ) = detokenize_incrementally( - self.tokenizer, - all_input_ids=seq.get_token_ids(), - prev_tokens=seq.tokens, - prefix_offset=seq.prefix_offset, - read_offset=seq.read_offset, - skip_special_tokens=prms.skip_special_tokens, - spaces_between_special_tokens=prms.spaces_between_special_tokens, - ) + (new_tokens, new_output_text, prefix_offset, + read_offset) = detokenize_incrementally( + self.tokenizer, + all_input_ids=seq.get_token_ids(), + prev_tokens=seq.tokens, + prefix_offset=seq.prefix_offset, + read_offset=seq.read_offset, + skip_special_tokens=prms.skip_special_tokens, + spaces_between_special_tokens=prms.spaces_between_special_tokens, + ) if seq.tokens is None: seq.tokens = new_tokens else: @@ -896,8 +877,8 @@ def _check_stop(self, seq: Sequence, return # Check if the sequence has generated the EOS token. - if (not sampling_params.ignore_eos - ) and seq.get_last_token_id() == self.tokenizer.eos_token_id: + if ((not sampling_params.ignore_eos) + and seq.get_last_token_id() == self.tokenizer.eos_token_id): seq.status = SequenceStatus.FINISHED_STOPPED return From b2eedc0ec7cdafb2a4e61ee67b1fe035a3bdedd0 Mon Sep 17 00:00:00 2001 From: jeejeeli Date: Fri, 19 Jan 2024 18:23:01 +0800 Subject: [PATCH 5/5] refactor code and add comment --- vllm/engine/llm_engine.py | 44 +++++++++++++++++++++++---------------- 1 file changed, 26 insertions(+), 18 deletions(-) diff --git a/vllm/engine/llm_engine.py b/vllm/engine/llm_engine.py index a9317574876bc..39da131fcae5e 100644 --- a/vllm/engine/llm_engine.py +++ b/vllm/engine/llm_engine.py @@ -919,6 +919,20 @@ def _run_workers( return [driver_worker_output] + ray_worker_outputs def _init_single_gpu_config(self) -> None: + """Using monkey patching to avoid initializing distributed group for a single GPU + + Details + - Step 1: As shown in the following code, use monkey patching to modify + `get_tensor_model_parallel_rank`、`get_tensor_model_parallel_world_size` + and get_tensor_model_parallel_group. + - Step 2: Due to Python's import mechanism, we must reload certain + modules (those to be reloaded are stored in `_NEED_RELOAD_MODULES`) so that + the monkey patching in Step 1 can take effect. + - Step 3: Use monkey patching to modify the `_init_distributed_environment` of + module `vllm.worker.worker` + + + """ _NEED_RELOAD_MODULES = [ "vllm.model_executor.parallel_utils.communication_op", "vllm.model_executor.layers.linear", @@ -926,28 +940,22 @@ def _init_single_gpu_config(self) -> None: "vllm.model_executor.layers.sampler", "vllm.model_executor.layers.vocab_parallel_embedding", ] - - def _parallel_rank_mp(*args, **kargs) -> int: - return 0 - - def _parallel_world_size_mp(*args, **kargs) -> int: - return 1 - - def _parallel_group_mp(*args, **kargs) -> int: - return 1 - import sys import importlib import vllm.model_executor.parallel_utils.parallel_state as ps_module - - ps_module.get_tensor_model_parallel_world_size = _parallel_world_size_mp - ps_module.get_tensor_model_parallel_rank = _parallel_rank_mp - ps_module.get_tensor_model_parallel_group = _parallel_group_mp + assert self.parallel_config.world_size == 1, ( + "it is required that the world_size must be 1.") + #Step 1 + ps_module.get_tensor_model_parallel_rank = lambda *args, **kargs: 0 + ps_module.get_tensor_model_parallel_world_size = lambda *args, **kargs: 1 + ps_module.get_tensor_model_parallel_group = lambda *args, **kargs: 1 + #Step 2 for module_name in _NEED_RELOAD_MODULES: if module_name in sys.modules: module_before = sys.modules.get(module_name, None) _ = importlib.reload(module_before) # retrurn reloaded module - module_worker = "vllm.worker.worker" - module = sys.modules.get(module_worker, None) - assert module - module._init_distributed_environment = lambda *args, **kargs: 0 + #Step 3 + module_worker_name = "vllm.worker.worker" + module_worker = sys.modules.get(module_worker_name, None) + assert module_worker + module_worker._init_distributed_environment = lambda *args, **kargs: None