From a09944c2dfd0c077b74f524f71d66039bab1284b Mon Sep 17 00:00:00 2001 From: Andrew Lapp Date: Tue, 17 Sep 2024 16:50:08 -0400 Subject: [PATCH] Don't re-use logits processors in SequenceGeneratorAdapter, copy them --- outlines/generate/api.py | 9 +++++---- tests/generate/test_generate.py | 10 ++++++++++ 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/outlines/generate/api.py b/outlines/generate/api.py index ad01377c0..4919f2090 100644 --- a/outlines/generate/api.py +++ b/outlines/generate/api.py @@ -1,4 +1,5 @@ import datetime +from copy import copy from dataclasses import dataclass from typing import TYPE_CHECKING, Any, Iterator, List, Optional, Union @@ -503,7 +504,7 @@ def __call__( completions = self.model.generate( prompts, generation_params, - self.logits_processor, + copy(self.logits_processor), self.sampling_params, **model_specific_params, ) @@ -525,7 +526,7 @@ def stream( return self.model.stream( prompts, generation_params, - self.logits_processor, + copy(self.logits_processor), self.sampling_params, **model_specific_params, ) @@ -556,7 +557,7 @@ def __call__( # type: ignore prompts, media, generation_params, - self.logits_processor, + copy(self.logits_processor), self.sampling_params, **model_specific_params, ) @@ -581,7 +582,7 @@ def stream( # type: ignore prompts, media, generation_params, - self.logits_processor, + copy(self.logits_processor), self.sampling_params, **model_specific_params, ) diff --git a/tests/generate/test_generate.py b/tests/generate/test_generate.py index ff247b0f4..a96ce8673 100644 --- a/tests/generate/test_generate.py +++ b/tests/generate/test_generate.py @@ -253,6 +253,16 @@ def test_generate_choice(request, model_fixture, sample_choices): assert res in sample_choices +@pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) +def test_generate_choice_twice(request, model_fixture, sample_choices): + model = request.getfixturevalue(model_fixture) + generator = generate.choice(model, sample_choices) + res = generator(**get_inputs(model_fixture)) + assert res in sample_choices + res = generator(**get_inputs(model_fixture)) + assert res in sample_choices + + @pytest.mark.parametrize("model_fixture", ALL_MODEL_FIXTURES) def test_generate_format_bool(request, model_fixture): model = request.getfixturevalue(model_fixture)