Skip to content

Commit

Permalink
Fix MultinomialSampler hyperparameter bug
Browse files Browse the repository at this point in the history
Uses all logit_processors instead of just the last logit_processor
  • Loading branch information
aidansan committed Sep 14, 2024
1 parent 0b9a3f1 commit 50feaa4
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion outlines/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def __call__(

altered_next_token_logits = next_token_logits
for logit_processor in self.logits_processors:
altered_next_token_logits = logit_processor(next_token_logits)
altered_next_token_logits = logit_processor(altered_next_token_logits)

probs = torch.nn.functional.softmax(altered_next_token_logits, dim=-1)
next_token_ids = torch.multinomial(probs, num_samples=1, generator=rng)
Expand Down

0 comments on commit 50feaa4

Please sign in to comment.