Skip to content

Commit

Permalink
Add top-k processor to multinomial sampler
Browse files Browse the repository at this point in the history
  • Loading branch information
rlouf committed Feb 12, 2024
1 parent e36065c commit 441e533
Show file tree
Hide file tree
Showing 3 changed files with 69 additions and 5 deletions.
7 changes: 7 additions & 0 deletions docs/reference/samplers.md
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,13 @@ print(answer)
# [[4, 4, 4], [6, 6, 6]]
```

### Top-k sampling

You can ask Outlines to only consider the top-k logits at each step by specifying the value of the `top-k` keyword argument when initializing the sampler.

```python
sampler = samplers.multinomial(3, top_k=10)
```

## Greedy sampler

Expand Down
38 changes: 33 additions & 5 deletions outlines/samplers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from typing import Protocol, Tuple
import math
from typing import Callable, Optional, Protocol, Tuple

import torch

Expand Down Expand Up @@ -89,9 +90,13 @@ class MultinomialSampler:
"""

def __init__(self, samples: int = 1):
def __init__(self, samples: int = 1, *, top_k: Optional[int] = None):
self.samples = samples

self.logits_processor = lambda x: x
if top_k is not None:
self.logits_processor = keep_top_k_logits(top_k)

def __call__(
self,
next_token_logits: torch.DoubleTensor,
Expand Down Expand Up @@ -119,11 +124,12 @@ def __call__(
cumulative weights of each sequence of shape ``(n_seqs,)``.
"""
probs = torch.nn.functional.softmax(next_token_logits, dim=-1)
altered_next_token_logits = self.logits_processor(next_token_logits)
probs = torch.nn.functional.softmax(altered_next_token_logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng)

logprobs = torch.nn.functional.log_softmax(next_token_logits, dim=-1)
ancestors = torch.arange(next_token_logits.shape[0])
logprobs = torch.nn.functional.log_softmax(altered_next_token_logits, dim=-1)
ancestors = torch.arange(altered_next_token_logits.shape[0])
weights = sequence_weights + torch.gather(logprobs, 1, next_token_ids).squeeze()

return next_token_ids, ancestors, weights
Expand All @@ -132,6 +138,28 @@ def __call__(
multinomial = MultinomialSampler


def keep_top_k_logits(k) -> Callable[[torch.Tensor], torch.Tensor]:
"""Build a function that masks logits values smaller than the top `k` ones.
Parameters
----------
k
The ranking below which logit values are replaced by `-math.inf`.
"""
if not isinstance(k, int) or k < 1:
raise ValueError(
f"`top_k` must be a strictly positive integers, got {k} instead."
)

def logits_processor(logits: torch.Tensor) -> torch.Tensor:
num_to_keep = min(k, logits.size(-1))
mask_idx = logits < torch.topk(logits, num_to_keep)[0][..., -1, None]
return logits.masked_fill(mask_idx, -math.inf)

return logits_processor


class BeamSearchSampler:
"""Beam Search sampling algorithm.
Expand Down
29 changes: 29 additions & 0 deletions tests/test_samplers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import math

import pytest
import torch

from outlines.samplers import (
Expand All @@ -8,6 +9,7 @@
MultinomialSampler,
beam_search,
greedy,
keep_top_k_logits,
multinomial,
)

Expand Down Expand Up @@ -69,6 +71,33 @@ def test_multinomial():
assert weights.equal(torch.tensor([logprobs[0, 0], logprobs[1, 2]]))


def test_topk():
logits = torch.tensor([[1.0, 2.0, 3.0, 4.0]])

logits_processor = keep_top_k_logits(1)
result = logits_processor(logits)
assert result.equal(torch.tensor([[-math.inf, -math.inf, -math.inf, 4.0]]))

logits_processor = keep_top_k_logits(10)
result = logits_processor(logits)
assert result.equal(torch.tensor([[1.0, 2.0, 3.0, 4.0]]))

with pytest.raises(ValueError, match="`top_k` must be a strictly"):
keep_top_k_logits(-1)

with pytest.raises(ValueError, match="`top_k` must be a strictly"):
keep_top_k_logits(0.1)

logits = torch.tensor([[1.0, 2.0, 3.0, 4.0], [5.0, 6.0, 7.0, 8.0]])
logits_processor = keep_top_k_logits(2)
result = logits_processor(logits)
assert result.equal(
torch.tensor(
[[-math.inf, -math.inf, 3.0, 4.0], [-math.inf, -math.inf, 7.0, 8.0]]
)
)


def test_beam_search():
# Two beams, single sequence
sampler = BeamSearchSampler(2)
Expand Down

0 comments on commit 441e533

Please sign in to comment.