Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix frequency_penalty and presence_penalty #672

Merged
merged 8 commits into from
Nov 11, 2024
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 8 additions & 4 deletions proto/generate.proto
Original file line number Diff line number Diff line change
Expand Up @@ -91,14 +91,18 @@ message NextTokenChooserParameters {
uint64 seed = 6;
/// repetition penalty
float repetition_penalty = 7;
/// frequency penalty
float frequency_penalty = 8;
/// presence penalty
float presence_penalty = 9;
/// token watermarking using "A Watermark for Large Language Models"
bool watermark = 8;
bool watermark = 10;
/// adapter to use with lora exchange
string adapter_id = 9;
string adapter_id = 11;
/// JSON schema used for constrained decoding (Outlines)
optional string schema = 10;
optional string schema = 12;
/// returning the k highest probability alternatives
uint32 return_k_alternatives = 11;
uint32 return_k_alternatives = 13;
}

message StoppingCriteriaParameters {
Expand Down
2 changes: 2 additions & 0 deletions router/client/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,8 @@ impl Client {
do_sample: false,
seed: 0,
repetition_penalty: 1.2,
frequency_penalty: 0.5,
presence_penalty: 0.5,
watermark: true,
adapter_id: "".to_string(),
schema: None,
Expand Down
2 changes: 2 additions & 0 deletions router/src/health.rs
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ impl Health {
do_sample: false,
seed: 0,
repetition_penalty: 1.0,
frequency_penalty: 0.0,
presence_penalty: 0.0,
watermark: false,
adapter_id: "".to_string(),
schema: None,
Expand Down
11 changes: 7 additions & 4 deletions router/src/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -361,7 +361,7 @@ impl Infer {
let truncate = request.parameters.truncate;
let encoding = self
.validation
.tokenize(inputs, truncate)
.tokenize(inputs, request.add_special_tokens, truncate)
.await
.map_err(|err| {
tracing::error!("Error occurred during tokenization. {err}");
Expand Down Expand Up @@ -530,7 +530,7 @@ impl Infer {
let inputs = request.inputs.clone();
let (tokenized_inputs, input_length) = self
.validation
.validate_input(request.inputs, None, Some(1))
.validate_input(request.inputs, true, None, Some(1))
.await?;

let valid_request = ValidEmbedRequest {
Expand Down Expand Up @@ -631,7 +631,7 @@ impl Infer {
let inputs = request.inputs.clone();
let (tokenized_inputs, input_length) = self
.validation
.validate_input(request.inputs, None, Some(1))
.validate_input(request.inputs, true, None, Some(1))
.await?;

let valid_request = ValidClassifyRequest {
Expand Down Expand Up @@ -755,7 +755,10 @@ impl Infer {
let futures: Vec<_> = request
.inputs
.iter()
.map(|input| self.validation.validate_input(input.clone(), None, Some(1)))
.map(|input| {
self.validation
.validate_input(input.clone(), true, None, Some(1))
})
.collect();

let all_tokenized_inputs = try_join_all(futures).await?;
Expand Down
184 changes: 183 additions & 1 deletion router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -17,10 +17,11 @@ use lorax_client::{AdapterParameters as AdapterParametersMessage, Entity as Enti
use lorax_client::{MajoritySignMethod, MergeStrategy};

use batch::Entry;
use infer::Infer;
use infer::{Infer, InferError};
use loader::AdapterLoader;
use serde::{Deserialize, Serialize};
use serde_json::json;
use server::prepare_chat_input;
use utoipa::ToSchema;
use validation::Validation;

Expand Down Expand Up @@ -204,6 +205,27 @@ pub(crate) struct GenerateParameters {
example = 1.03
)]
pub repetition_penalty: Option<f32>,

#[serde(default)]
#[schema(
exclusive_minimum = -2.0,
exclusive_maximum = 2.0,
nullable = true,
default = "null",
example = 0.1
)]
pub frequency_penalty: Option<f32>,

#[serde(default)]
#[schema(
exclusive_minimum = -2.0,
exclusive_maximum = 2.0,
nullable = true,
default = "null",
example = 0.1
)]
pub presence_penalty: Option<f32>,

#[serde(default)]
#[schema(exclusive_minimum = 0, nullable = true, default = "null", example = 10)]
pub top_k: Option<i32>,
Expand Down Expand Up @@ -285,6 +307,8 @@ fn default_parameters() -> GenerateParameters {
best_of: None,
temperature: None,
repetition_penalty: None,
frequency_penalty: None,
presence_penalty: None,
top_k: None,
top_p: None,
typical_p: None,
Expand All @@ -310,6 +334,16 @@ pub(crate) struct GenerateRequest {
pub inputs: String,
#[serde(default = "default_parameters")]
pub parameters: GenerateParameters,

/// This is used internally because some requests
/// already contain the templated input therefore
/// we shouldn't add the special tokens.
#[serde(default = "default_true", skip)]
pub add_special_tokens: bool,
}

fn default_true() -> bool {
true
}

#[derive(Clone, Debug, Deserialize, ToSchema)]
Expand All @@ -321,13 +355,20 @@ pub(crate) struct CompatGenerateRequest {
#[serde(default)]
#[schema(default = "false")]
pub stream: bool,

/// This is used internally because some requests
/// already contain the templated input therefore
/// we shouldn't add the special tokens.
#[serde(default = "default_true", skip)]
pub add_special_tokens: bool,
}

impl From<CompatGenerateRequest> for GenerateRequest {
fn from(req: CompatGenerateRequest) -> Self {
Self {
inputs: req.inputs,
parameters: req.parameters,
add_special_tokens: req.add_special_tokens,
}
}
}
Expand Down Expand Up @@ -665,6 +706,144 @@ struct ChatCompletionRequest {
pub guideline: Option<String>,
}

impl ChatCompletionRequest {
fn try_into_generate(self, infer: &Infer) -> Result<(CompatGenerateRequest, bool), InferError> {
let ChatCompletionRequest {
model,
max_tokens,
messages,
seed,
stop,
stream,
tools,
tool_choice,
tool_prompt,
temperature,
response_format,
guideline,
repetition_penalty,
presence_penalty,
frequency_penalty,
top_p,
top_k,
n,
adapter_source,
api_token,
ignore_eos_token,
..
} = self;

let mut adapter_id = Some(model.clone());
if model == "" {
adapter_id = None;
}

// Modify input values to ResponseFormat to be OpenAI API compatible
let response_format: Option<ResponseFormat> = match response_format {
None => None,
Some(openai_format) => {
let response_format_type = openai_format.response_format_type.clone();
match response_format_type {
// Ignore when type is text
ResponseFormatType::Text => None,

// For json_object, use the fixed schema.
// For backwards compatibility, also support non-standard `schema` field
ResponseFormatType::JsonObject => openai_format.schema.map_or_else(
|| {
Some(ResponseFormat {
r#type: response_format_type.clone(),
schema: default_json_schema(),
})
},
|schema_value: serde_json::Value| {
Some(ResponseFormat {
r#type: response_format_type.clone(),
schema: Some(schema_value),
})
},
),

// For json_schema, use schema_value if available, otherwise fallback to the fixed schema
ResponseFormatType::JsonSchema => openai_format
.json_schema
.and_then(|schema| schema.schema)
.map_or_else(
|| {
Some(ResponseFormat {
r#type: response_format_type.clone(),
schema: default_json_schema(),
})
},
|schema_value: serde_json::Value| {
Some(ResponseFormat {
r#type: response_format_type.clone(),
schema: Some(schema_value),
})
},
),
}
}
};

let tool_prompt = tool_prompt
.filter(|s| !s.is_empty())
.unwrap_or_else(default_tool_prompt);

// enable greedy only when temperature is 0
let (do_sample, temperature) = match temperature {
Some(temperature) if temperature == 0.0 => (false, None),
other => (true, other),
};

let (inputs, response_format, using_tools) = prepare_chat_input(
&infer,
response_format,
tools,
tool_choice,
&tool_prompt,
guideline,
messages,
)?;

Ok((
CompatGenerateRequest {
inputs: inputs.to_string(),
add_special_tokens: false,
parameters: GenerateParameters {
adapter_id,
adapter_source,
adapter_parameters: None,
api_token,
best_of: n.map(|x| x as usize),
temperature,
repetition_penalty,
frequency_penalty,
presence_penalty,
top_k,
top_p,
typical_p: None,
do_sample,
max_new_tokens: max_tokens.map(|x| x as u32),
return_full_text: None,
stop,
truncate: None,
watermark: false,
details: true,
decoder_input_details: false,
seed,
ignore_eos_token: ignore_eos_token.unwrap_or(false),
return_k_alternatives: None,
apply_chat_template: false,
response_format,
},
stream: stream.unwrap_or(false),
},
using_tools,
))
}
}

pub fn default_tool_prompt() -> String {
"\nGiven the functions available, please respond with a JSON for a function call with its proper arguments that best answers the given prompt. Respond in the format {name: function name, parameters: dictionary of argument name and its value}.Do not use variables.\n".to_string()
}
Expand Down Expand Up @@ -1001,6 +1180,7 @@ impl From<CompletionRequest> for CompatGenerateRequest {
fn from(req: CompletionRequest) -> Self {
CompatGenerateRequest {
inputs: req.prompt,
add_special_tokens: true,
parameters: GenerateParameters {
adapter_id: req.model.parse().ok(),
adapter_source: req.adapter_source,
Expand All @@ -1009,6 +1189,8 @@ impl From<CompletionRequest> for CompatGenerateRequest {
best_of: req.best_of.map(|x| x as usize),
temperature: req.temperature,
repetition_penalty: req.repetition_penalty,
frequency_penalty: req.frequency_penalty,
presence_penalty: req.presence_penalty,
top_k: req.top_k,
top_p: req.top_p,
typical_p: None,
Expand Down
Loading
Loading