From 67579cc1e27ba86863a1128d8b3f4abd6fa0b94a Mon Sep 17 00:00:00 2001 From: Jeffrey Tang <810895+jeffreyftang@users.noreply.github.com> Date: Tue, 19 Mar 2024 16:09:14 -0500 Subject: [PATCH] enh: Expose ignore_eos_token option in generate requests (#340) --- clients/python/lorax/client.py | 8 ++++++++ clients/python/lorax/types.py | 2 ++ docs/reference/openapi.json | 5 +++++ router/src/lib.rs | 6 ++++++ router/src/validation.rs | 3 ++- 5 files changed, 23 insertions(+), 1 deletion(-) diff --git a/clients/python/lorax/client.py b/clients/python/lorax/client.py index 82e72ef1f..ec9d32e02 100644 --- a/clients/python/lorax/client.py +++ b/clients/python/lorax/client.py @@ -68,6 +68,7 @@ def generate( api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, + ignore_eos_token: bool = False, best_of: Optional[int] = None, repetition_penalty: Optional[float] = None, return_full_text: bool = False, @@ -102,6 +103,8 @@ def generate( Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens + ignore_eos_token (`bool`): + Whether to ignore EOS tokens during generation best_of (`int`): Generate best_of sequences and return the one if the highest token logprobs repetition_penalty (`float`): @@ -158,6 +161,7 @@ def generate( details=details, do_sample=do_sample, max_new_tokens=max_new_tokens, + ignore_eos_token=ignore_eos_token, repetition_penalty=repetition_penalty, return_full_text=return_full_text, seed=seed, @@ -198,6 +202,7 @@ def generate_stream( api_token: Optional[str] = None, do_sample: bool = False, max_new_tokens: int = 20, + ignore_eos_token: bool = False, repetition_penalty: Optional[float] = None, return_full_text: bool = False, seed: Optional[int] = None, @@ -229,6 +234,8 @@ def generate_stream( Activate logits sampling max_new_tokens (`int`): Maximum number of generated tokens + ignore_eos_token (`bool`): + Whether to ignore EOS tokens during generation repetition_penalty (`float`): The parameter for repetition penalty. 1.0 means no penalty. See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. @@ -280,6 +287,7 @@ def generate_stream( decoder_input_details=False, do_sample=do_sample, max_new_tokens=max_new_tokens, + ignore_eos_token=ignore_eos_token, repetition_penalty=repetition_penalty, return_full_text=return_full_text, seed=seed, diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index 32503ff7c..2b9aa32fd 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -80,6 +80,8 @@ class Parameters(BaseModel): do_sample: bool = False # Maximum number of generated tokens max_new_tokens: int = 20 + # Whether to ignore the EOS token during generation + ignore_eos_token: bool = False # The parameter for repetition penalty. 1.0 means no penalty. # See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details. repetition_penalty: Optional[float] = None diff --git a/docs/reference/openapi.json b/docs/reference/openapi.json index 574ad4d8c..51130c4e7 100644 --- a/docs/reference/openapi.json +++ b/docs/reference/openapi.json @@ -750,6 +750,11 @@ "exclusiveMaximum": 512.0, "exclusiveMinimum": 0.0 }, + "ignore_eos_token": { + "type": "boolean", + "default": "false", + "example": true + }, "repetition_penalty": { "type": "number", "format": "float", diff --git a/router/src/lib.rs b/router/src/lib.rs index 61f627f77..3f7ba1650 100644 --- a/router/src/lib.rs +++ b/router/src/lib.rs @@ -224,6 +224,9 @@ pub(crate) struct GenerateParameters { #[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")] pub max_new_tokens: u32, #[serde(default)] + #[schema(default = "false", example = true)] + pub ignore_eos_token: bool, + #[serde(default)] #[schema(nullable = true, default = "null", example = false)] pub return_full_text: Option, #[serde(default)] @@ -282,6 +285,7 @@ fn default_parameters() -> GenerateParameters { typical_p: None, do_sample: false, max_new_tokens: default_max_new_tokens(), + ignore_eos_token: false, return_full_text: None, stop: Vec::new(), truncate: None, @@ -619,6 +623,7 @@ impl From for CompatGenerateRequest { .max_tokens .map(|x| x as u32) .unwrap_or(default_max_new_tokens()), + ignore_eos_token: false, return_full_text: req.echo, stop: req.stop, truncate: None, @@ -655,6 +660,7 @@ impl From for CompatGenerateRequest { .max_tokens .map(|x| x as u32) .unwrap_or(default_max_new_tokens()), + ignore_eos_token: false, return_full_text: None, stop: req.stop, truncate: None, diff --git a/router/src/validation.rs b/router/src/validation.rs index f9d7b6b50..1cf9130b5 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -139,6 +139,7 @@ impl Validation { typical_p, do_sample, max_new_tokens, + ignore_eos_token, stop: stop_sequences, truncate, seed, @@ -296,7 +297,7 @@ impl Validation { let stopping_parameters = StoppingCriteriaParameters { max_new_tokens, stop_sequences, - ignore_eos_token: false, + ignore_eos_token, }; metrics::histogram!("lorax_request_max_new_tokens", max_new_tokens as f64);