From c69c469cb2a18221119de3febe6386e20b28dbec Mon Sep 17 00:00:00 2001 From: Travis Addair Date: Mon, 13 May 2024 21:47:47 -0700 Subject: [PATCH] Allow setting temperature=0 --- clients/python/lorax/types.py | 4 +- router/src/validation.rs | 4 +- server/lorax_server/utils/logits_process.py | 2 +- server/lorax_server/utils/tokens.py | 20 +++++- server/tests/conftest.py | 44 ++++++++++++ server/tests/utils/test_tokens.py | 79 +++++++++++++++++++++ 6 files changed, 145 insertions(+), 8 deletions(-) diff --git a/clients/python/lorax/types.py b/clients/python/lorax/types.py index fadfc0823..bd6f5ff49 100644 --- a/clients/python/lorax/types.py +++ b/clients/python/lorax/types.py @@ -163,8 +163,8 @@ def valid_seed(cls, v): @field_validator("temperature") def valid_temp(cls, v): - if v is not None and v <= 0: - raise ValidationError("`temperature` must be strictly positive") + if v is not None and v < 0: + raise ValidationError("`temperature` must be non-negative") return v @field_validator("top_k") diff --git a/router/src/validation.rs b/router/src/validation.rs index 8264ea0a7..33514b84e 100644 --- a/router/src/validation.rs +++ b/router/src/validation.rs @@ -203,7 +203,7 @@ impl Validation { } let temperature = temperature.unwrap_or(1.0); - if temperature <= 0.0 { + if temperature < 0.0 { return Err(ValidationError::Temperature); } @@ -422,7 +422,7 @@ pub enum ValidationError { BestOfStream, #[error("`decoder_input_details` == true is not supported when streaming tokens")] PrefillDetailsStream, - #[error("`temperature` must be strictly positive")] + #[error("`temperature` must be non-negative")] Temperature, #[error("`repetition_penalty` must be strictly positive")] RepetitionPenalty, diff --git a/server/lorax_server/utils/logits_process.py b/server/lorax_server/utils/logits_process.py index 0ee178d5e..6091c94d5 100644 --- a/server/lorax_server/utils/logits_process.py +++ b/server/lorax_server/utils/logits_process.py @@ -35,7 +35,7 @@ def __init__( ): self.warpers = [] - if temperature is not None and temperature != 1.0: + if temperature is not None and temperature != 1.0 and temperature != 0: temperature = float(temperature) self.warpers.append(TemperatureLogitsWarper(temperature)) if top_k is not None and top_k != 0: diff --git a/server/lorax_server/utils/tokens.py b/server/lorax_server/utils/tokens.py index 160fff5d0..199c1f107 100644 --- a/server/lorax_server/utils/tokens.py +++ b/server/lorax_server/utils/tokens.py @@ -3,6 +3,7 @@ from typing import List, Optional, Set, Tuple, Union import torch +import warnings from transformers import ( PreTrainedTokenizerBase, RepetitionPenaltyLogitsProcessor, @@ -67,8 +68,10 @@ def __init__( self.schema_processor = OutlinesLogitsProcessor(schema, tokenizer) if schema and tokenizer else None + # Temperature = 1 does not change logits; do not use warper + # Temperature = 0 invokes determinstic token choosing; do not warp has_warpers = ( - (temperature is not None and temperature != 1.0) + (temperature is not None and temperature != 1.0 and temperature != 0) or (top_k is not None and top_k != 0) or (top_p is not None and top_p < 1.0) or (typical_p is not None and typical_p < 1.0) @@ -79,6 +82,13 @@ def __init__( self.static_warper = None sampling = do_sample or has_warpers + + # do not sample if temperature is 0, even if do_sample flag is set True + # warn user about deterministic sampling + if sampling and temperature == 0: + sampling = False + warnings.warn("Temperature is set to 0, token sampling will be disabled") + self.choice = Sampling(seed, device) if sampling else Greedy() def __call__(self, input_ids, scores): @@ -283,8 +293,10 @@ def __init__( HeterogeneousSchemaLogitsProcessor.from_schemas(schemas, tokenizers) if any(schemas) else None ) - if any([x != 1.0 for x in temperature]): - do_sample = [sample or x != 1.0 for x, sample in zip(temperature, do_sample)] + if any([(x != 1.0 and x != 0) for x in temperature]): + # set sample flags for each index + # do not sample this index if temperature is 0 or 1 + do_sample = [sample or (x != 1.0 and x != 0) for x, sample in zip(temperature, do_sample)] warpers.append(HeterogeneousTemperatureLogitsWarper(temperature, dtype, device)) if any([x != 0 for x in top_k]): @@ -302,8 +314,10 @@ def __init__( self.warpers = warpers if any(do_sample): + # sample tokens from distribution if any sample flags are set True self.choice = HeterogeneousSampling(do_sample, seeds, device) else: + # sampling for all requests is set false, do Greedy / deterministic sampling self.choice = Greedy() self.seeds = seeds diff --git a/server/tests/conftest.py b/server/tests/conftest.py index 4173ade2f..6d5464a20 100644 --- a/server/tests/conftest.py +++ b/server/tests/conftest.py @@ -1,6 +1,10 @@ import pytest +import torch +from transformers import AutoTokenizer +from lorax_server.models.causal_lm import CausalLM, CausalLMBatch from lorax_server.pb import generate_pb2 +from lorax_server.utils.tokenizer import TokenizerManager @pytest.fixture @@ -53,3 +57,43 @@ def schema_constrained_pb_parameters(default_json_schema): do_sample=False, schema=default_json_schema, ) + + +@pytest.fixture(scope="session") +def default_causal_lm(): + return CausalLM("gpt2") + + +@pytest.fixture(scope="session") +def gpt2_tokenizer(): + tokenizer = AutoTokenizer.from_pretrained("gpt2", padding_side="left") + tokenizer.pad_token_id = 50256 + return tokenizer + + +@pytest.fixture +def default_pb_request(default_pb_parameters, default_pb_stop_parameters): + return generate_pb2.Request( + id=0, + inputs="Test", + prefill_logprobs=True, + truncate=100, + parameters=default_pb_parameters, + stopping_parameters=default_pb_stop_parameters, + ) + + +@pytest.fixture +def default_pb_batch(default_pb_request): + return generate_pb2.Batch(id=0, requests=[default_pb_request], size=1) + + +@pytest.fixture +def default_causal_lm_batch(default_pb_batch, gpt2_tokenizer): + return CausalLMBatch.from_pb( + default_pb_batch, + gpt2_tokenizer, + TokenizerManager(), + torch.float32, + torch.device("cpu"), + ) diff --git a/server/tests/utils/test_tokens.py b/server/tests/utils/test_tokens.py index e68956ffe..c0c901f93 100644 --- a/server/tests/utils/test_tokens.py +++ b/server/tests/utils/test_tokens.py @@ -1,4 +1,7 @@ +from lorax_server.adapters.weights import AdapterBatchData +from lorax_server.models.causal_lm import CausalLMBatch from lorax_server.utils.tokens import ( + NextTokenChooser, StopSequenceCriteria, StoppingCriteria, FinishReason, @@ -42,3 +45,79 @@ def test_stopping_criteria_max(): assert criteria(1, "") == (False, None) assert criteria(1, "") == (False, None) assert criteria(1, "") == (True, FinishReason.FINISH_REASON_LENGTH) + + +# check generations work normally with temperature = 0 +def test_generate_token_temperature_zero(default_causal_lm, default_causal_lm_batch): + sequence_length = len(default_causal_lm_batch.all_input_ids[0]) + batch = default_causal_lm_batch + + # set all token choosers in batch to be deterministic with Temperature = 0 + determ_token_choosers = [NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers))] + batch.next_token_choosers = determ_token_choosers + # generate tokens from next batch + generations, next_batch = default_causal_lm.generate_token(default_causal_lm_batch) + + # same assertions as testing generate token, causal lm + assert len(generations) == len(next_batch) + assert isinstance(next_batch, CausalLMBatch) + + assert len(next_batch.all_input_ids) == len(next_batch) + assert len(next_batch.all_input_ids[0]) == sequence_length + 1 + + +# generates tokens with determinstic choosers, +# checks that output tokens have highest probability in distribution +def test_deterministic_tokens_temperature_zero(default_causal_lm, default_causal_lm_batch): + # Inside of CausalLM.generate_token, used to access + # logit distribution and compare log prob + batch = default_causal_lm_batch + + # set all token choosers in batch to be deterministic with Temperature = 0 + determ_token_choosers = [NextTokenChooser(temperature=0) for _ in range(len(batch.next_token_choosers))] + batch.next_token_choosers = determ_token_choosers + + attention_mask = batch.attention_mask[:, : -batch.padding_right_offset] + + adapter_data = AdapterBatchData.from_meta(batch.adapter_meta, default_causal_lm.batched_lora_weights) + + logits, _ = default_causal_lm.forward( + batch.input_ids, + attention_mask, + batch.position_ids, + batch.past_key_values, + adapter_data, + ) + + # Zipped iterator + iterator = zip( + batch.requests, + batch.input_lengths, + batch.prefix_offsets, + batch.read_offsets, + logits, + batch.next_token_choosers, + batch.stopping_criterias, + batch.all_input_ids, + ) + + # For each member of the batch + for i, ( + request, + input_length, + prefix_offset, + read_offset, + logits, + next_token_chooser, + stopping_criteria, + all_input_ids, + ) in enumerate(iterator): + # Select next token + next_token_id, logprobs = next_token_chooser(all_input_ids.view(1, -1), logits[-1:, :]) + + # Generated token + next_token_logprob = logprobs[-1, next_token_id] + + # A deterministic model with Temperature = 0 should always choose + # the highest logprob token + assert next_token_logprob == max(logprobs[-1])