Skip to content

Commit

Permalink
Don't re-use logits processors in SequenceGeneratorAdapter, copy them
Browse files Browse the repository at this point in the history
  • Loading branch information
lapp0 committed Sep 22, 2024
1 parent e07f550 commit a09944c
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
9 changes: 5 additions & 4 deletions outlines/generate/api.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import datetime
from copy import copy
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union

Expand Down Expand Up @@ -503,7 +504,7 @@ def __call__(
completions = self.model.generate(
prompts,
generation_params,
self.logits_processor,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
Expand All @@ -525,7 +526,7 @@ def stream(
return self.model.stream(
prompts,
generation_params,
self.logits_processor,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
Expand Down Expand Up @@ -556,7 +557,7 @@ def __call__( # type: ignore
prompts,
media,
generation_params,
self.logits_processor,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
Expand All @@ -581,7 +582,7 @@ def stream( # type: ignore
prompts,
media,
generation_params,
self.logits_processor,
copy(self.logits_processor),
self.sampling_params,
**model_specific_params,
)
Expand Down
10 changes: 10 additions & 0 deletions tests/generate/test_generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,16 @@ def test_generate_choice(request, model_fixture, sample_choices):
assert res in sample_choices


@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_choice_twice(request, model_fixture, sample_choices):
model = request.getfixturevalue(model_fixture)
generator = generate.choice(model, sample_choices)
res = generator(**get_inputs(model_fixture))
assert res in sample_choices
res = generator(**get_inputs(model_fixture))
assert res in sample_choices


@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES)
def test_generate_format_bool(request, model_fixture):
model = request.getfixturevalue(model_fixture)
Expand Down

0 comments on commit a09944c

Please sign in to comment.