diff --git a/proto/generate.proto b/proto/generate.proto index a44bcdb54..9744d17a9 100644 --- a/proto/generate.proto +++ b/proto/generate.proto @@ -91,14 +91,18 @@ message NextTokenChooserParameters { uint64 seed = 6; /// repetition penalty float repetition_penalty = 7; + /// frequency penalty + float frequency_penalty = 8; + /// presence penalty + float presence_penalty = 9; /// token watermarking using "A Watermark for Large Language Models" - bool watermark = 8; + bool watermark = 10; /// adapter to use with lora exchange - string adapter_id = 9; + string adapter_id = 11; /// JSON schema used for constrained decoding (Outlines) - optional string schema = 10; + optional string schema = 12; /// returning the k highest probability alternatives - uint32 return_k_alternatives = 11; + uint32 return_k_alternatives = 13; } message StoppingCriteriaParameters { diff --git a/router/client/src/client.rs b/router/client/src/client.rs index 0751347da..1224ec9da 100644 --- a/router/client/src/client.rs +++ b/router/client/src/client.rs @@ -154,6 +154,8 @@ impl Client { do_sample: false, seed: 0, repetition_penalty: 1.2, + frequency_penalty: 0.5, + presence_penalty: 0.5, watermark: true, adapter_id: "".to_string(), schema: None, diff --git a/router/src/health.rs b/router/src/health.rs index 1c6aaed2d..af9efeca5 100644 --- a/router/src/health.rs +++ b/router/src/health.rs @@ -59,6 +59,8 @@ impl Health { do_sample: false, seed: 0, repetition_penalty: 1.0, + frequency_penalty: 0.0, + presence_penalty: 0.0, watermark: false, adapter_id: "".to_string(), schema: None, diff --git a/router/src/infer.rs b/router/src/infer.rs index 0dbec8557..1bead7c75 100644 --- a/router/src/infer.rs +++ b/router/src/infer.rs @@ -361,7 +361,7 @@ impl Infer { let truncate = request.parameters.truncate; let encoding = self .validation - .tokenize(inputs, truncate) + .tokenize(inputs, request.add_special_tokens, truncate) .await .map_err(|err| { tracing::error!("Error occurred during tokenization. {err}"); @@ -530,7 +530,7 @@ impl Infer { let inputs = request.inputs.clone(); let (tokenized_inputs, input_length) = self .validation - .validate_input(request.inputs, None, Some(1)) + .validate_input(request.inputs, true, None, Some(1)) .await?; let valid_request = ValidEmbedRequest { @@ -631,7 +631,7 @@ impl Infer { let inputs = request.inputs.clone(); let (tokenized_inputs, input_length) = self .validation - .validate_input(request.inputs, None, Some(1)) + .validate_input(request.inputs, true, None, Some(1)) .await?; let valid_request = ValidClassifyRequest { @@ -755,7 +755,10 @@ impl Infer { let futures: Vec<_> = request .inputs .iter() - .map(|input| self.validation.validate_input(input.clone(), None, Some(1))) + .map(|input| { + self.validation + .validate_input(input.clone(), true, None, Some(1)) + }) .collect(); let all_tokenized_inputs = try_join_all(futures).await?; diff --git a/router/src/lib.rs b/router/src/lib.rs index 3eaeeca3f..db080eec0 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -17,10 +17,11 @@ use lorax_client::{AdapterParameters as AdapterParametersMessage, Entity as Enti use lorax_client::{MajoritySignMethod, MergeStrategy}; use batch::Entry; -use infer::Infer; +use infer::{Infer, InferError}; use loader::AdapterLoader; use serde::{Deserialize, Serialize}; use serde_json::json; +use server::prepare_chat_input; use utoipa::ToSchema; use validation::Validation; @@ -204,6 +205,27 @@ pub(crate) struct GenerateParameters { example = 1.03 )] pub repetition_penalty: Option, + + #[serde(default)] + #[schema( + exclusive_minimum = -2.0, + exclusive_maximum = 2.0, + nullable = true, + default = "null", + example = 0.1 + )] + pub frequency_penalty: Option, + + #[serde(default)] + #[schema( + exclusive_minimum = -2.0, + exclusive_maximum = 2.0, + nullable = true, + default = "null", + example = 0.1 + )] + pub presence_penalty: Option, + #[serde(default)] #[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)] pub top_k: Option, @@ -285,6 +307,8 @@ fn default_parameters() -> GenerateParameters { best_of: None, temperature: None, repetition_penalty: None, + frequency_penalty: None, + presence_penalty: None, top_k: None, top_p: None, typical_p: None, @@ -310,6 +334,16 @@ pub(crate) struct GenerateRequest { pub inputs: String, #[serde(default = "default_parameters")] pub parameters: GenerateParameters, + + /// This is used internally because some requests + /// already contain the templated input therefore + /// we shouldn't add the special tokens. + #[serde(default = "default_true", skip)] + pub add_special_tokens: bool, +} + +fn default_true() -> bool { + true } #[derive(Clone, Debug, Deserialize, ToSchema)] @@ -321,6 +355,12 @@ pub(crate) struct CompatGenerateRequest { #[serde(default)] #[schema(default = "false")] pub stream: bool, + + /// This is used internally because some requests + /// already contain the templated input therefore + /// we shouldn't add the special tokens. + #[serde(default = "default_true", skip)] + pub add_special_tokens: bool, } impl From for GenerateRequest { @@ -328,6 +368,7 @@ impl From for GenerateRequest { Self { inputs: req.inputs, parameters: req.parameters, + add_special_tokens: req.add_special_tokens, } } } @@ -665,6 +706,144 @@ struct ChatCompletionRequest { pub guideline: Option, } +impl ChatCompletionRequest { + fn try_into_generate(self, infer: &Infer) -> Result<(CompatGenerateRequest, bool), InferError> { + let ChatCompletionRequest { + model, + max_tokens, + messages, + seed, + stop, + stream, + tools, + tool_choice, + tool_prompt, + temperature, + response_format, + guideline, + repetition_penalty, + presence_penalty, + frequency_penalty, + top_p, + top_k, + n, + adapter_source, + api_token, + ignore_eos_token, + .. + } = self; + + let mut adapter_id = Some(model.clone()); + if model == "" { + adapter_id = None; + } + + // Modify input values to ResponseFormat to be OpenAI API compatible + let response_format: Option = match response_format { + None => None, + Some(openai_format) => { + let response_format_type = openai_format.response_format_type.clone(); + match response_format_type { + // Ignore when type is text + ResponseFormatType::Text => None, + + // For json_object, use the fixed schema. + // For backwards compatibility, also support non-standard `schema` field + ResponseFormatType::JsonObject => openai_format.schema.map_or_else( + || { + Some(ResponseFormat { + r#type: response_format_type.clone(), + schema: default_json_schema(), + }) + }, + |schema_value: serde_json::Value| { + Some(ResponseFormat { + r#type: response_format_type.clone(), + schema: Some(schema_value), + }) + }, + ), + + // For json_schema, use schema_value if available, otherwise fallback to the fixed schema + ResponseFormatType::JsonSchema => openai_format + .json_schema + .and_then(|schema| schema.schema) + .map_or_else( + || { + Some(ResponseFormat { + r#type: response_format_type.clone(), + schema: default_json_schema(), + }) + }, + |schema_value: serde_json::Value| { + Some(ResponseFormat { + r#type: response_format_type.clone(), + schema: Some(schema_value), + }) + }, + ), + } + } + }; + + let tool_prompt = tool_prompt + .filter(|s| !s.is_empty()) + .unwrap_or_else(default_tool_prompt); + + // enable greedy only when temperature is 0 + let (do_sample, temperature) = match temperature { + Some(temperature) if temperature == 0.0 => (false, None), + other => (true, other), + }; + + let (inputs, response_format, using_tools) = prepare_chat_input( + &infer, + response_format, + tools, + tool_choice, + &tool_prompt, + guideline, + messages, + )?; + + Ok(( + CompatGenerateRequest { + inputs: inputs.to_string(), + add_special_tokens: false, + parameters: GenerateParameters { + adapter_id, + adapter_source, + adapter_parameters: None, + api_token, + best_of: n.map(|x| x as usize), + temperature, + repetition_penalty, + frequency_penalty, + presence_penalty, + top_k, + top_p, + typical_p: None, + do_sample, + max_new_tokens: max_tokens.map(|x| x as u32), + return_full_text: None, + stop, + truncate: None, + watermark: false, + details: true, + decoder_input_details: false, + seed, + ignore_eos_token: ignore_eos_token.unwrap_or(false), + return_k_alternatives: None, + apply_chat_template: false, + response_format, + }, + stream: stream.unwrap_or(false), + }, + using_tools, + )) + } +} + pub fn default_tool_prompt() -> String { "\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string() } @@ -1001,6 +1180,7 @@ impl From for CompatGenerateRequest { fn from(req: CompletionRequest) -> Self { CompatGenerateRequest { inputs: req.prompt, + add_special_tokens: true, parameters: GenerateParameters { adapter_id: req.model.parse().ok(), adapter_source: req.adapter_source, @@ -1009,6 +1189,8 @@ impl From for CompatGenerateRequest { best_of: req.best_of.map(|x| x as usize), temperature: req.temperature, repetition_penalty: req.repetition_penalty, + frequency_penalty: req.frequency_penalty, + presence_penalty: req.presence_penalty, top_k: req.top_k, top_p: req.top_p, typical_p: None, diff --git a/router/src/server.rs b/router/src/server.rs index 011b509e1..9b9d87571 100644 --- a/router/src/server.rs +++ b/router/src/server.rs @@ -5,20 +5,19 @@ use crate::health::Health; use crate::infer::{InferError, InferResponse, InferStreamResponse}; use crate::tool_grammar::ToolGrammar; use crate::validation::ValidationError; +use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig}; use crate::{ - default_json_schema, default_tool_prompt, AdapterParameters, AlternativeToken, - BatchClassifyRequest, BestOfSequence, ChatCompletionRequest, ChatCompletionResponse, - ChatCompletionResponseChoice, ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, - ChatMessage, ClassifyRequest, CompatGenerateRequest, CompletionFinishReason, CompletionRequest, - CompletionResponse, CompletionResponseChoice, CompletionResponseStreamChoice, - CompletionStreamResponse, Details, EmbedParameters, EmbedRequest, EmbedResponse, Entity, - ErrorResponse, FinishReason, FunctionDefinition, GenerateParameters, GenerateRequest, - GenerateResponse, HubModelInfo, Infer, Info, JsonSchema, LogProbs, Message, - OpenAiResponseFormat, PrefillToken, ResponseFormat, ResponseFormatType, SimpleToken, - StreamDetails, StreamResponse, Token, TokenizeRequest, TokenizeResponse, Tool, ToolCall, - ToolChoice, UsageInfo, Validation, + AdapterParameters, AlternativeToken, BatchClassifyRequest, BestOfSequence, + ChatCompletionRequest, ChatCompletionResponse, ChatCompletionResponseChoice, + ChatCompletionStreamResponse, ChatCompletionStreamResponseChoice, ChatMessage, ClassifyRequest, + CompatGenerateRequest, CompletionFinishReason, CompletionRequest, CompletionResponse, + CompletionResponseChoice, CompletionResponseStreamChoice, CompletionStreamResponse, Details, + EmbedParameters, EmbedRequest, EmbedResponse, Entity, ErrorResponse, FinishReason, + FunctionDefinition, GenerateParameters, GenerateRequest, GenerateResponse, HubModelInfo, Infer, + Info, JsonSchema, LogProbs, Message, OpenAiResponseFormat, PrefillToken, ResponseFormat, + ResponseFormatType, SimpleToken, StreamDetails, StreamResponse, Token, TokenizeRequest, + TokenizeResponse, Tool, ToolCall, ToolChoice, UsageInfo, Validation, }; -use crate::{json, HubPreprocessorConfig, HubProcessorConfig, HubTokenizerConfig}; use axum::extract::Extension; use axum::http::{HeaderMap, Method, StatusCode}; use axum::response::sse::{Event, KeepAlive, Sse}; @@ -246,107 +245,9 @@ async fn chat_completions_v1( req.model = "".to_string(); } - let mut adapter_id = Some(req.model.clone()); - if req.model == info.model_id.as_str() { - // Allow user to specify the base model, but treat it as an empty adapter_id - tracing::debug!("Replacing base model {0} with empty adapter_id", req.model); - adapter_id = None; - } - let system_fingerprint = format!("{}-{}", info.version, info.docker_label.unwrap_or("native")); - - // Modify input values to ResponseFormat to be OpenAI API compatible - let response_format: Option = match req.response_format { - None => None, - Some(openai_format) => { - let response_format_type = openai_format.response_format_type.clone(); - match response_format_type { - // Ignore when type is text - ResponseFormatType::Text => None, - - // For json_object, use the fixed schema. - // For backwards compatibility, also support non-standard `schema` field - ResponseFormatType::JsonObject => openai_format.schema.map_or_else( - || { - Some(ResponseFormat { - r#type: response_format_type.clone(), - schema: default_json_schema(), - }) - }, - |schema_value: serde_json::Value| { - Some(ResponseFormat { - r#type: response_format_type.clone(), - schema: Some(schema_value), - }) - }, - ), - - // For json_schema, use schema_value if available, otherwise fallback to the fixed schema - ResponseFormatType::JsonSchema => openai_format - .json_schema - .and_then(|schema| schema.schema) - .map_or_else( - || { - Some(ResponseFormat { - r#type: response_format_type.clone(), - schema: default_json_schema(), - }) - }, - |schema_value: serde_json::Value| { - Some(ResponseFormat { - r#type: response_format_type.clone(), - schema: Some(schema_value), - }) - }, - ), - } - } - }; - - let tool_prompt = req - .tool_prompt - .filter(|s| !s.is_empty()) - .unwrap_or_else(default_tool_prompt); - - let (inputs, response_format, using_tools) = prepare_chat_input( - &infer, - response_format, - req.tools, - req.tool_choice, - &tool_prompt, - req.guideline, - req.messages, - )?; - - let mut gen_req = CompatGenerateRequest { - inputs: inputs.to_string(), - parameters: GenerateParameters { - adapter_id: adapter_id, - adapter_source: req.adapter_source, - adapter_parameters: None, - api_token: req.api_token, - best_of: req.n.map(|x| x as usize), - temperature: req.temperature, - repetition_penalty: req.repetition_penalty, - top_k: req.top_k, - top_p: req.top_p, - typical_p: None, - do_sample: !req.n.is_none(), - max_new_tokens: req.max_tokens.map(|x| x as u32), - ignore_eos_token: req.ignore_eos_token.unwrap_or(false), - return_full_text: None, - stop: req.stop, - truncate: None, - watermark: false, - details: true, - decoder_input_details: false, - return_k_alternatives: None, - apply_chat_template: false, - seed: req.seed, - response_format: response_format, - }, - stream: req.stream.unwrap_or(false), - }; + let (mut gen_req, using_tools): (CompatGenerateRequest, bool) = + req.try_into_generate(&infer)?; // default return_full_text given the pipeline_tag if gen_req.parameters.return_full_text.is_none() { @@ -616,6 +517,8 @@ async fn health( do_sample: false, seed: None, repetition_penalty: None, + frequency_penalty: None, + presence_penalty: None, watermark: false, return_full_text: None, stop: vec![], @@ -628,6 +531,7 @@ async fn health( max_new_tokens: Some(1), ignore_eos_token: false, }, + add_special_tokens: true, }; match infer.generate(generate_request).await { Ok(response) => { diff --git a/router/src/validation.rs b/router/src/validation.rs index 4a7a91edc..2950f63f4 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -85,6 +85,7 @@ impl Validation { pub async fn tokenize( &self, inputs: String, + add_special_tokens: bool, truncate: Option, ) -> Result)>, ValidationError> { // If we have a fast tokenizer @@ -94,7 +95,11 @@ impl Validation { // Send request to the background validation task // Unwrap is safe here sender - .send(((inputs, truncate), response_sender, Span::current())) + .send(( + (inputs, add_special_tokens, truncate), + response_sender, + Span::current(), + )) .unwrap(); // Await on response channel @@ -110,11 +115,15 @@ impl Validation { pub(crate) async fn validate_input( &self, inputs: String, + add_special_tokens: bool, truncate: Option, max_new_tokens: Option, ) -> Result<(Option, usize), ValidationError> { // If we have a fast tokenizer - if let Some((encoding, chunks)) = self.tokenize(inputs.clone(), truncate).await? { + if let Some((encoding, chunks)) = self + .tokenize(inputs.clone(), add_special_tokens, truncate) + .await? + { // Create response channel let input_length = encoding.len(); @@ -193,6 +202,8 @@ impl Validation { best_of, temperature, repetition_penalty, + frequency_penalty, + presence_penalty, top_k, top_p, typical_p, @@ -252,6 +263,16 @@ impl Validation { return Err(ValidationError::RepetitionPenalty); } + let frequency_penalty = frequency_penalty.unwrap_or(0.0); + if !(-2.0..=2.0).contains(&frequency_penalty) { + return Err(ValidationError::FrequencyPenalty); + } + + let presence_penalty = presence_penalty.unwrap_or(0.0); + if !(-2.0..=2.0).contains(&presence_penalty) { + return Err(ValidationError::PresencePenalty); + } + // Different because the proto default value is not a valid value // for the user let top_p = top_p @@ -339,12 +360,19 @@ impl Validation { // Validate inputs let inputs = request.inputs.clone(); let (tokenized_inputs, input_length) = self - .validate_input(request.inputs, truncate, max_new_tokens) + .validate_input( + request.inputs, + request.add_special_tokens, + truncate, + max_new_tokens, + ) .await?; let parameters = NextTokenChooserParameters { temperature, repetition_penalty, + frequency_penalty, + presence_penalty, top_k, top_p, typical_p, @@ -419,12 +447,15 @@ fn tokenizer_worker( mut receiver: mpsc::UnboundedReceiver, ) { // Loop over requests - while let Some(((inputs, truncate), response_tx, parent_span)) = receiver.blocking_recv() { + while let Some(((inputs, add_special_tokens, truncate), response_tx, parent_span)) = + receiver.blocking_recv() + { parent_span.in_scope(|| { response_tx .send(prepare_input( inputs, truncate, + add_special_tokens, &tokenizer, config.as_ref(), preprocessor_config.as_ref(), @@ -561,6 +592,7 @@ fn image_tokens_fixup(config: &Config, text: String) -> String { fn prepare_input( inputs: String, _truncate: Option, + add_special_tokens: bool, tokenizer: &Tokenizer, config: Option<&Config>, preprocessor_config: Option<&HubPreprocessorConfig>, @@ -599,14 +631,14 @@ fn prepare_input( // Get the number of tokens in the input let encoding = tokenizer - .encode(tokenizer_query, true) + .encode(tokenizer_query, add_special_tokens) .map_err(|err| ValidationError::Tokenizer(err.to_string()))?; Ok((encoding, input_chunks)) } type TokenizerRequest = ( - (String, Option), + (String, bool, Option), oneshot::Sender), ValidationError>>, Span, ); @@ -663,6 +695,10 @@ pub enum ValidationError { Temperature, #[error("`repetition_penalty` must be strictly positive")] RepetitionPenalty, + #[error("`frequency_penalty` must be >= -2.0 and <= 2.0")] + FrequencyPenalty, + #[error("`presence_penalty` must be >= -2.0 and <= 2.0")] + PresencePenalty, #[error("`top_p` must be > 0.0 and < 1.0")] TopP, #[error("`top_k` must be strictly positive")] @@ -739,7 +775,7 @@ mod tests { let max_new_tokens = Some(10); match validation - .validate_input("Hello".to_string(), None, max_new_tokens) + .validate_input("Hello".to_string(), true, None, max_new_tokens) .await { Err(ValidationError::MaxNewTokens(1, 10)) => (), @@ -769,7 +805,7 @@ mod tests { let max_new_tokens = Some(10); match validation - .validate_input("Hello".to_string(), None, max_new_tokens) + .validate_input("Hello".to_string(), true, None, max_new_tokens) .await { Err(ValidationError::MaxTotalTokens(5, 1, 10)) => (), @@ -805,6 +841,7 @@ mod tests { do_sample: false, ..default_parameters() }, + add_special_tokens: true, }, Adapter::new( AdapterParameters { @@ -850,6 +887,7 @@ mod tests { top_p: Some(1.0), ..default_parameters() }, + add_special_tokens: true, }, Adapter::new( AdapterParameters { @@ -876,6 +914,7 @@ mod tests { max_new_tokens: Some(1), ..default_parameters() }, + add_special_tokens: true, }, Adapter::new( AdapterParameters { @@ -902,6 +941,7 @@ mod tests { max_new_tokens: Some(1), ..default_parameters() }, + add_special_tokens: true, }, Adapter::new( AdapterParameters { diff --git a/server/lorax_server/utils/layers.py b/server/lorax_server/utils/layers.py index 0feaae609..5a37c79f2 100644 --- a/server/lorax_server/utils/layers.py +++ b/server/lorax_server/utils/layers.py @@ -79,7 +79,8 @@ def forward_layer_type( # Triton Punica kernels key = (layer_type, self.layer_id) if ( - adapter_data.punica_wrapper is not None and adapter_data.punica_wrapper.enabled + adapter_data.punica_wrapper is not None + and adapter_data.punica_wrapper.enabled and key in adapter_data.layer_to_lora_weights and input.shape[0] <= adapter_data.punica_wrapper.max_batch_size and can_vectorize diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index 6091c94d5..dc5b880fb 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -86,6 +86,28 @@ def static_warper( return StaticWarper(temperature=temperature, top_k=top_k, top_p=top_p, typical_p=typical_p) +class FrequencyPenaltyLogitsProcessor(LogitsProcessor): + r""" + Frequency penalty as defined by OpenAI + + Args: + penalty (`float`): + The parameter for frequency penalty. 0.0 means no penalty. + """ + + def __init__(self, penalty: float): + self.penalty = penalty + + def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor: + score = torch.gather(scores, 1, input_ids) + # if score < 0 then penalty has to be multiplied to reduce the previous token probability + score = -torch.where(score < 0, score * self.penalty, score / self.penalty) + # set score to 0 where input_ids is a padding token + score *= input_ids.ne(0) + + return scores.scatter_add_(1, input_ids, score) + + class HeterogeneousRepetitionPenaltyLogitsProcessor(LogitsProcessor): r""" [`LogitsProcessor`] enforcing an exponential penalty on repeated sequences. @@ -119,6 +141,53 @@ def filter(self, indices): return None +class HeterogeneousFrequencyPenaltyLogitsProcessor(LogitsProcessor): + r""" + Frequency penalty as defined by OpenAI in + https://platform.openai.com/docs/guides/text-generation/parameter-details + + Args: + frequency_penalty (`List[float]`): + The parameter for frequency penalty. 0.0 means no penalty. + presence_penalty (`List[float]`): + The parameter for presence penalty. 0.0 means no penalty. + """ + + def __init__( + self, frequency_penalty: List[float], presence_penalty: List[float], dtype: torch.dtype, device: torch.device + ): + self.frequency_penalty = frequency_penalty + self.frequency_penalty_tensor = torch.tensor(frequency_penalty, dtype=dtype, device=device).unsqueeze(1) + + self.presence_penalty = presence_penalty + self.presence_penalty_tensor = torch.tensor(presence_penalty, dtype=dtype, device=device).unsqueeze(1) + + def __call__(self, input_ids: torch.Tensor, scores: torch.Tensor) -> torch.Tensor: + batch_size, input_size = input_ids.size() + vocab_size = scores.size(1) + + # Calculate the frequency for each token so far + token_freq = torch.zeros(batch_size, vocab_size, device=input_ids.device) + token_freq.scatter_add_(1, input_ids, torch.ones_like(input_ids, dtype=torch.float)) + mask = token_freq > 0 + token_freq /= input_size + + # Apply the frequency and presence penalties to logits + scores -= token_freq * self.frequency_penalty_tensor + scores -= mask * self.presence_penalty_tensor + + return scores + + def filter(self, indices): + self.frequency_penalty = [self.frequency_penalty[i] for i in indices] + self.presence_penalty = [self.presence_penalty[i] for i in indices] + if any([x != 0.0 for x in self.frequency_penalty]) or any([x != 0.0 for x in self.presence_penalty]): + self.frequency_penalty_tensor = self.frequency_penalty_tensor[indices] + self.presence_penalty_tensor = self.presence_penalty_tensor[indices] + return self + return None + + class HeterogeneousTemperatureLogitsWarper(LogitsWarper): r""" [`LogitsWarper`] for temperature (exponential scaling output probability distribution). diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index c3231caea..c08bb1467 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -12,6 +12,7 @@ from lorax_server.pb import generate_pb2 from lorax_server.pb.generate_pb2 import FinishReason from lorax_server.utils.logits_process import ( + HeterogeneousFrequencyPenaltyLogitsProcessor, HeterogeneousProcessorWrapper, HeterogeneousRepetitionPenaltyLogitsProcessor, HeterogeneousSchemaLogitsProcessor, @@ -232,6 +233,8 @@ class HeterogeneousNextTokenChooser: watermark (List[bool]): A list of booleans indicating whether watermark processing should be applied for each token. temperature (List[float]): A list of temperature values for temperature-based logits warping. repetition_penalty (List[float]): A list of repetition penalty values for repetition penalty-based logits warping. + frequency_penalty (List[float]): A list of frequency penalty values for frequency penalty-based logits warping. + presence_penalty (List[float]): A list of presence penalty values for presence penalty-based logits warping. schemas (List[str]): A list of JSON schema strings for Outlines logits warping. top_k (List[int]): A list of top-k values for top-k-based logits warping. top_p (List[float]): A list of top-p values for top-p-based logits warping. @@ -243,6 +246,7 @@ class HeterogeneousNextTokenChooser: Attributes: watermark_processor (HeterogeneousProcessorWrapper): The watermark logits processor. repetition_processor (HeterogeneousRepetitionPenaltyLogitsProcessor): The repetition penalty logits processor. + frequency_processor (HeterogeneousFrequencyPenaltyLogitsProcessor): The frequency penalty logits processor. schema_processor (HeterogeneousSchemaLogitsProcessor): The JSON schema logits processor. warpers (List[HeterogeneousLogitsWarper]): The list of logits warpers. choice (HeterogeneousSampling or Greedy): The token choice strategy. @@ -259,6 +263,8 @@ def __init__( watermark: List[bool], temperature: List[float], repetition_penalty: List[float], + frequency_penalty: List[float], + presence_penalty: List[float], schemas: List[str], top_k: List[int], top_p: List[float], @@ -284,6 +290,12 @@ def __init__( else None ) + self.frequency_processor = ( + HeterogeneousFrequencyPenaltyLogitsProcessor(frequency_penalty, presence_penalty, dtype, device) + if any([x != 0.0 for x in frequency_penalty]) or any([x != 0.0 for x in presence_penalty]) + else None + ) + if sequence_processors is not None: # Reuse the state from the previous generation steps self.schema_processor = ( @@ -359,10 +371,16 @@ def __call__( with self.schema_processor.restore_state() if self.schema_processor is not None else nullcontext(): for j in range(S): scores_j = scores[:, j] + + if j > 0: + scores_j = torch.rand_like(scores_j) + if self.watermark_processor is not None: scores_j = self.watermark_processor(input_ids, scores_j) if self.repetition_processor is not None: scores_j = self.repetition_processor(input_ids, scores_j) + if self.frequency_processor is not None: + scores_j = self.frequency_processor(input_ids, scores_j) if self.schema_processor is not None: scores_j = self.schema_processor(input_ids, scores_j) @@ -380,7 +398,8 @@ def __call__( self.schema_processor.next_state(batch_idx, next_ids_j[batch_idx].item()) next_ids = next_ids.view(B * S) - scores = scores.view(B * S, -1) + allscores = scores.view(B * S, -1) + alllogprobs = torch.log_softmax(allscores, -1) if speculated_ids is not None: accepted_ids = [] @@ -406,14 +425,15 @@ def __call__( accepted_ids = torch.tensor(accepted_ids, device=input_ids.device, dtype=input_ids.dtype) next_ids = next_ids[indices] - scores = scores[indices] + logprobs = alllogprobs[indices] indices = torch.arange(B, device=input_ids.device) * S if speculative_scores is not None: speculative_scores = speculative_scores[indices + accepted_ids - 1] else: accepted_ids = torch.ones_like(next_ids) + logprobs = alllogprobs - next_logprobs = torch.gather(torch.log_softmax(scores, -1), 1, next_ids.view(-1, 1)).view(-1) + next_logprobs = torch.gather(logprobs, 1, next_ids.view(-1, 1)).view(-1) speculative_ids = None if speculate > 0: @@ -441,6 +461,9 @@ def filter(self, indices): if self.repetition_processor is not None: self.repetition_processor = self.repetition_processor.filter(indices) + if self.frequency_processor is not None: + self.frequency_processor = self.frequency_processor.filter(indices) + if self.schema_processor is not None: self.schema_processor = self.schema_processor.filter(indices) @@ -491,6 +514,8 @@ def from_pb( watermark=[pb_.watermark for pb_ in pb], temperature=[pb_.temperature for pb_ in pb], repetition_penalty=[pb_.repetition_penalty for pb_ in pb], + frequency_penalty=[pb_.frequency_penalty for pb_ in pb], + presence_penalty=[pb_.presence_penalty for pb_ in pb], schemas=[pb_.schema for pb_ in pb], top_k=[pb_.top_k for pb_ in pb], top_p=[pb_.top_p for pb_ in pb],