diff --git a/server/lorax_server/models/causal_lm.py b/server/lorax_server/models/causal_lm.py index 927e04c50..a99721a8b 100644 --- a/server/lorax_server/models/causal_lm.py +++ b/server/lorax_server/models/causal_lm.py @@ -1,12 +1,9 @@ -from collections import defaultdict -import json -import torch -import inspect - from dataclasses import dataclass +from typing import Optional, Tuple, List, Type, Dict + +import torch from opentelemetry import trace from transformers import AutoTokenizer, AutoModelForCausalLM, PreTrainedTokenizerBase -from typing import Optional, Tuple, List, Type, Dict from lorax_server.models import Model from lorax_server.models.types import ( @@ -17,9 +14,9 @@ ) from lorax_server.pb import generate_pb2 from lorax_server.utils import NextTokenChooser, StoppingCriteria, Sampling -from lorax_server.utils.tokenizer import TokenizerManager -from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata, BatchedLoraWeights +from lorax_server.utils.lora import AdapterBatchData, AdapterBatchMetadata from lorax_server.utils.segments import SegmentConcatBuilder, find_segments +from lorax_server.utils.tokenizer import TokenizerManager tracer = trace.get_tracer(__name__) @@ -95,7 +92,10 @@ def from_pb( requests_idx_mapping[r.id] = i req_inputs = tokenizers.get_inputs(r, tokenizer) inputs.append(req_inputs) - next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, device)) + next_token_choosers.append(NextTokenChooser.from_pb(r.parameters, + device, + tokenizers.get_tokenizer(adapter_indices_list[i], + tokenizer))) stopping_criteria = StoppingCriteria.from_pb( r.stopping_parameters, tokenizer ) @@ -212,7 +212,7 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: stopping_criteria = self.stopping_criterias[idx] stopping_criterias.append(stopping_criteria) remaining_decode_tokens = ( - stopping_criteria.max_new_tokens - stopping_criteria.current_tokens + stopping_criteria.max_new_tokens - stopping_criteria.current_tokens ) total_remaining_decode_tokens += remaining_decode_tokens new_padding_right_offset = max( @@ -226,12 +226,14 @@ def filter(self, request_ids: List[int]) -> Optional["CausalLMBatch"]: position_ids = self.position_ids[keep_indices] adapter_indices = self.adapter_meta.adapter_indices[keep_indices] self.attention_mask = self.attention_mask[ - keep_indices, - -(self.padding_right_offset + max_input_length) : ( - self.attention_mask.shape[1] - self.padding_right_offset - ) - + new_padding_right_offset, - ] + keep_indices, + -(self.padding_right_offset + max_input_length): ( + self.attention_mask.shape[ + 1] - + self.padding_right_offset + ) + + new_padding_right_offset, + ] # Ensure that past_key_values tensors can be updated in-place if type(self.past_key_values[0]) == tuple: @@ -371,17 +373,17 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # and to remove unused allocated space left_offset = max_input_length - batch.max_input_length batch_left_offset = ( - batch.attention_mask.shape[1] - - batch.max_input_length - - batch.padding_right_offset + batch.attention_mask.shape[1] + - batch.max_input_length + - batch.padding_right_offset ) attention_mask[ - start_index:end_index, - left_offset:-padding_right_offset, + start_index:end_index, + left_offset:-padding_right_offset, ] = batch.attention_mask[ :, - batch_left_offset : -batch.padding_right_offset, - ] + batch_left_offset: -batch.padding_right_offset, + ] # Create empty tensor # position_ids is always of shape [batch_size, 1] @@ -405,7 +407,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # Add eventual padding tokens that were added while concatenating max_tokens += batch.max_tokens + ( - max_input_length - batch.max_input_length + max_input_length - batch.max_input_length ) * len(batch) start_index = end_index @@ -447,12 +449,12 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": past_seq_len = batch.max_input_length - 1 if batch.keys_head_dim_last: padded_past_keys[ - start_index:end_index, :, -past_seq_len:, : + start_index:end_index, :, -past_seq_len:, : ] = past_keys[:, :, -past_seq_len:, :] else: # BLOOM case padded_past_keys[ - start_index:end_index, :, :, -past_seq_len: + start_index:end_index, :, :, -past_seq_len: ] = past_keys[:, :, :, -past_seq_len:] del past_keys @@ -472,7 +474,7 @@ def concatenate(cls, batches: List["CausalLMBatch"]) -> "CausalLMBatch": # We slice the past values to remove the padding from previous batches past_seq_len = batch.max_input_length - 1 padded_past_values[ - start_index:end_index, :, -past_seq_len:, : + start_index:end_index, :, -past_seq_len:, : ] = past_values[:, :, -past_seq_len:, :] del past_values @@ -525,7 +527,7 @@ def __init__( ): if compile: raise ValueError("`--compile` is not supported with CausalLM") - + if torch.cuda.is_available(): device = torch.device("cuda") dtype = torch.float16 if dtype is None else dtype @@ -580,7 +582,7 @@ def __init__( @property def batch_type(self) -> Type[CausalLMBatch]: return CausalLMBatch - + @property def has_adapter_data(self) -> bool: return False @@ -592,19 +594,19 @@ def decode(self, generated_ids: List[int]) -> str: def forward( self, - input_ids, - attention_mask, - position_ids, - past_key_values: Optional = None, + input_ids, + attention_mask, + position_ids, + past_key_values: Optional = None, adapter_data: Optional[AdapterBatchData] = None ) -> Tuple[torch.Tensor, List[Tuple[torch.Tensor, torch.Tensor]]]: # Model Forward kwargs = { - "input_ids": input_ids, - "attention_mask": attention_mask, + "input_ids" : input_ids, + "attention_mask" : attention_mask, "past_key_values": past_key_values, - "use_cache": True, - "return_dict": True, + "use_cache" : True, + "return_dict" : True, } if self.has_position_ids: kwargs["position_ids"] = position_ids @@ -653,14 +655,14 @@ def generate_token( # For each member of the batch for i, ( - request, - input_length, - prefix_offset, - read_offset, - logits, - next_token_chooser, - stopping_criteria, - all_input_ids, + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, ) in enumerate(iterator): # Select next token next_token_id, logprobs = next_token_chooser( @@ -693,7 +695,7 @@ def generate_token( if stop: # Decode generated tokens output_text = self.decode( - all_input_ids[-stopping_criteria.current_tokens :, 0] + all_input_ids[-stopping_criteria.current_tokens:, 0] ) # Get seed if isinstance(next_token_chooser.choice, Sampling): @@ -713,8 +715,8 @@ def generate_token( prefill_logprobs = [float("nan")] + torch.log_softmax( logits, -1 ).gather(1, all_input_ids[1:]).squeeze(1)[ - -new_input_length:-1 - ].tolist() + -new_input_length:-1 + ].tolist() prefill_token_ids = all_input_ids[-new_input_length:-1] prefill_texts = self.tokenizer.batch_decode( prefill_token_ids, diff --git a/server/lorax_server/models/flash_causal_lm.py b/server/lorax_server/models/flash_causal_lm.py index cabf94570..a93c2db3b 100644 --- a/server/lorax_server/models/flash_causal_lm.py +++ b/server/lorax_server/models/flash_causal_lm.py @@ -111,14 +111,14 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizer_mgr: TokenizerManager, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": batch_inputs = [] max_truncation = 0 for r in pb.requests: - inputs = tokenizer_mgr.get_inputs(r, tokenizer) + inputs = tokenizers.get_inputs(r, tokenizer) batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) @@ -240,7 +240,7 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) request_tokenizers = [ - tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer) + tokenizers.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( diff --git a/server/lorax_server/models/flash_mistral.py b/server/lorax_server/models/flash_mistral.py index 2feb9bda8..aabda653d 100644 --- a/server/lorax_server/models/flash_mistral.py +++ b/server/lorax_server/models/flash_mistral.py @@ -57,7 +57,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizer_mgr: TokenizerManager, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": @@ -67,7 +67,7 @@ def from_pb( batch_inputs = [] max_truncation = 0 for r in pb.requests: - inputs = tokenizer_mgr.get_inputs(r, tokenizer) + inputs = tokenizers.get_inputs(r, tokenizer) batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) @@ -206,7 +206,7 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) request_tokenizers = [ - tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer) + tokenizers.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( diff --git a/server/lorax_server/models/flash_mixtral.py b/server/lorax_server/models/flash_mixtral.py index 33ed124cf..675f78385 100644 --- a/server/lorax_server/models/flash_mixtral.py +++ b/server/lorax_server/models/flash_mixtral.py @@ -64,7 +64,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizer_mgr: TokenizerManager, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "FlashCausalLMBatch": @@ -74,7 +74,7 @@ def from_pb( batch_inputs = [] max_truncation = 0 for r in pb.requests: - inputs = tokenizer_mgr.get_inputs(r, tokenizer) + inputs = tokenizers.get_inputs(r, tokenizer) batch_inputs.append(inputs) max_truncation = max(max_truncation, r.truncate) @@ -213,7 +213,7 @@ def from_pb( adapter_indices = torch.cat(adapter_indices_list).to(dtype=torch.int64, device=device) request_tokenizers = [ - tokenizer_mgr.get_tokenizer(r.adapter_index, tokenizer) + tokenizers.get_tokenizer(r.adapter_index, tokenizer) for r in pb.requests ] next_token_chooser = HeterogeneousNextTokenChooser.from_pb( diff --git a/server/lorax_server/models/types.py b/server/lorax_server/models/types.py index e352c7ca4..610c41c6c 100644 --- a/server/lorax_server/models/types.py +++ b/server/lorax_server/models/types.py @@ -22,7 +22,7 @@ def from_pb( cls, pb: generate_pb2.Batch, tokenizer: PreTrainedTokenizerBase, - tokenizer_mgr: TokenizerManager, + tokenizers: TokenizerManager, dtype: torch.dtype, device: torch.device, ) -> "Batch": diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index 177f4bd8b..8b6d53798 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -432,8 +432,8 @@ class HeterogeneousSchemaLogitsProcessor(LogitsProcessor): def __init__( self, - schemas: List[Optional[str]] = None, - tokenizers: List[Optional[PreTrainedTokenizerBase]] = None, + schemas: Optional[List[Optional[str]]] = None, + tokenizers: Optional[List[Optional[PreTrainedTokenizerBase]]] = None, ): if schemas is None: schemas = [] @@ -441,7 +441,7 @@ def __init__( tokenizers = [] self.sequence_processors = [ - None if schema is None else OutlinesLogitsProcessor(schema, tokenizer) + None if schema is None or tokenizer is None else OutlinesLogitsProcessor(schema, tokenizer) for schema, tokenizer in zip(schemas, tokenizers) ] @@ -469,7 +469,7 @@ def concatenate( # Source: https://github.com/outlines-dev/outlines/blob/main/outlines/serve/vllm.py -class OutlinesLogitsProcessor: +class OutlinesLogitsProcessor(LogitsProcessor): def __init__(self, schema: str, tokenizer: PreTrainedTokenizerBase): """Compile the FSM that drives the regex-guided generation. diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 4cb441240..f9258fa24 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -18,7 +18,7 @@ HeterogeneousTopPLogitsWarper, HeterogeneousTypicalLogitsWarper, HeterogeneousProcessorWrapper, - HeterogeneousSchemaLogitsProcessor, + HeterogeneousSchemaLogitsProcessor, OutlinesLogitsProcessor, ) @@ -30,12 +30,14 @@ class NextTokenChooser: watermark (bool): Whether to apply watermark processing to logits. Default is False. temperature (float): The temperature value for warping logits. Default is 1.0. repetition_penalty (float): The penalty value for repetition in logits. Default is 1.0. + schema (str): A JSON schema string for Outlines logits warping. top_k (int): The value for top-k warping of logits. Default is None. top_p (float): The value for top-p warping of logits. Default is None. typical_p (float): The value for typical-p warping of logits. Default is None. do_sample (bool): Whether to perform sampling. Default is False. seed (int): The seed value for random number generation. Default is 0. device (str): The device to use for computation. Default is "cpu". + tokenizer (PreTrainedTokenizerBase): A tokenizer to use for processing the tokens. Returns: next_id (torch.Tensor): The next token ID. @@ -44,15 +46,17 @@ class NextTokenChooser: def __init__( self, - watermark=False, - temperature=1.0, - repetition_penalty=1.0, - top_k=None, - top_p=None, - typical_p=None, - do_sample=False, - seed=0, - device="cpu", + watermark: bool = False, + temperature: float = 1.0, + repetition_penalty: float = 1.0, + schema: str = None, + top_k: int = None, + top_p: float = None, + typical_p: float = None, + do_sample: bool = False, + seed: int = 0, + device: str = "cpu", + tokenizer: Optional[PreTrainedTokenizerBase] = None, ): self.watermark_processor = ( WatermarkLogitsProcessor(device=device) if watermark else None @@ -63,6 +67,12 @@ def __init__( else None ) + self.schema_processor = ( + OutlinesLogitsProcessor(schema, tokenizer) + if schema is not None and tokenizer is not None + else None + ) + has_warpers = ( (temperature is not None and temperature != 1.0) or (top_k is not None and top_k != 0) @@ -84,6 +94,8 @@ def __call__(self, input_ids, scores): scores = self.watermark_processor(input_ids, scores) if self.repetition_processor is not None: scores = self.repetition_processor(input_ids, scores) + if self.schema_processor is not None: + scores = self.schema_processor(input_ids, scores) if self.static_warper is None: next_logprob = torch.log_softmax(scores, -1) @@ -99,6 +111,7 @@ def from_pb( cls, pb: generate_pb2.NextTokenChooserParameters, device: torch.device, + tokenizer: PreTrainedTokenizerBase, ) -> "NextTokenChooser": """ Create a NextTokenChooser instance from a protobuf message. @@ -106,6 +119,7 @@ def from_pb( Args: pb (generate_pb2.NextTokenChooserParameters): The protobuf message containing the parameters. device (torch.device): The device to use for computation. + tokenizer (PreTrainedTokenizerBase): A tokenizer for use in processing the tokens. Returns: NextTokenChooser: The NextTokenChooser instance. @@ -114,12 +128,14 @@ def from_pb( watermark=pb.watermark, temperature=pb.temperature, repetition_penalty=pb.repetition_penalty, + schema=pb.schema, top_k=pb.top_k, top_p=pb.top_p, typical_p=pb.typical_p, do_sample=pb.do_sample, seed=pb.seed, device=device, + tokenizer=tokenizer, )