From 29bd1fe0bb1ebb66dffcca1a60900d4135684c5b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?R=C3=A9mi=20Louf?= Date: Mon, 12 Feb 2024 14:11:49 +0100 Subject: [PATCH] Add temperature rescaling to the multinomial sampler --- outlines/samplers.py | 40 ++++++++++++++++++++++++++++++++++++---- tests/test_samplers.py | 38 ++++++++++++++++++++++++++++++++++++++ 2 files changed, 74 insertions(+), 4 deletions(-) diff --git a/outlines/samplers.py b/outlines/samplers.py index 187318264..393b16af2 100644 --- a/outlines/samplers.py +++ b/outlines/samplers.py @@ -96,14 +96,18 @@ def __init__( *, top_k: Optional[int] = None, top_p: Optional[float] = None, + temperature: Optional[float] = None, ): self.samples = samples - self.logits_processor = lambda x: x + self.logits_processors = [] if top_k is not None: - self.logits_processor = keep_top_k_logits(top_k) + self.logits_processors.append(keep_top_k_logits(top_k)) elif top_p is not None: - self.logits_processor = keep_top_p_logits(top_p) + self.logits_processors.append(keep_top_p_logits(top_p)) + + if temperature is not None: + self.logits_processors.append(rescale_logits(temperature)) def __call__( self, @@ -132,7 +136,10 @@ def __call__( cumulative weights of each sequence of shape ``(n_seqs,)``. """ - altered_next_token_logits = self.logits_processor(next_token_logits) + altered_next_token_logits = next_token_logits + for logit_processor in self.logits_processors: + altered_next_token_logits = logit_processor(next_token_logits) + probs = torch.nn.functional.softmax(altered_next_token_logits, dim=-1) next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng) @@ -196,6 +203,31 @@ def logits_processor(logits: torch.Tensor) -> torch.Tensor: return logits_processor +def rescale_logits(temperature: float) -> Callable[[torch.Tensor], torch.Tensor]: + """Build a function that rescales the token probabilities exponentially. + + Parameters + ---------- + temperature + The value by which we rescale the logits. + + """ + + if not isinstance(temperature, float) or temperature < 0.0: + raise ValueError( + f"`temperature` must be a strictly negative floating point number, got {temperature} instead." + ) + elif temperature == 0.0: + raise ValueError( + "Please use the greedy sampler instead of setting the temperature to 0." + ) + + def logits_processor(logits: torch.Tensor) -> torch.Tensor: + return logits / temperature + + return logits_processor + + class BeamSearchSampler: """Beam Search sampling algorithm. diff --git a/tests/test_samplers.py b/tests/test_samplers.py index 3578fcece..88cdb0fbc 100644 --- a/tests/test_samplers.py +++ b/tests/test_samplers.py @@ -12,6 +12,7 @@ keep_top_k_logits, keep_top_p_logits, multinomial, + rescale_logits, ) @@ -72,6 +73,32 @@ def test_multinomial(): assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]])) +def test_multinomial_init(): + sampler = MultinomialSampler() + assert sampler.logits_processors == [] + + sampler = MultinomialSampler(3) + assert sampler.logits_processors == [] + + sampler = MultinomialSampler(top_k=1) + assert len(sampler.logits_processors) == 1 + + sampler = MultinomialSampler(top_p=0.95) + assert len(sampler.logits_processors) == 1 + + sampler = MultinomialSampler(top_k=1, top_p=0.95) + assert len(sampler.logits_processors) == 1 + + sampler = MultinomialSampler(temperature=1.0) + assert len(sampler.logits_processors) == 1 + + sampler = MultinomialSampler(top_k=1, temperature=1.0) + assert len(sampler.logits_processors) == 2 + + sampler = MultinomialSampler(top_p=0.95, temperature=1.0) + assert len(sampler.logits_processors) == 2 + + def test_top_k(): with pytest.raises(ValueError, match="`k` must be a strictly"): keep_top_k_logits(-1) @@ -159,6 +186,17 @@ def test_top_p(): ) +def test_rescale(): + with pytest.raises(ValueError, match="`temperature` must"): + rescale_logits(1) + + with pytest.raises(ValueError, match="`temperature` must"): + rescale_logits(-0.1) + + with pytest.raises(ValueError, match="Please use the greedy sampler"): + rescale_logits(0.0) + + def test_beam_search(): # Two beams, single sequence sampler = BeamSearchSampler(2)