From 390e86d5826be766d4bc2fbe728bf83484ec46f4 Mon Sep 17 00:00:00 2001 From: Louis Hernandez Date: Thu, 16 May 2024 17:00:59 +0200 Subject: [PATCH 1/6] Added generate.probabilities for BeamSearch --- outlines/generate/__init__.py | 1 + outlines/generate/api.py | 42 ++++++++++++++++++++----- outlines/generate/generator.py | 1 - outlines/generate/probabilities.py | 18 +++++++++++ outlines/samplers.py | 3 +- pyproject.toml | 1 - tests/generate/test_integration_vllm.py | 1 + 7 files changed, 55 insertions(+), 12 deletions(-) create mode 100644 outlines/generate/probabilities.py diff --git a/outlines/generate/__init__.py b/outlines/generate/__init__.py index f28cbd80d..e34db6d20 100644 --- a/outlines/generate/__init__.py +++ b/outlines/generate/__init__.py @@ -4,5 +4,6 @@ from .format import format from .fsm import fsm from .json import json +from .probabilities import probabilities from .regex import regex from .text import text diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 3f4f182d2..802354c24 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -14,19 +14,19 @@ class SequenceGenerator: - def __init__( - self, - fsm, - model, - sampler, - device, - ): + def __init__(self, fsm, model, sampler, device, probabilities=None): self.fsm = fsm self.model = model self.sampler = sampler self.tokenizer = model.tokenizer self.device = device self.num_samples = sampler.samples + # The choices for which we want to compute the probabilities + self.probabilities = probabilities + if self.probabilities: + assert isinstance( + self.sampler, BeamSearchSampler + ), "Probabilities are only supported with a beam search sampler" def get_generated_token_ids( self, @@ -132,7 +132,9 @@ def __call__( max_tokens: Optional[int] = None, stop_at: Optional[Union[str, List[str]]] = None, rng: Optional["torch.Generator"] = None, - ) -> Union[FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]]]: + ) -> Union[ + FormattedOutput, List[FormattedOutput], List[List[FormattedOutput]], tuple + ]: """Generate the full text sequence. Since `SequenceGenerator.stream` calls the tokenizer at every step this @@ -231,9 +233,33 @@ def __call__( # We reshape the output to (batch_size, sample_size) output: List[List[FormattedOutput]] = list() + for i in range(batch_size): output.append(formatted[i : i + num_samples]) + if self.probabilities: + logprobs = last_state.weights + probs = torch.exp(logprobs) + output_probs = [ + {choice: 0.0 for choice in self.probabilities} + for _ in range(batch_size) + ] + for i in range(batch_size): + for choice in self.probabilities: + for output_index, output_item in enumerate(output[i]): + if choice == output_item: + output_probs[i][choice] += float( + probs[i * num_samples + output_index] + ) + if batch_size == 1 and num_samples == 1: + return output[0][0], output_probs + elif batch_size == 1: + return output[0], output_probs + elif num_samples == 1: + return [samples[0] for samples in output], output_probs + else: + return output, output_probs + # We remove leading dimensions for the output if batch_size == 1 and num_samples == 1: return output[0][0] diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index e506aa035..6212b7765 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -82,7 +82,6 @@ def sequence_generator( next_token_ids, ancestors, sequence_weights = sampler( biased_logits, sequence_weights, rng ) - token_ids = update_token_ids(token_ids, next_token_ids, ancestors) attention_masks = update_attention_masks(attention_masks, ancestors) kv_cache = reorder_kv_cache(kv_cache, ancestors) diff --git a/outlines/generate/probabilities.py b/outlines/generate/probabilities.py new file mode 100644 index 000000000..71a43a22f --- /dev/null +++ b/outlines/generate/probabilities.py @@ -0,0 +1,18 @@ +from typing import List + +from outlines.generate.api import SequenceGenerator +from outlines.samplers import BeamSearchSampler, Sampler + +from .regex import regex + + +def probabilities(model, choices: List[str], sampler: Sampler) -> SequenceGenerator: + regex_str = r"(" + r"|".join(choices) + r")" + assert isinstance( + sampler, BeamSearchSampler + ), "Only BeamSearchSampler is supported for probabilities" + generator = regex(model, regex_str, sampler) + generator.format_sequence = lambda x: x + generator.probabilities = choices + + return generator diff --git a/outlines/samplers.py b/outlines/samplers.py index 8b64ed768..fb58de673 100644 --- a/outlines/samplers.py +++ b/outlines/samplers.py @@ -293,8 +293,8 @@ def __call__( # and find the top-k weights for each batch. batch_size = next_token_logits.shape[0] // self.samples vocab_size = next_token_logits.shape[-1] - weights = weights.view(batch_size, self.samples * vocab_size) + weights = weights.view(batch_size, self.samples * vocab_size) # If the weights are all equal to 0 we are at the beginning of the search # and thus only need to sample from one set of token logits for each # batch. @@ -317,7 +317,6 @@ def __call__( ancestors = ancestors.view(self.samples * batch_size) weights = weights.view(self.samples * batch_size) next_token_ids = next_token_ids.view(self.samples * batch_size, 1) - return next_token_ids, ancestors, weights diff --git a/pyproject.toml b/pyproject.toml index 41c306b14..519d381d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,7 +89,6 @@ write_to = "outlines/_version.py" [tool.pytest.ini_options] testpaths = ["tests"] filterwarnings = [ - "error", "ignore::numba.core.errors.NumbaPendingDeprecationWarning", "ignore::pydantic.warnings.PydanticDeprecatedSince20", "ignore::FutureWarning:transformers.*", diff --git a/tests/generate/test_integration_vllm.py b/tests/generate/test_integration_vllm.py index 4634bc839..3f058ae1e 100644 --- a/tests/generate/test_integration_vllm.py +++ b/tests/generate/test_integration_vllm.py @@ -74,6 +74,7 @@ def test_vllm_greedy_sampling(model): assert isinstance(res, str) +@pytest.mark.skip(reason="Temporary disabled for development.") def test_vllm_multinomial_sampling(model): sampler = samplers.multinomial() generator = generate.text(model, sampler) From 34935397db49c1a8279351665a692b4277518886 Mon Sep 17 00:00:00 2001 From: Louis Hernandez Date: Thu, 16 May 2024 17:52:52 +0200 Subject: [PATCH 2/6] Added tests for generate.probabilities --- tests/generate/test_integration_transformers.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 38525a076..11c69d301 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -320,6 +320,22 @@ def test_transformers_integration_choice(): assert sequence == "test" or sequence == "choice" +def test_transformers_integration_probabilities(): + rng = torch.Generator() + rng.manual_seed(0) + + model_name = "hf-internal-testing/tiny-random-GPTJForCausalLM" + model = models.transformers(model_name, device="cpu") + prompt = "Write a short sentence " + sequence, probs = generate.probabilities( + model, ["test", "choice"], sampler=beam_search(beams=5) + )(prompt, rng=rng) + for prob in probs: + assert probs.keys() == {"test", "choice"} + assert probs["test"] > 0 + assert probs["choice"] > 0 + + def test_transformers_integration_with_pad_token(): model_name = "hf-internal-testing/tiny-random-XLMRobertaXLForCausalLM" model = models.transformers(model_name, device="meta") From a433ad93e355c2cecf69307caac2bcfd71117e84 Mon Sep 17 00:00:00 2001 From: Louis Hernandez Date: Fri, 17 May 2024 10:33:15 +0200 Subject: [PATCH 3/6] Returns a dict and not a list when batch_size=1 --- outlines/generate/api.py | 4 ++-- tests/generate/test_integration_transformers.py | 7 +++---- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index 802354c24..1662be03c 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -252,9 +252,9 @@ def __call__( probs[i * num_samples + output_index] ) if batch_size == 1 and num_samples == 1: - return output[0][0], output_probs + return output[0][0], output_probs[0] elif batch_size == 1: - return output[0], output_probs + return output[0], output_probs[0] elif num_samples == 1: return [samples[0] for samples in output], output_probs else: diff --git a/tests/generate/test_integration_transformers.py b/tests/generate/test_integration_transformers.py index 11c69d301..894126951 100644 --- a/tests/generate/test_integration_transformers.py +++ b/tests/generate/test_integration_transformers.py @@ -330,10 +330,9 @@ def test_transformers_integration_probabilities(): sequence, probs = generate.probabilities( model, ["test", "choice"], sampler=beam_search(beams=5) )(prompt, rng=rng) - for prob in probs: - assert probs.keys() == {"test", "choice"} - assert probs["test"] > 0 - assert probs["choice"] > 0 + assert probs.keys() == {"test", "choice"} + assert probs["test"] > 0 + assert probs["choice"] > 0 def test_transformers_integration_with_pad_token(): From 4f0766b5157525d7f0fd6804b61b8e7317e27aa6 Mon Sep 17 00:00:00 2001 From: Louis Hernandez Date: Wed, 22 May 2024 10:10:53 +0200 Subject: [PATCH 4/6] Added back errors to filterwarning --- pyproject.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/pyproject.toml b/pyproject.toml index 519d381d8..41c306b14 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,7 @@ write_to = "outlines/_version.py" [tool.pytest.ini_options] testpaths = ["tests"] filterwarnings = [ + "error", "ignore::numba.core.errors.NumbaPendingDeprecationWarning", "ignore::pydantic.warnings.PydanticDeprecatedSince20", "ignore::FutureWarning:transformers.*", From 761b64bea0c66bb5ca15f123a064ca9eb28c27a6 Mon Sep 17 00:00:00 2001 From: Louis Hernandez Date: Wed, 22 May 2024 15:30:03 +0200 Subject: [PATCH 5/6] removed useless modifications --- outlines/generate/generator.py | 1 + outlines/samplers.py | 3 ++- tests/generate/test_integration_vllm.py | 1 - 3 files changed, 3 insertions(+), 2 deletions(-) diff --git a/outlines/generate/generator.py b/outlines/generate/generator.py index 6212b7765..e506aa035 100644 --- a/outlines/generate/generator.py +++ b/outlines/generate/generator.py @@ -82,6 +82,7 @@ def sequence_generator( next_token_ids, ancestors, sequence_weights = sampler( biased_logits, sequence_weights, rng ) + token_ids = update_token_ids(token_ids, next_token_ids, ancestors) attention_masks = update_attention_masks(attention_masks, ancestors) kv_cache = reorder_kv_cache(kv_cache, ancestors) diff --git a/outlines/samplers.py b/outlines/samplers.py index fb58de673..8b64ed768 100644 --- a/outlines/samplers.py +++ b/outlines/samplers.py @@ -293,8 +293,8 @@ def __call__( # and find the top-k weights for each batch. batch_size = next_token_logits.shape[0] // self.samples vocab_size = next_token_logits.shape[-1] - weights = weights.view(batch_size, self.samples * vocab_size) + # If the weights are all equal to 0 we are at the beginning of the search # and thus only need to sample from one set of token logits for each # batch. @@ -317,6 +317,7 @@ def __call__( ancestors = ancestors.view(self.samples * batch_size) weights = weights.view(self.samples * batch_size) next_token_ids = next_token_ids.view(self.samples * batch_size, 1) + return next_token_ids, ancestors, weights diff --git a/tests/generate/test_integration_vllm.py b/tests/generate/test_integration_vllm.py index 3f058ae1e..4634bc839 100644 --- a/tests/generate/test_integration_vllm.py +++ b/tests/generate/test_integration_vllm.py @@ -74,7 +74,6 @@ def test_vllm_greedy_sampling(model): assert isinstance(res, str) -@pytest.mark.skip(reason="Temporary disabled for development.") def test_vllm_multinomial_sampling(model): sampler = samplers.multinomial() generator = generate.text(model, sampler) From c40db35fa0271677462a60a92e334a9ab5473d0f Mon Sep 17 00:00:00 2001 From: Louis Hernandez Date: Thu, 30 May 2024 15:37:57 +0200 Subject: [PATCH 6/6] Added default sampler + docs + correct protocal for BeamSearch --- docs/reference/probabilities.md | 25 +++++++++++++++++++++++++ outlines/generate/probabilities.py | 6 ++++-- outlines/samplers.py | 2 +- 3 files changed, 30 insertions(+), 3 deletions(-) create mode 100644 docs/reference/probabilities.md diff --git a/docs/reference/probabilities.md b/docs/reference/probabilities.md new file mode 100644 index 000000000..696047114 --- /dev/null +++ b/docs/reference/probabilities.md @@ -0,0 +1,25 @@ +# Multiple choices with Probabilities + +Outlines allows you to generate probabilities for different options, giving you insights into the model's confidence for each choice. + +```python +from outlines import models, generate + +model = models.transformers("mistralai/Mistral-7B-v0.1") +probabilities = generate.probabilities(model, ["skirt", "dress", "pen", "jacket"]) +answer = probabilities("Pick the odd word out: skirt, dress, pen, jacket") +print(answer) +``` + +!!! Warning "Compatibility" + + `generate.probabilities` uses a beam search sampler. It is not compatible with other samplers. Ensure that no other samplers are used in conjunction with this method. + +## How It Works + + +Beam search is a heuristic search algorithm used to explore the most promising sequences in a limited set. In text generation, it maintains the top `k` sequences (beams) at each step based on their cumulative probabilities. Each sequence has a weight, which is the product of the probabilities of its tokens, representing the likelihood of the sequence according to the model. + +!!! Warning "Probabilities Summation" + + The probabilities returned by `generate.probabilities` might not sum to one because the `topk` limitation only keeps the best sequences. This means other sequences with potentially non-negligible probabilities are not taken into account, leading to an incomplete probability distribution. diff --git a/outlines/generate/probabilities.py b/outlines/generate/probabilities.py index 71a43a22f..fac7c9fc6 100644 --- a/outlines/generate/probabilities.py +++ b/outlines/generate/probabilities.py @@ -1,12 +1,14 @@ from typing import List from outlines.generate.api import SequenceGenerator -from outlines.samplers import BeamSearchSampler, Sampler +from outlines.samplers import BeamSearchSampler, Sampler, beam_search from .regex import regex -def probabilities(model, choices: List[str], sampler: Sampler) -> SequenceGenerator: +def probabilities( + model, choices: List[str], sampler: Sampler = beam_search() +) -> SequenceGenerator: regex_str = r"(" + r"|".join(choices) + r")" assert isinstance( sampler, BeamSearchSampler diff --git a/outlines/samplers.py b/outlines/samplers.py index 8b64ed768..1c269d3ae 100644 --- a/outlines/samplers.py +++ b/outlines/samplers.py @@ -261,7 +261,7 @@ def __call__( self, next_token_logits: "torch.DoubleTensor", sequence_weights: "torch.DoubleTensor", - _, + rng: "torch.Generator", ) -> Tuple["torch.DoubleTensor", "torch.DoubleTensor", "torch.DoubleTensor"]: """Call the beam search sampler.