Skip to content

Commit

Permalink
enh: Expose ignore_eos_token option in generate requests (#340)
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffreyftang authored Mar 19, 2024
1 parent 431ae61 commit 67579cc
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 1 deletion.
8 changes: 8 additions & 0 deletions clients/python/lorax/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ def generate(
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
best_of: Optional[int] = None,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
Expand Down Expand Up @@ -102,6 +103,8 @@ def generate(
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
ignore_eos_token (`bool`):
Whether to ignore EOS tokens during generation
best_of (`int`):
Generate best_of sequences and return the one if the highest token logprobs
repetition_penalty (`float`):
Expand Down Expand Up @@ -158,6 +161,7 @@ def generate(
details=details,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
ignore_eos_token=ignore_eos_token,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
Expand Down Expand Up @@ -198,6 +202,7 @@ def generate_stream(
api_token: Optional[str] = None,
do_sample: bool = False,
max_new_tokens: int = 20,
ignore_eos_token: bool = False,
repetition_penalty: Optional[float] = None,
return_full_text: bool = False,
seed: Optional[int] = None,
Expand Down Expand Up @@ -229,6 +234,8 @@ def generate_stream(
Activate logits sampling
max_new_tokens (`int`):
Maximum number of generated tokens
ignore_eos_token (`bool`):
Whether to ignore EOS tokens during generation
repetition_penalty (`float`):
The parameter for repetition penalty. 1.0 means no penalty. See [this
paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
Expand Down Expand Up @@ -280,6 +287,7 @@ def generate_stream(
decoder_input_details=False,
do_sample=do_sample,
max_new_tokens=max_new_tokens,
ignore_eos_token=ignore_eos_token,
repetition_penalty=repetition_penalty,
return_full_text=return_full_text,
seed=seed,
Expand Down
2 changes: 2 additions & 0 deletions clients/python/lorax/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,8 @@ class Parameters(BaseModel):
do_sample: bool = False
# Maximum number of generated tokens
max_new_tokens: int = 20
# Whether to ignore the EOS token during generation
ignore_eos_token: bool = False
# The parameter for repetition penalty. 1.0 means no penalty.
# See [this paper](https://arxiv.org/pdf/1909.05858.pdf) for more details.
repetition_penalty: Optional[float] = None
Expand Down
5 changes: 5 additions & 0 deletions docs/reference/openapi.json
Original file line number Diff line number Diff line change
Expand Up @@ -750,6 +750,11 @@
"exclusiveMaximum": 512.0,
"exclusiveMinimum": 0.0
},
"ignore_eos_token": {
"type": "boolean",
"default": "false",
"example": true
},
"repetition_penalty": {
"type": "number",
"format": "float",
Expand Down
6 changes: 6 additions & 0 deletions router/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -224,6 +224,9 @@ pub(crate) struct GenerateParameters {
#[schema(exclusive_minimum = 0, exclusive_maximum = 512, default = "20")]
pub max_new_tokens: u32,
#[serde(default)]
#[schema(default = "false", example = true)]
pub ignore_eos_token: bool,
#[serde(default)]
#[schema(nullable = true, default = "null", example = false)]
pub return_full_text: Option<bool>,
#[serde(default)]
Expand Down Expand Up @@ -282,6 +285,7 @@ fn default_parameters() -> GenerateParameters {
typical_p: None,
do_sample: false,
max_new_tokens: default_max_new_tokens(),
ignore_eos_token: false,
return_full_text: None,
stop: Vec::new(),
truncate: None,
Expand Down Expand Up @@ -619,6 +623,7 @@ impl From<CompletionRequest> for CompatGenerateRequest {
.max_tokens
.map(|x| x as u32)
.unwrap_or(default_max_new_tokens()),
ignore_eos_token: false,
return_full_text: req.echo,
stop: req.stop,
truncate: None,
Expand Down Expand Up @@ -655,6 +660,7 @@ impl From<ChatCompletionRequest> for CompatGenerateRequest {
.max_tokens
.map(|x| x as u32)
.unwrap_or(default_max_new_tokens()),
ignore_eos_token: false,
return_full_text: None,
stop: req.stop,
truncate: None,
Expand Down
3 changes: 2 additions & 1 deletion router/src/validation.rs
Original file line number Diff line number Diff line change
Expand Up @@ -139,6 +139,7 @@ impl Validation {
typical_p,
do_sample,
max_new_tokens,
ignore_eos_token,
stop: stop_sequences,
truncate,
seed,
Expand Down Expand Up @@ -296,7 +297,7 @@ impl Validation {
let stopping_parameters = StoppingCriteriaParameters {
max_new_tokens,
stop_sequences,
ignore_eos_token: false,
ignore_eos_token,
};

metrics::histogram!("lorax_request_max_new_tokens", max_new_tokens as f64);
Expand Down

0 comments on commit 67579cc

Please sign in to comment.