From 2242c38a30e51e0d3bce83363fc066f1dbdde4f0 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Mon, 25 Nov 2024 18:22:05 +0000 Subject: [PATCH 1/5] feat: classifier free guidance - take 2 --- aphrodite/cfg/__init__.py | 0 aphrodite/cfg/cfg_model_runner.py | 163 +++++++++++++++ aphrodite/cfg/cfg_worker.py | 194 ++++++++++++++++++ aphrodite/cfg/separated_worker.py | 77 +++++++ aphrodite/common/config.py | 38 ++++ aphrodite/common/sampling_params.py | 3 + aphrodite/common/sequence.py | 53 ++++- aphrodite/engine/aphrodite_engine.py | 78 +++++-- aphrodite/engine/args_tools.py | 30 ++- aphrodite/engine/async_aphrodite.py | 18 +- .../engine/output_processor/single_step.py | 2 + aphrodite/executor/executor_base.py | 10 +- aphrodite/executor/gpu_executor.py | 5 + aphrodite/inputs/data.py | 23 ++- aphrodite/modeling/models/llama.py | 27 ++- aphrodite/processing/block_manager_v2.py | 49 +++++ aphrodite/processing/scheduler.py | 33 ++- aphrodite/task_handler/model_runner.py | 3 + aphrodite/task_handler/worker.py | 13 +- 19 files changed, 775 insertions(+), 44 deletions(-) create mode 100644 aphrodite/cfg/__init__.py create mode 100644 aphrodite/cfg/cfg_model_runner.py create mode 100644 aphrodite/cfg/cfg_worker.py create mode 100644 aphrodite/cfg/separated_worker.py diff --git a/aphrodite/cfg/__init__.py b/aphrodite/cfg/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/aphrodite/cfg/cfg_model_runner.py b/aphrodite/cfg/cfg_model_runner.py new file mode 100644 index 000000000..b171f6d79 --- /dev/null +++ b/aphrodite/cfg/cfg_model_runner.py @@ -0,0 +1,163 @@ +from typing import List, Optional, Union + +import torch + +from aphrodite.common.sequence import IntermediateTensors, SamplerOutput +from aphrodite.distributed import get_pp_group +from aphrodite.multimodal import MultiModalInputs +from aphrodite.task_handler.model_runner import ( + FLASHINFER_WORKSPACE_BUFFER_SIZE, BatchDecodeWithPagedKVCacheWrapper, + BatchPrefillWithPagedKVCacheWrapper, ModelInputForGPUWithSamplingMetadata, + ModelRunner) + + +class CFGModelRunner(ModelRunner): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.inference_mode() + def model_execute( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> torch.Tensor: + if num_steps > 1: + raise ValueError("num_steps > 1 is not supported in ModelRunner") + + if self.lora_config: + assert model_input.lora_requests is not None + assert model_input.lora_mapping is not None + self.set_active_loras(model_input.lora_requests, + model_input.lora_mapping) + + if self.prompt_adapter_config: + assert model_input.prompt_adapter_requests is not None + assert model_input.prompt_adapter_mapping is not None + self.set_active_prompt_adapters( + model_input.prompt_adapter_requests, + model_input.prompt_adapter_mapping) + + if self.attn_backend.get_name() == "flashinfer": + assert model_input.attn_metadata is not None + assert model_input.input_tokens is not None + if self.flashinfer_decode_workspace_buffer is None: + self.flashinfer_decode_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_decode_wrapper = \ + BatchDecodeWithPagedKVCacheWrapper( + self.flashinfer_decode_workspace_buffer, "NHD") + self.flashinfer_prefill_workspace_buffer = torch.empty( + FLASHINFER_WORKSPACE_BUFFER_SIZE, + dtype=torch.uint8, + device=self.device) + self.flashinfer_prefill_wrapper = \ + BatchPrefillWithPagedKVCacheWrapper( + self.flashinfer_prefill_workspace_buffer, "NHD") + + model_input.attn_metadata.prefill_wrapper = \ + self.flashinfer_prefill_wrapper + if model_input.attn_metadata.use_cuda_graph: + batch_size = model_input.input_tokens.shape[0] + model_input.attn_metadata.decode_wrapper = self.graph_runners[ + model_input. + virtual_engine][batch_size].flashinfer_decode_wrapper + else: + model_input.attn_metadata.decode_wrapper = \ + self.flashinfer_decode_wrapper + model_input.attn_metadata.begin_forward() + + # Currently cuda graph is only supported by the decode phase. + assert model_input.attn_metadata is not None + prefill_meta = model_input.attn_metadata.prefill_metadata + decode_meta = model_input.attn_metadata.decode_metadata + # TODO(andoorve): We can remove this once all + # virtual engines share the same kv cache. + virtual_engine = model_input.virtual_engine + if prefill_meta is None and decode_meta.use_cuda_graph: + assert model_input.input_tokens is not None + graph_batch_size = model_input.input_tokens.shape[0] + model_executable = self.graph_runners[virtual_engine][ + graph_batch_size] + else: + model_executable = self.model + + multi_modal_kwargs = model_input.multi_modal_kwargs or {} + seqlen_agnostic_kwargs = { + "finished_requests_ids": model_input.finished_requests_ids, + "request_ids_to_seq_ids": model_input.request_ids_to_seq_ids, + } if self.has_seqlen_agnostic else {} + + hidden_or_intermediate_states = model_executable( + input_ids=model_input.input_tokens, + positions=model_input.input_positions, + kv_caches=kv_caches, + attn_metadata=model_input.attn_metadata, + intermediate_tensors=intermediate_tensors, + **MultiModalInputs.as_kwargs(multi_modal_kwargs, + device=self.device), + **seqlen_agnostic_kwargs) + + return hidden_or_intermediate_states + + @torch.inference_mode() + def get_logits( + self, + hidden_or_intermediate_states: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> torch.Tensor: + return self.model._get_logits(hidden_or_intermediate_states, + model_input.sampling_metadata) + + @torch.inference_mode() + def compute_logits( + self, + logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> torch.Tensor: + return self.model.compute_logits(logits, + model_input.sampling_metadata) + + @torch.inference_mode() + def do_sample( + self, + logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ): + if not self.is_driver_worker: + return [] + + # Sample the next token. + output: SamplerOutput = self.model.sample( + logits=logits, + sampling_metadata=model_input.sampling_metadata, + ) + + if self.return_hidden_states: + raise NotImplementedError("return_hidden_states is not supported in CFGModelRunner") + + return [output] + + @torch.inference_mode() + def execute_model( + self, + model_input: ModelInputForGPUWithSamplingMetadata, + kv_caches: List[torch.Tensor], + intermediate_tensors: Optional[IntermediateTensors] = None, + num_steps: int = 1, + ) -> Optional[Union[List[SamplerOutput], IntermediateTensors]]: + + hidden_or_intermediate_states = self.model_execute( + model_input, kv_caches, intermediate_tensors, num_steps) + + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + hidden_or_intermediate_states = self.get_logits( + hidden_or_intermediate_states, model_input) + logits = self.compute_logits(hidden_or_intermediate_states, model_input) + + return self.do_sample(logits, model_input) diff --git a/aphrodite/cfg/cfg_worker.py b/aphrodite/cfg/cfg_worker.py new file mode 100644 index 000000000..7aee2c170 --- /dev/null +++ b/aphrodite/cfg/cfg_worker.py @@ -0,0 +1,194 @@ +import copy +from typing import Dict, List, Optional, Tuple + +import torch + +from aphrodite.cfg.cfg_model_runner import CFGModelRunner +from aphrodite.cfg.separated_worker import SeparatedWorker +from aphrodite.common.config import CFGConfig, ParallelConfig +from aphrodite.common.sequence import (ExecuteModelRequest, SamplerOutput, + SequenceData, SequenceGroupMetadata) +from aphrodite.distributed import get_pp_group, get_tp_group +from aphrodite.task_handler.worker_base import (LoraNotSupportedWorkerBase, + WorkerBase) + + +def create_cfg_worker(*args, **kwargs) -> "CFGWorker": + assert "cfg_config" in kwargs + cfg_config: CFGConfig = kwargs.get("cfg_config") + assert cfg_config is not None + kwargs.pop("cfg_config") + + kwargs["model_runner_cls"] = CFGModelRunner + root_worker = SeparatedWorker(*args, **kwargs) + + guidance_model_config = cfg_config.guidance_model_config + guidance_parallel_config = cfg_config.guidance_parallel_config + kwargs.update( + model_config=guidance_model_config, + parallel_config=guidance_parallel_config, + ) + guidance_worker = SeparatedWorker(*args, **kwargs) + + return CFGWorker( + root_worker=root_worker, + guidance_worker=guidance_worker, + is_driver_worker=kwargs["is_driver_worker"], + parallel_config=kwargs["parallel_config"], + ) + + +class CFGWorker(LoraNotSupportedWorkerBase): + def __init__( + self, + root_worker: WorkerBase, + guidance_worker: WorkerBase, + is_driver_worker: bool, + parallel_config: ParallelConfig, + ): + self.root_worker = root_worker + self.guidance_worker = guidance_worker + self.is_driver_worker = is_driver_worker + self.parallel_config = parallel_config + assert self.parallel_config.pipeline_parallel_size == 1 + + def init_device(self): + self.root_task_handler.init_device() + self.guidance_worker.init_device() + + def load_model(self): + self.root_worker.load_model() + self.guidance_worker.share_model(self.root_worker) + + def determine_num_available_blocks(self) -> Tuple[int, int]: + ( + num_gpu_blocks, + num_cpu_blocks, + ) = self.root_worker.determine_num_available_blocks() + + root_cache_block_size_bytes = ( + self.root_worker.get_cache_block_size_bytes() + ) + guidance_cache_block_size_bytes = ( + self.guidance_worker.get_cache_block_size_bytes() + ) + + new_num_gpu_blocks = int( + num_gpu_blocks + * root_cache_block_size_bytes + / (guidance_cache_block_size_bytes + root_cache_block_size_bytes) + ) + return new_num_gpu_blocks, num_cpu_blocks + + def initialize_cache( + self, num_gpu_blocks: int, num_cpu_blocks: int + ) -> None: + self.root_worker.initialize_cache( + num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks + ) + self.guidance_worker.initialize_cache( + num_gpu_blocks=num_gpu_blocks, num_cpu_blocks=num_cpu_blocks + ) + + @property + def do_metadata_broadcast(self) -> bool: + return self.parallel_config.tensor_parallel_size > 1 + + @torch.inference_mode() + def execute_model( + self, execute_model_req: Optional[ExecuteModelRequest] = None + ) -> List[SamplerOutput]: + # prepare negative request with shallow copy + if execute_model_req is not None: + negative_seq_group_metadata_list: List[SequenceGroupMetadata] = [] + negative_excute_model_req = execute_model_req.clone( + negative_seq_group_metadata_list + ) + for seq_group_metadata in execute_model_req.seq_group_metadata_list: + negative_seq_group_metadata = copy.copy(seq_group_metadata) + negative_seq_data: Dict[int, SequenceData] = {} + negative_block_tables: Dict[int, List[int]] = {} + assert len(seq_group_metadata.seq_data) == 1 + for seq_id in seq_group_metadata.seq_data.keys(): + negative_seq_data[ + seq_id + ] = seq_group_metadata.negative_seq_data + negative_block_tables[ + seq_id + ] = seq_group_metadata.negative_block_table + + if negative_seq_group_metadata.is_prompt: + negative_seq_group_metadata.token_chunk_size = list( + negative_seq_data.values() + )[0].get_len() + + negative_seq_group_metadata.seq_data = negative_seq_data + negative_seq_group_metadata.block_tables = negative_block_tables + negative_seq_group_metadata.negative_seq_data = None + negative_seq_group_metadata.negative_block_table = None + negative_seq_group_metadata_list.append( + negative_seq_group_metadata + ) + negative_excute_model_req.seq_group_metadata_list = ( + negative_seq_group_metadata_list + ) + else: + negative_excute_model_req = None + + inputs = self.root_worker.prepare_input(execute_model_req) + negative_inputs = self.guidance_worker.prepare_input( + negative_excute_model_req + ) + if inputs is None: + assert negative_inputs is None + return None + + # get root models's logits + condition_logits = self.root_worker.execute_model_part(inputs) + # get unconditional logits + unconditional_logits = self.guidance_worker.execute_model_part( + negative_inputs + ) + + # do classifier free guidance logist process + model_input, _ = inputs + if condition_logits is not None: + for seq_group in model_input.sampling_metadata.seq_groups: + seq_ids = seq_group.seq_ids + guidance_scale = seq_group.sampling_params.guidance_scale + if guidance_scale == 1.0: + break + for seq_id, logits_row_idx in zip( + seq_ids, seq_group.sample_indices + ): + logits_row = torch.nn.functional.log_softmax( + condition_logits[logits_row_idx], dim=-1 + ) + unconditional_logits_row = torch.nn.functional.log_softmax( + unconditional_logits[logits_row_idx], dim=-1 + ) + condition_logits[logits_row_idx] = ( + guidance_scale * (logits_row - unconditional_logits_row) + + unconditional_logits_row + ) + + # do logist_processor + scores = self.root_worker.compute_logits(condition_logits, model_input) + if not self.is_driver_worker: + return [] + + # do sample + output = self.root_worker.do_sample(scores, model_input) + + if not get_pp_group().is_last_rank: + # output is IntermediateTensors + get_pp_group().send_tensor_dict( + output.tensors, all_gather_group=get_tp_group() + ) + return [None] + + # output is List[SamplerOutput] + return output + + def get_cache_block_size_bytes(self): + raise NotImplementedError diff --git a/aphrodite/cfg/separated_worker.py b/aphrodite/cfg/separated_worker.py new file mode 100644 index 000000000..121a8807e --- /dev/null +++ b/aphrodite/cfg/separated_worker.py @@ -0,0 +1,77 @@ +from typing import List, Optional, Tuple + +import torch + +from aphrodite.common.sequence import IntermediateTensors, SamplerOutput +from aphrodite.distributed import get_pp_group, get_tp_group +from aphrodite.task_handler.model_runner import ( + ModelInputForGPUWithSamplingMetadata) +from aphrodite.task_handler.model_runner_base import BroadcastableModelInput +from aphrodite.task_handler.worker import Worker +from aphrodite.task_handler.worker_base import WorkerInput + + +class SeparatedWorker(Worker): + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @torch.inference_mode() + def get_logits( + self, + hidden_or_intermediate_states: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> torch.Tensor: + return self.model_runner.get_logits( + hidden_or_intermediate_states, model_input) + + @torch.inference_mode() + def compute_logits( + self, + logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> torch.Tensor: + return self.model_runner.compute_logits(logits, model_input) + + @torch.inference_mode() + def do_sample( + self, + logits: torch.Tensor, + model_input: ModelInputForGPUWithSamplingMetadata, + ) -> List[SamplerOutput]: + return self.model_runner.do_sample(logits, model_input) + + @torch.inference_mode() + def execute_model_part( + self, + inputs: Tuple[BroadcastableModelInput, WorkerInput], + ) -> Optional[List[SamplerOutput]]: + + model_input, worker_input = inputs + num_steps = worker_input.num_steps + + self.execute_worker(worker_input) + + # If there is no input, we don't need to execute the model. + if worker_input.num_seq_groups == 0: + return [] + + intermediate_tensors = None + if not get_pp_group().is_first_rank: + intermediate_tensors = IntermediateTensors( + get_pp_group().recv_tensor_dict(all_gather_group=get_tp_group())) + + hidden_or_intermediate_states = self.model_runner.model_execute( + model_input, + self.kv_cache[worker_input.virtual_engine] + if self.kv_cache is not None else None, + intermediate_tensors, + num_steps + ) + + # Compute the logits in the last pipeline stage. + if not get_pp_group().is_last_rank: + return hidden_or_intermediate_states + + logits = self.get_logits(hidden_or_intermediate_states, model_input) + + return logits diff --git a/aphrodite/common/config.py b/aphrodite/common/config.py index ffd4602b2..591618af3 100644 --- a/aphrodite/common/config.py +++ b/aphrodite/common/config.py @@ -1555,6 +1555,43 @@ def __repr__(self) -> str: return f"SpeculativeConfig({draft_model=}, {num_spec_tokens=})" +class CFGConfig: + @staticmethod + def maybe_create_spec_config( + target_model_config: ModelConfig, + target_parallel_config: ParallelConfig, + guidance_model: Optional[str], + ): + if guidance_model is None: + return None + + guidance_parallel_config = target_parallel_config + assert target_model_config.model == guidance_model + guidance_model_config = target_model_config + + return CFGConfig( + guidance_model_config, + guidance_parallel_config + ) + + def __init__( + self, + guidance_model_config: ModelConfig, + guidance_parallel_config: ParallelConfig, + ): + self.guidance_model_config = guidance_model_config + self.guidance_parallel_config = guidance_parallel_config + + def _verify_args(self) -> None: + if self.guidance_model_config: + self.guidance_model_config.verify_with_parallel_config( + self.guidance_parallel_config) + + def __repr__(self) -> str: + guidance_model = self.guidance_model_config.model + return f"CFGConfig({guidance_model=})" + + @dataclass class LoRAConfig: max_lora_rank: int @@ -1877,6 +1914,7 @@ class EngineConfig: speculative_config: Optional[SpeculativeConfig] decoding_config: Optional[DecodingConfig] prompt_adapter_config: Optional[PromptAdapterConfig] + cfg_config: Optional[CFGConfig] def __post_init__(self): """Verify configs are valid & consistent with each other. diff --git a/aphrodite/common/sampling_params.py b/aphrodite/common/sampling_params.py index 08ef8af8f..d52851d99 100644 --- a/aphrodite/common/sampling_params.py +++ b/aphrodite/common/sampling_params.py @@ -175,6 +175,7 @@ class SamplingParams( Defaults to None. skew: Bias the token selection towards higher or lower probability tokens. Defaults to 0 (disabled). + guidance_scale: The scale of CFG guidance to apply. """ n: int = 1 @@ -227,6 +228,7 @@ class SamplingParams( dry_allowed_length: int = 2 dry_sequence_breaker_ids: List[int] = [] skew: float = 0.0 + guidance_scale: Optional[float] = None # The below fields are not supposed to be used as an input. # They are set in post_init. output_text_buffer_length: int = 0 @@ -279,6 +281,7 @@ class SamplingParams( "dry_allowed_length": 2, "dry_sequence_breaker_ids": [], "skew": 0.0, + "guidance_scale": None, } def __post_init__(self) -> None: diff --git a/aphrodite/common/sequence.py b/aphrodite/common/sequence.py index d43b0e02c..238184ae3 100644 --- a/aphrodite/common/sequence.py +++ b/aphrodite/common/sequence.py @@ -330,6 +330,7 @@ def __init__( lora_request: Optional[LoRARequest] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, from_decoder_prompt: bool = True, + from_negative_prompt: bool = False, ) -> None: self.seq_id = seq_id self.inputs = inputs @@ -338,6 +339,7 @@ def __init__( self.lora_request = lora_request self.prompt_adapter_request = prompt_adapter_request self.from_decoder_prompt = from_decoder_prompt + self.from_negative_prompt = from_negative_prompt self._prompt: Optional[str] = None self._prompt_token_ids: Optional[List[int]] = None @@ -395,8 +397,12 @@ def prompt(self) -> Optional[str]: # Select decoder or encoder input prompt str, # as appropriate - prompt_key: str = ("prompt" - if self.from_decoder_prompt else "encoder_prompt") + prompt_key: str = "prompt" + if not self.from_decoder_prompt: + prompt_key = "encoder_prompt" + if self.from_negative_prompt: + assert self.from_decoder_prompt is True + prompt_key = "negative_prompt" # Cache prompt self._prompt = cast(Optional[str], self.inputs.get(prompt_key)) @@ -410,9 +416,12 @@ def prompt_token_ids(self) -> List[int]: # Select decoder or encoder input prompt # token ids, as appropriate - prompt_token_ids_key: str = ("prompt_token_ids" - if self.from_decoder_prompt else - "encoder_prompt_token_ids") + prompt_token_ids_key: str = "prompt_token_ids" + if not self.from_decoder_prompt: + "encoder_prompt_token_ids" + if self.from_negative_prompt: + assert self.from_decoder_prompt is True + prompt_token_ids_key = "negative_prompt_token_ids" # Cache computed prompt token ids self._prompt_token_ids = cast(List[int], @@ -476,6 +485,9 @@ def get_output_len(self) -> int: def get_token_ids(self) -> List[int]: return self.data.get_token_ids() + def get_negative_token_ids(self) -> List[int]: + return self.data.get_negative_token_ids() + def get_prompt_token_ids(self) -> Tuple[int, ...]: return self.data.get_prompt_token_ids() @@ -532,7 +544,8 @@ def is_prefill(self) -> bool: def __repr__(self) -> str: return (f"Sequence(seq_id={self.seq_id}, " f"status={self.status.name}, " - f"num_blocks={self.n_blocks}, ") + f"num_blocks={self.n_blocks}, " + f"data={self.data})") class SequenceGroupState( @@ -576,6 +589,7 @@ def __init__( embeddings: Optional[List[float]] = None, pooling_params: Optional[PoolingParams] = None, encoder_seq: Optional[Sequence] = None, + negative_seq: Optional[Sequence] = None, prompt_adapter_request: Optional[PromptAdapterRequest] = None, ) -> None: self.request_id = request_id @@ -596,6 +610,9 @@ def __init__( self.prompt_adapter_request = prompt_adapter_request self.encoder_seq = encoder_seq + assert self.is_single_seq is True + self.negative_seq = negative_seq + @property def prompt(self) -> Optional[str]: # All sequences in the group should have the same prompt. @@ -624,6 +641,22 @@ def encoder_prompt_token_ids(self) -> Optional[List[int]]: return (self.encoder_seq.prompt_token_ids if self.encoder_seq is not None else None) + @property + def negative_prompt(self) -> Optional[str]: + # There are either 0 or 1 negative sequences + # We use the prompt of an arbitrary sequence. + assert self.is_single_seq is True + return (self.negative_seq.prompt + if self.negative_seq is not None else None) + + @property + def negative_prompt_token_ids(self) -> List[int]: + # All sequences in the group should have the same prompt. + # We use the prompt of an arbitrary sequence. + assert self.is_single_seq is True + return (self.negative_seq.prompt_token_ids + if self.negative_seq is not None else None) + @property def multi_modal_data(self) -> "MultiModalDataDict": # All sequences in the group should have the same multi-modal data. @@ -723,6 +756,12 @@ def is_encoder_decoder(self) -> bool: def get_encoder_seq(self) -> Optional[Sequence]: return self.encoder_seq + def has_negative_prompt(self) -> bool: + return self.negative_seq is not None + + def get_negative_seq(self) -> Optional[Sequence]: + return self.negative_seq + def get_unfinished_seqs(self) -> List[Sequence]: if self.is_single_seq: return self.seqs if not self.seqs[0].is_finished() else [] @@ -921,6 +960,8 @@ class SequenceGroupMetadata( multi_modal_data: Optional[Any] = None encoder_seq_data: Optional[SequenceData] = None cross_block_table: Optional[List[int]] = None + negative_seq_data: Optional[SequenceData] = None + negative_block_table: Optional[List[int]] = None prompt_adapter_request: Optional[PromptAdapterRequest] = None token_chunk_size: Optional[int] = None diff --git a/aphrodite/engine/aphrodite_engine.py b/aphrodite/engine/aphrodite_engine.py index 3762c2db5..1714d0c17 100644 --- a/aphrodite/engine/aphrodite_engine.py +++ b/aphrodite/engine/aphrodite_engine.py @@ -9,9 +9,9 @@ from transformers import PreTrainedTokenizer from typing_extensions import assert_never -from aphrodite.common.config import (CacheConfig, DecodingConfig, DeviceConfig, - EngineConfig, LoadConfig, LoRAConfig, - ModelConfig, ParallelConfig, +from aphrodite.common.config import (CacheConfig, CFGConfig, DecodingConfig, + DeviceConfig, EngineConfig, LoadConfig, + LoRAConfig, ModelConfig, ParallelConfig, PromptAdapterConfig, SchedulerConfig, SpeculativeConfig) from aphrodite.common.logger import setup_logger @@ -70,9 +70,11 @@ def _load_generation_config_dict(model_config: ModelConfig) -> Dict[str, Any]: _O = TypeVar("_O", RequestOutput, EmbeddingRequestOutput) PromptComponents = Tuple[Optional[str], List[int], - Optional[MultiModalDataDict]] + Optional[MultiModalDataDict], + Optional[None], Optional[None]] DecoderPromptComponents = Tuple[Optional[str], Optional[List[int]], - Optional[MultiModalDataDict]] + Optional[MultiModalDataDict], + Optional[None], Optional[None]] class AphroditeEngine: @@ -171,6 +173,7 @@ def __init__( speculative_config: Optional[SpeculativeConfig], decoding_config: Optional[DecodingConfig], prompt_adapter_config: Optional[PromptAdapterConfig], + cfg_config: Optional[CFGConfig], executor_class: Type[ExecutorBase], log_stats: bool, stat_loggers: Optional[Dict[str, StatLoggerBase]] = None, @@ -185,6 +188,7 @@ def __init__( config_dict = { "Model": model_config.model, "Speculative Config": speculative_config, + "CFG Config": cfg_config, "DataType": model_config.dtype, "Model Load Format": load_config.load_format, "Tensor Parallel Size": parallel_config.tensor_parallel_size, @@ -233,6 +237,7 @@ def __init__( self.load_config = load_config self.decoding_config = decoding_config or DecodingConfig() self.prompt_adapter_config = prompt_adapter_config + self.cfg_config = cfg_config self.log_stats = log_stats if not self.model_config.skip_tokenizer_init: @@ -269,6 +274,7 @@ def get_tokenizer_for_seq(sequence: Sequence) -> PreTrainedTokenizer: speculative_config=speculative_config, load_config=load_config, prompt_adapter_config=prompt_adapter_config, + cfg_config=cfg_config, ) if not self.model_config.embedding_mode: @@ -533,6 +539,16 @@ def _add_processed_request( seq = Sequence(seq_id, processed_inputs, block_size, eos_token_id, lora_request, prompt_adapter_request) + negative_seq = None + if 'negative_prompt_token_ids' in processed_inputs: + negative_seq = Sequence(seq_id, + processed_inputs, + block_size, + eos_token_id, + lora_request, + prompt_adapter_request, + from_negative_prompt=True) + encoder_seq = None if 'encoder_prompt_token_ids' in processed_inputs: encoder_seq = Sequence(seq_id, @@ -553,6 +569,7 @@ def _add_processed_request( lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, encoder_seq=encoder_seq, + negative_seq=negative_seq, ) elif isinstance(params, PoolingParams): seq_group = self._create_sequence_group_with_pooling( @@ -661,6 +678,8 @@ def _extract_prompt_components( lora_request=lora_request, ) multi_modal_data = None + negative_prompt = None + negative_prompt_token_ids = None elif isinstance(inputs, dict): if "prompt_token_ids" in inputs: prompt = None @@ -674,11 +693,27 @@ def _extract_prompt_components( lora_request=lora_request, ) + if "negative_prompt_token_ids" in inputs: + negative_prompt = None + negative_prompt_token_ids = inputs["negative_prompt_token_ids"] + elif "negative_prompt" in inputs: + negative_prompt = parsed_negative_prompt = inputs[ + "negative_prompt"] + negative_prompt_token_ids = self._tokenize_prompt( + parsed_negative_prompt, + request_id=request_id, + lora_request=lora_request, + ) + else: + negative_prompt = None + negative_prompt_token_ids = None + multi_modal_data = inputs.get("multi_modal_data") else: assert_never(inputs) - return prompt, prompt_token_ids, multi_modal_data + return (prompt, prompt_token_ids, multi_modal_data, + negative_prompt, negative_prompt_token_ids) def _apply_prompt_adapter( self, @@ -728,8 +763,10 @@ def _build_enc_dec_llm_inputs( encoder_comps: PromptComponents, decoder_comps: DecoderPromptComponents, ) -> EncoderDecoderLLMInputs: - encoder_prompt, encoder_prompt_ids, encoder_mm_data = encoder_comps - decoder_prompt, decoder_prompt_ids, decoder_mm_data = decoder_comps + encoder_prompt, encoder_prompt_ids, encoder_mm_data, \ + encoder_negative_prompt, encoder_negative_prompt_ids = encoder_comps + decoder_prompt, decoder_prompt_ids, decoder_mm_data, \ + decoder_negative_prompt, decoder_negative_prompt_ids= decoder_comps if encoder_mm_data is not None or decoder_mm_data is not None: raise ValueError("Multi-modal encoder-decoder models are " @@ -737,12 +774,18 @@ def _build_enc_dec_llm_inputs( decoder_prompt_ids = ( self._prepare_decoder_input_ids_for_generation(decoder_prompt_ids)) + decoder_negative_prompt_ids = ( + self._prepare_decoder_input_ids_for_generation(decoder_negative_prompt_ids)) return EncoderDecoderLLMInputs( prompt_token_ids=decoder_prompt_ids, prompt=decoder_prompt, + negative_prompt_token_ids=decoder_negative_prompt_ids, + negative_prompt=decoder_negative_prompt, encoder_prompt_token_ids=encoder_prompt_ids, encoder_prompt=encoder_prompt, + encoder_negative_prompt_token_ids=encoder_negative_prompt_ids, + encoder_negative_prompt=encoder_negative_prompt, ) def _process_encoder_decoder_prompt( @@ -787,7 +830,7 @@ def _process_encoder_decoder_prompt( ) if (decoder_input := inputs["decoder_prompt"]) is None: - decoder_comps = None, None, None + decoder_comps = None, None, None, None, None else: decoder_comps = self._extract_prompt_components( decoder_input, @@ -799,7 +842,7 @@ def _process_encoder_decoder_prompt( request_id=request_id, ) - decoder_comps = None, None, None + decoder_comps = None, None, None, None, None return self._build_enc_dec_llm_inputs(encoder_comps, decoder_comps) @@ -808,14 +851,17 @@ def _build_decoder_only_llm_inputs( prompt_comps: PromptComponents, prompt_adapter_request: Optional[PromptAdapterRequest], ) -> LLMInputs: - prompt, prompt_token_ids, multi_modal_data = prompt_comps + prompt, prompt_token_ids, multi_modal_data, \ + negative_prompt, negative_prompt_token_ids = prompt_comps prompt_token_ids = self._apply_prompt_adapter( prompt_token_ids, prompt_adapter_request=prompt_adapter_request) return LLMInputs(prompt_token_ids=prompt_token_ids, prompt=prompt, - multi_modal_data=multi_modal_data) + multi_modal_data=multi_modal_data, + negative_prompt_token_ids=negative_prompt_token_ids, + negative_prompt=negative_prompt) def _process_decoder_only_prompt( self, @@ -960,6 +1006,7 @@ def _create_sequence_group_with_sampling( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, + negative_seq: Optional[Sequence] = None, ) -> SequenceGroup: """Creates a SequenceGroup with SamplingParams.""" max_logprobs = self.get_model_config().max_logprobs @@ -984,7 +1031,8 @@ def _create_sequence_group_with_sampling( sampling_params=sampling_params, lora_request=lora_request, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + negative_seq=negative_seq) return seq_group @@ -997,6 +1045,7 @@ def _create_sequence_group_with_pooling( lora_request: Optional[LoRARequest], prompt_adapter_request: Optional[PromptAdapterRequest] = None, encoder_seq: Optional[Sequence] = None, + negative_seq: Optional[Sequence] = None, ) -> SequenceGroup: """Creates a SequenceGroup with PoolingParams.""" # Defensive copy of PoolingParams, which are used by the pooler @@ -1009,7 +1058,8 @@ def _create_sequence_group_with_pooling( lora_request=lora_request, pooling_params=pooling_params, prompt_adapter_request=prompt_adapter_request, - encoder_seq=encoder_seq) + encoder_seq=encoder_seq, + negative_seq=negative_seq) return seq_group diff --git a/aphrodite/engine/args_tools.py b/aphrodite/engine/args_tools.py index bb8028e56..ee16ecf12 100644 --- a/aphrodite/engine/args_tools.py +++ b/aphrodite/engine/args_tools.py @@ -8,12 +8,12 @@ from loguru import logger -from aphrodite.common.config import (CacheConfig, ConfigFormat, DecodingConfig, - DeviceConfig, EngineConfig, LoadConfig, - LoadFormat, LoRAConfig, ModelConfig, - ParallelConfig, PromptAdapterConfig, - SchedulerConfig, SpeculativeConfig, - TokenizerPoolConfig) +from aphrodite.common.config import (CacheConfig, CFGConfig, ConfigFormat, + DecodingConfig, DeviceConfig, + EngineConfig, LoadConfig, LoadFormat, + LoRAConfig, ModelConfig, ParallelConfig, + PromptAdapterConfig, SchedulerConfig, + SpeculativeConfig, TokenizerPoolConfig) from aphrodite.common.utils import FlexibleArgumentParser, is_cpu from aphrodite.executor.executor_base import ExecutorBase from aphrodite.quantization import QUANTIZATION_METHODS @@ -149,6 +149,8 @@ class EngineArgs: max_prompt_adapter_token: int = 0 # Log Options disable_log_stats: bool = False + # Classifier-Free-Guidance (CFG) options + cfg_model: Optional[str] = None def __post_init__(self): if self.tokenizer is None: @@ -855,6 +857,14 @@ def add_cli_args(parser: FlexibleArgumentParser) -> FlexibleArgumentParser: "disable logging statistics", ) + # CFG Options + parser.add_argument( + "--cfg-model", + type=str, + default=EngineArgs.cfg_model, + help="The name of the model to be used in CFG." + ) + return parser @classmethod @@ -1033,6 +1043,11 @@ def create_engine_config(self, ) -> EngineConfig: if speculative_config is None \ else speculative_config.num_lookahead_slots + cfg_config = CFGConfig.maybe_create_spec_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + guidance_model=self.cfg_model) + scheduler_config = SchedulerConfig( max_num_batched_tokens=self.max_num_batched_tokens, max_num_seqs=self.max_num_seqs, @@ -1099,7 +1114,8 @@ def create_engine_config(self, ) -> EngineConfig: speculative_config=speculative_config, load_config=load_config, decoding_config=decoding_config, - prompt_adapter_config=prompt_adapter_config) + prompt_adapter_config=prompt_adapter_config, + cfg_config=cfg_config) @dataclass diff --git a/aphrodite/engine/async_aphrodite.py b/aphrodite/engine/async_aphrodite.py index af4bc35fa..4263e4935 100644 --- a/aphrodite/engine/async_aphrodite.py +++ b/aphrodite/engine/async_aphrodite.py @@ -437,6 +437,7 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) multi_modal_data = None + negative_prompt = negative_prompt_token_ids = None elif isinstance(inputs, dict): if "prompt_token_ids" in inputs: prompt = None @@ -450,11 +451,26 @@ async def _extract_prompt_components_async( lora_request=lora_request, ) + if "negative_prompt_token_ids" in inputs: + negative_prompt = None + negative_prompt_token_ids = inputs["negative_prompt_token_ids"] + elif "negative_prompt" in inputs: + negative_prompt = parsed_negative_prompt = inputs[ + "negative_prompt"] + negative_prompt_token_ids = await self._tokenize_prompt_async( + parsed_negative_prompt, + request_id=request_id, + lora_request=lora_request, + ) + else: + negative_prompt = negative_prompt_token_ids = None + multi_modal_data = inputs.get("multi_modal_data") else: assert_never(inputs) - return prompt, prompt_token_ids, multi_modal_data + return (prompt, prompt_token_ids, multi_modal_data, + negative_prompt, negative_prompt_token_ids) async def _process_encoder_decoder_prompt_async( self, diff --git a/aphrodite/engine/output_processor/single_step.py b/aphrodite/engine/output_processor/single_step.py index 5f26a89cc..981a2733e 100644 --- a/aphrodite/engine/output_processor/single_step.py +++ b/aphrodite/engine/output_processor/single_step.py @@ -84,6 +84,8 @@ def _process_sequence_group_outputs(self, seq_group: SequenceGroup, # only have one sequence seq = seq_group.seqs[0] seq.append_token_id(sample.output_token, sample.logprobs) + negative_seq = seq_group.negative_seq + negative_seq.append_token_id(sample.output_token, sample.logprobs) if sampling_params.detokenize and self.detokenizer: new_char_count = self.detokenizer.decode_sequence_inplace( seq, sampling_params) diff --git a/aphrodite/executor/executor_base.py b/aphrodite/executor/executor_base.py index 6ab9137e5..93dd4ce76 100644 --- a/aphrodite/executor/executor_base.py +++ b/aphrodite/executor/executor_base.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod from typing import List, Optional, Set, Tuple -from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig, - LoRAConfig, ModelConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from aphrodite.common.config import (CacheConfig, CFGConfig, DeviceConfig, + LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig) from aphrodite.common.sequence import ExecuteModelRequest, SamplerOutput from aphrodite.lora.request import LoRARequest from aphrodite.prompt_adapter.request import PromptAdapterRequest @@ -31,6 +31,7 @@ def __init__( lora_config: Optional[LoRAConfig], speculative_config: Optional[SpeculativeConfig], prompt_adapter_config: Optional[PromptAdapterConfig], + cfg_config: Optional[CFGConfig], ) -> None: self.model_config = model_config self.cache_config = cache_config @@ -41,6 +42,7 @@ def __init__( self.device_config = device_config self.speculative_config = speculative_config self.prompt_adapter_config = prompt_adapter_config + self.cfg_config = cfg_config self._init_executor() diff --git a/aphrodite/executor/gpu_executor.py b/aphrodite/executor/gpu_executor.py index eaa0dfbbf..04f2bc964 100644 --- a/aphrodite/executor/gpu_executor.py +++ b/aphrodite/executor/gpu_executor.py @@ -57,6 +57,7 @@ def _get_worker_kwargs( lora_config=self.lora_config, speculative_config=self.speculative_config, prompt_adapter_config=self.prompt_adapter_config, + cfg_config=self.cfg_config, is_driver_worker=(not self.parallel_config) or (rank % self.parallel_config.tensor_parallel_size == 0), ) @@ -76,6 +77,10 @@ def _get_create_worker_kwargs( worker_kwargs.update( worker_module_name="aphrodite.spec_decode.spec_decode_worker", worker_class_name="create_spec_worker") + elif self.cfg_config: + worker_kwargs.update( + worker_module_name="aphrodite.cfg.cfg_worker", + worker_class_name="create_cfg_worker") else: worker_kwargs.update( worker_module_name="aphrodite.task_handler.worker", diff --git a/aphrodite/inputs/data.py b/aphrodite/inputs/data.py index 2c298cba3..4bb50820c 100644 --- a/aphrodite/inputs/data.py +++ b/aphrodite/inputs/data.py @@ -33,7 +33,22 @@ class TokensPrompt(TypedDict): """ -SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt] +class NegativeTextPrompt(TypedDict): + """Schema for a text prompt.""" + + negative_prompt: str + """The input text to be tokenized before passing to the model.""" + + +class NegativeTokensPrompt(TypedDict): + """Schema for a tokenized prompt.""" + + negative_prompt_token_ids: List[int] + """A list of token IDs to pass to the model.""" + + +SingletonPromptInputs = Union[str, TextPrompt, TokensPrompt, + NegativeTextPrompt, NegativeTokensPrompt] """ Set of possible schemas for a single LLM input: - A text prompt (:class:`str` or :class:`TextPrompt`) @@ -116,6 +131,9 @@ class LLMInputs(TypedDict): if the model supports it. """ + negative_prompt_token_ids: NotRequired[Optional[List[int]]] + negative_prompt: NotRequired[Optional[str]] + class EncoderDecoderLLMInputs(LLMInputs): """ @@ -132,6 +150,9 @@ class EncoderDecoderLLMInputs(LLMInputs): available. """ + encoder_negative_prompt_token_ids: NotRequired[Optional[List[int]]] + encoder_negative_prompt: NotRequired[Optional[str]] + _T1 = TypeVar("_T1", bound=SingletonPromptInputs, diff --git a/aphrodite/modeling/models/llama.py b/aphrodite/modeling/models/llama.py index abbb13f0b..b42f9f761 100644 --- a/aphrodite/modeling/models/llama.py +++ b/aphrodite/modeling/models/llama.py @@ -34,13 +34,15 @@ from aphrodite.distributed import (get_current_tp_rank_partition_size, get_pp_group, get_tensor_model_parallel_rank, - get_tensor_model_parallel_world_size) + get_tensor_model_parallel_world_size, + tensor_model_parallel_gather) from aphrodite.modeling.layers.activation import SiluAndMul from aphrodite.modeling.layers.layernorm import RMSNorm from aphrodite.modeling.layers.linear import (MergedColumnParallelLinear, QKVParallelLinear, RowParallelLinear) -from aphrodite.modeling.layers.logits_processor import LogitsProcessor +from aphrodite.modeling.layers.logits_processor import (LogitsProcessor, + _prune_hidden_states) from aphrodite.modeling.layers.rotary_embedding import get_rope from aphrodite.modeling.layers.sampler import Sampler from aphrodite.modeling.layers.vocab_parallel_embedding import ( @@ -429,7 +431,9 @@ def __init__( logit_scale = getattr(config, "logit_scale", 1.0) self.logits_processor = LogitsProcessor(self.unpadded_vocab_size, config.vocab_size, - logit_scale) + logit_scale, + logits_as_input=True) + self.org_vocab_size = config.vocab_size self.sampler = Sampler() else: self.lm_head = PPMissingLayer() @@ -446,6 +450,23 @@ def forward( attn_metadata, intermediate_tensors) return model_output + def _get_logits(self, + hidden_states: torch.Tensor, + sampling_metadata: SamplingMetadata, + ) -> torch.Tensor: + hidden_states = _prune_hidden_states(hidden_states, sampling_metadata) + # Get the logits for the next tokens. + logits = self.lm_head.linear_method.apply( + self.lm_head, + hidden_states, + bias=None, + ) + logits = tensor_model_parallel_gather(logits) + # Remove paddings in vocab (if any). + if logits is not None: + logits = logits[:, :self.org_vocab_size] + return logits + def compute_logits( self, hidden_states: torch.Tensor, diff --git a/aphrodite/processing/block_manager_v2.py b/aphrodite/processing/block_manager_v2.py index 38154fa79..f01a1ac45 100644 --- a/aphrodite/processing/block_manager_v2.py +++ b/aphrodite/processing/block_manager_v2.py @@ -17,6 +17,7 @@ from aphrodite.processing.interfaces import AllocStatus, BlockSpaceManager SeqId = int +NegativeSeqId = str EncoderSeqId = str @@ -98,6 +99,7 @@ def __init__( ) self.block_tables: Dict[SeqId, BlockTable] = {} + self.negative_block_tables: Dict[NegativeSeqId, BlockTable] = {} self.cross_block_tables: Dict[EncoderSeqId, BlockTable] = {} self._computed_blocks_tracker = ComputedBlocksTracker( @@ -123,6 +125,11 @@ def can_allocate(self, seq_group: SequenceGroup) -> AllocStatus: block_size=self.block_size, ) + if seq_group.has_negative_prompt(): + num_required_blocks += BlockTable.get_num_required_blocks( + seq_group.get_negative_seq().get_token_ids(), + block_size=self.block_size) + if self.max_block_sliding_window is not None: num_required_blocks = min(num_required_blocks, self.max_block_sliding_window) @@ -183,6 +190,15 @@ def allocate(self, seq_group: SequenceGroup) -> None: assert (request_id not in self.cross_block_tables), \ "block table already exists" + assert (request_id + not in self.negative_block_tables), \ + "block table already exists" + + if seq_group.has_negative_prompt(): + block_table = self._allocate_sequence( + seq_group.get_negative_seq()) + self.negative_block_tables[request_id] = block_table + check_no_caching_or_swa_for_blockmgr_encdec(self, seq_group) @@ -215,6 +231,15 @@ def can_append_slots(self, seq_group: SequenceGroup, num_lookahead_slots=num_lookahead_slots, )) + negative_block_table = self.negative_block_tables[ + seq_group.request_id] + num_touched_blocks += ( + negative_block_table.get_num_blocks_touched_by_append_slots( + token_ids=negative_block_table.get_unseen_token_ids( + seq_group.get_negative_seq().get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + )) + num_free_gpu_blocks = self.block_allocator.get_num_free_blocks( Device.GPU) return num_touched_blocks <= num_free_gpu_blocks @@ -223,6 +248,7 @@ def append_slots( self, seq: Sequence, num_lookahead_slots: int, + seq_group: SequenceGroup, ) -> List[Tuple[int, int]]: block_table = self.block_tables[seq.seq_id] @@ -232,6 +258,15 @@ def append_slots( num_lookahead_slots=num_lookahead_slots, num_computed_slots=seq.data.get_num_computed_tokens(), ) + + negative_block_table = self.negative_block_tables[seq_group.request_id] + negative_seq = seq_group.negative_seq + negative_block_table.append_token_ids( + token_ids=negative_block_table.get_unseen_token_ids( + negative_seq.get_token_ids()), + num_lookahead_slots=num_lookahead_slots, + num_computed_slots=negative_seq.data.get_num_computed_tokens(), + ) # Return any new copy-on-writes. new_cows = self.block_allocator.clear_copy_on_writes() return new_cows @@ -263,6 +298,13 @@ def free_cross(self, seq_group: SequenceGroup) -> None: self.cross_block_tables[request_id].free() del self.cross_block_tables[request_id] + def free_negative(self, seq_group: SequenceGroup) -> None: + request_id = seq_group.request_id + if request_id not in self.negative_block_tables: + return + self.negative_block_tables[request_id].free() + del self.negative_block_tables[request_id] + def get_block_table(self, seq: Sequence) -> List[int]: block_ids = self.block_tables[seq.seq_id].physical_block_ids return block_ids # type: ignore @@ -274,6 +316,13 @@ def get_cross_block_table(self, seq_group: SequenceGroup) -> List[int]: assert all(b is not None for b in block_ids) return block_ids # type: ignore + def get_negative_block_table(self, seq_group: SequenceGroup) -> List[int]: + request_id = seq_group.request_id + assert request_id in self.negative_block_tables + block_ids = self.negative_block_tables[request_id].physical_block_ids + assert all(b is not None for b in block_ids) + return block_ids + def access_all_blocks_in_seq(self, seq: Sequence, now: float): if self.enable_caching: # Record the latest access time for the sequence. The actual update diff --git a/aphrodite/processing/scheduler.py b/aphrodite/processing/scheduler.py index 5159265df..e4a866f66 100644 --- a/aphrodite/processing/scheduler.py +++ b/aphrodite/processing/scheduler.py @@ -441,6 +441,7 @@ def abort_seq_group(self, request_id: Union[str, Iterable[str]]) -> None: self.free_seq(seq) self._free_seq_group_cross_attn_blocks(aborted_group) + self._free_seq_group_negative_blocks(aborted_group) def _free_seq_group_cross_attn_blocks( self, @@ -453,6 +454,13 @@ def _free_seq_group_cross_attn_blocks( if seq_group.is_encoder_decoder(): self.block_manager.free_cross(seq_group) + def _free_seq_group_negative_blocks( + self, + seq_group: SequenceGroup, + ) -> None: + if seq_group.has_negative_prompt(): + self.block_manager.free_negative(seq_group) + def has_unfinished_seqs(self) -> bool: return len(self.waiting) != 0 or len(self.running) != 0 or len( self.swapped) != 0 @@ -1036,7 +1044,8 @@ def _can_append_slots(self, seq_group: SequenceGroup) -> bool: return self.block_manager.can_append_slots( seq_group=seq_group, - num_lookahead_slots=self._get_num_lookahead_slots(is_prefill), + num_lookahead_slots=self._get_num_lookahead_slots( + is_prefill, seq_group), ) def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: @@ -1073,6 +1082,14 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: encoder_seq_data = None cross_block_table = None + if seq_group.has_negative_prompt(): + negative_seq_data = seq_group.get_negative_seq().data + negative_block_table = ( + self.block_manager.get_negative_block_table(seq_group)) + else: + negative_seq_data = None + negative_block_table = None + for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): seq_id = seq.seq_id seq_data[seq_id] = seq.data @@ -1120,6 +1137,8 @@ def schedule(self) -> Tuple[List[SequenceGroupMetadata], SchedulerOutputs]: computed_block_nums=common_computed_block_nums, encoder_seq_data=encoder_seq_data, cross_block_table=cross_block_table, + negative_seq_data=negative_seq_data, + negative_block_table=negative_block_table, state=seq_group.state, # `multi_modal_data` will only be present for the 1st comm # between engine and worker. @@ -1169,6 +1188,7 @@ def free_finished_seq_groups(self) -> None: if seq_group.is_finished(): # Free cross-attention block table, if it exists self._free_seq_group_cross_attn_blocks(seq_group) + self._free_seq_group_negative_blocks(seq_group) # Add the finished requests to the finished requests list. # This list will be used to update the Mamba cache in the # next step. @@ -1198,11 +1218,14 @@ def _append_slots( the new source and destination block indices for the appended slots. """ - num_lookahead_slots = self._get_num_lookahead_slots(is_prefill=False) + num_lookahead_slots = self._get_num_lookahead_slots( + is_prefill=False, seq_group=seq_group) seq_group.init_multi_step(num_scheduler_steps=num_lookahead_slots + 1) for seq in seq_group.get_seqs(status=SequenceStatus.RUNNING): - cows = self.block_manager.append_slots(seq, num_lookahead_slots) + cows = self.block_manager.append_slots(seq, num_lookahead_slots, + seq_group) + assert len(cows) == 0 if len(cows) > 0: blocks_to_copy.extend(cows) @@ -1313,7 +1336,9 @@ def _passed_delay(self, now: float) -> bool: passed_delay = True return passed_delay - def _get_num_lookahead_slots(self, is_prefill: bool) -> int: + def _get_num_lookahead_slots(self, is_prefill: bool, + seq_group: Optional[SequenceGroup] = None + ) -> int: """The number of slots to allocate per sequence per step, beyond known token ids. Speculative decoding uses these slots to store KV activations of tokens which may or may not be accepted. diff --git a/aphrodite/task_handler/model_runner.py b/aphrodite/task_handler/model_runner.py index c29bc39b8..78bd104a9 100644 --- a/aphrodite/task_handler/model_runner.py +++ b/aphrodite/task_handler/model_runner.py @@ -979,6 +979,9 @@ def load_model(self) -> None: def get_model_memory_usage(self): return self.model_memory_usage + def share_model(self, model: nn.Module) -> None: + self.model = model + def save_sharded_state( self, path: str, diff --git a/aphrodite/task_handler/worker.py b/aphrodite/task_handler/worker.py index 77089c7eb..ad39aaef6 100644 --- a/aphrodite/task_handler/worker.py +++ b/aphrodite/task_handler/worker.py @@ -8,10 +8,10 @@ import torch.distributed from loguru import logger -from aphrodite.common.config import (CacheConfig, DeviceConfig, LoadConfig, - LoRAConfig, ModelConfig, ParallelConfig, - PromptAdapterConfig, SchedulerConfig, - SpeculativeConfig) +from aphrodite.common.config import (CacheConfig, CFGConfig, DeviceConfig, + LoadConfig, LoRAConfig, ModelConfig, + ParallelConfig, PromptAdapterConfig, + SchedulerConfig, SpeculativeConfig) from aphrodite.common.sequence import (ExecuteModelRequest, IntermediateTensors, SamplerOutput, SequenceGroupMetadata, @@ -56,6 +56,7 @@ def __init__( lora_config: Optional[LoRAConfig] = None, speculative_config: Optional[SpeculativeConfig] = None, prompt_adapter_config: Optional[PromptAdapterConfig] = None, + cfg_config: Optional[CFGConfig] = None, is_driver_worker: bool = False, model_runner_cls: Optional[Type[GPUModelRunnerBase]] = None, ) -> None: @@ -70,6 +71,7 @@ def __init__( self.distributed_init_method = distributed_init_method self.lora_config = lora_config self.prompt_adapter_config = prompt_adapter_config + self.cfg_config = cfg_config self.load_config = load_config self.is_driver_worker = is_driver_worker if parallel_config and is_driver_worker: @@ -155,6 +157,9 @@ def init_device(self) -> None: def load_model(self): self.model_runner.load_model() + def share_model(self, shared_worker) -> None: + self.model_runner.share_model(shared_worker.model_runner.model) + def save_sharded_state( self, path: str, From 2242975cfc8ee2c700e46054d81bc4587893dcb1 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Mon, 25 Nov 2024 18:28:43 +0000 Subject: [PATCH 2/5] fix --- aphrodite/cfg/cfg_worker.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/aphrodite/cfg/cfg_worker.py b/aphrodite/cfg/cfg_worker.py index 7aee2c170..7e0cf62a5 100644 --- a/aphrodite/cfg/cfg_worker.py +++ b/aphrodite/cfg/cfg_worker.py @@ -53,7 +53,7 @@ def __init__( assert self.parallel_config.pipeline_parallel_size == 1 def init_device(self): - self.root_task_handler.init_device() + self.root_worker.init_device() self.guidance_worker.init_device() def load_model(self): From 2242cf6fc5b2ad146e27073b715562efec6bd06b Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Mon, 25 Nov 2024 18:38:16 +0000 Subject: [PATCH 3/5] clean-up and example script --- aphrodite/cfg/cfg_model_runner.py | 3 +- aphrodite/cfg/cfg_worker.py | 2 +- examples/offline_inference/cfg_inference.py | 42 +++++++++++++++++++++ 3 files changed, 45 insertions(+), 2 deletions(-) create mode 100644 examples/offline_inference/cfg_inference.py diff --git a/aphrodite/cfg/cfg_model_runner.py b/aphrodite/cfg/cfg_model_runner.py index b171f6d79..f0f0693e0 100644 --- a/aphrodite/cfg/cfg_model_runner.py +++ b/aphrodite/cfg/cfg_model_runner.py @@ -137,7 +137,8 @@ def do_sample( ) if self.return_hidden_states: - raise NotImplementedError("return_hidden_states is not supported in CFGModelRunner") + raise NotImplementedError("return_hidden_states is not supported " + "in CFGModelRunner") return [output] diff --git a/aphrodite/cfg/cfg_worker.py b/aphrodite/cfg/cfg_worker.py index 7e0cf62a5..64f8b5282 100644 --- a/aphrodite/cfg/cfg_worker.py +++ b/aphrodite/cfg/cfg_worker.py @@ -109,7 +109,7 @@ def execute_model( negative_seq_data: Dict[int, SequenceData] = {} negative_block_tables: Dict[int, List[int]] = {} assert len(seq_group_metadata.seq_data) == 1 - for seq_id in seq_group_metadata.seq_data.keys(): + for seq_id in seq_group_metadata.seq_data: negative_seq_data[ seq_id ] = seq_group_metadata.negative_seq_data diff --git a/examples/offline_inference/cfg_inference.py b/examples/offline_inference/cfg_inference.py new file mode 100644 index 000000000..1e274d499 --- /dev/null +++ b/examples/offline_inference/cfg_inference.py @@ -0,0 +1,42 @@ +from typing import List +from aphrodite import LLM, SamplingParams +from aphrodite.inputs import PromptInputs + +llm = LLM( + model="NousResearch/Meta-Llama-3.1-8B-Instruct", + use_v2_block_manager=True, + cfg_model="NousResearch/Meta-Llama-3.1-8B-Instruct", + max_model_len=8192, +) + +prompt_pairs = [ + { + "prompt": "Hello, my name is", + "negative_prompt": "I am uncertain and confused about who I am" + }, + { + "prompt": "The president of the United States is", + "negative_prompt": "I don't know anything about US politics or leadership" + }, +] + +tokenizer = llm.get_tokenizer() + +inputs: List[PromptInputs] = [ + { + "prompt_token_ids": tokenizer.encode(text=pair["prompt"]), + "negative_prompt_token_ids": tokenizer.encode(text=pair["negative_prompt"]) + } + for pair in prompt_pairs +] + +sampling_params = SamplingParams(guidance_scale=5.0, max_tokens=128) +outputs = llm.generate(inputs, sampling_params) + +for i, output in enumerate(outputs): + prompt_pair = prompt_pairs[i] + generated_text = output.outputs[0].text + print(f"Prompt: {prompt_pair['prompt']!r}") + print(f"Negative Prompt: {prompt_pair['negative_prompt']!r}") + print(f"Generated text: {generated_text!r}") + print("-" * 50) \ No newline at end of file From 22425c186b46ba48b0683cce9da8534feb995209 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Mon, 25 Nov 2024 18:40:12 +0000 Subject: [PATCH 4/5] formatting --- examples/offline_inference/cfg_inference.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/examples/offline_inference/cfg_inference.py b/examples/offline_inference/cfg_inference.py index 1e274d499..969e1bbd4 100644 --- a/examples/offline_inference/cfg_inference.py +++ b/examples/offline_inference/cfg_inference.py @@ -1,4 +1,5 @@ from typing import List + from aphrodite import LLM, SamplingParams from aphrodite.inputs import PromptInputs @@ -16,7 +17,7 @@ }, { "prompt": "The president of the United States is", - "negative_prompt": "I don't know anything about US politics or leadership" + "negative_prompt": "I don't know anything about US politics or leadership" # noqa: E501 }, ] @@ -25,7 +26,7 @@ inputs: List[PromptInputs] = [ { "prompt_token_ids": tokenizer.encode(text=pair["prompt"]), - "negative_prompt_token_ids": tokenizer.encode(text=pair["negative_prompt"]) + "negative_prompt_token_ids": tokenizer.encode(text=pair["negative_prompt"]) # noqa: E501 } for pair in prompt_pairs ] From 22425a03f8d26eae6915b07ee52b6391473442a9 Mon Sep 17 00:00:00 2001 From: AlpinDale Date: Mon, 25 Nov 2024 19:01:19 +0000 Subject: [PATCH 5/5] guard against using block manager v1 --- aphrodite/common/config.py | 13 ++++++++++++- aphrodite/engine/args_tools.py | 10 ++++++---- 2 files changed, 18 insertions(+), 5 deletions(-) diff --git a/aphrodite/common/config.py b/aphrodite/common/config.py index 591618af3..e1ef98abf 100644 --- a/aphrodite/common/config.py +++ b/aphrodite/common/config.py @@ -1560,6 +1560,7 @@ class CFGConfig: def maybe_create_spec_config( target_model_config: ModelConfig, target_parallel_config: ParallelConfig, + target_scheduler_config: SchedulerConfig, guidance_model: Optional[str], ): if guidance_model is None: @@ -1567,25 +1568,35 @@ def maybe_create_spec_config( guidance_parallel_config = target_parallel_config assert target_model_config.model == guidance_model + guidance_scheduler_config = target_scheduler_config guidance_model_config = target_model_config return CFGConfig( guidance_model_config, - guidance_parallel_config + guidance_parallel_config, + guidance_scheduler_config, ) def __init__( self, guidance_model_config: ModelConfig, guidance_parallel_config: ParallelConfig, + guidance_scheduler_config: SchedulerConfig, ): self.guidance_model_config = guidance_model_config self.guidance_parallel_config = guidance_parallel_config + self.guidance_scheduler_config = guidance_scheduler_config + self._verify_args() def _verify_args(self) -> None: if self.guidance_model_config: self.guidance_model_config.verify_with_parallel_config( self.guidance_parallel_config) + if not self.guidance_scheduler_config.use_v2_block_manager: + raise ValueError( + "CFG requires usage of the V2 " + "block manager. Enable it with --use-v2-block-manager " + "or use_v2_block_manager=True.") def __repr__(self) -> str: guidance_model = self.guidance_model_config.model diff --git a/aphrodite/engine/args_tools.py b/aphrodite/engine/args_tools.py index ee16ecf12..4d5ab627b 100644 --- a/aphrodite/engine/args_tools.py +++ b/aphrodite/engine/args_tools.py @@ -1043,10 +1043,6 @@ def create_engine_config(self, ) -> EngineConfig: if speculative_config is None \ else speculative_config.num_lookahead_slots - cfg_config = CFGConfig.maybe_create_spec_config( - target_model_config=model_config, - target_parallel_config=parallel_config, - guidance_model=self.cfg_model) scheduler_config = SchedulerConfig( max_num_batched_tokens=self.max_num_batched_tokens, @@ -1064,6 +1060,12 @@ def create_engine_config(self, ) -> EngineConfig: parallel_config.use_ray), ) + cfg_config = CFGConfig.maybe_create_spec_config( + target_model_config=model_config, + target_parallel_config=parallel_config, + target_scheduler_config=scheduler_config, + guidance_model=self.cfg_model) + if not HAS_TRITON and self.enable_lora: raise ValueError("Triton is not installed, LoRA will not work.")