diff --git a/outlines/models/openai.py b/outlines/models/openai.py index 8809177e9..2e01445cb 100644 --- a/outlines/models/openai.py +++ b/outlines/models/openai.py @@ -104,6 +104,11 @@ def __init__( parameters that cannot be set by calling this class' methods. """ + if model_name not in ["gpt-4", "gpt-3.5-turbo"]: + raise ValueError( + "Invalid model_name. It must be either 'gpt-4' or 'gpt-3.5-turbo'." + ) + try: import openai except ImportError: diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index c2e885eb1..2801673fe 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -1,6 +1,7 @@ import pytest from outlines.models.openai import ( + OpenAI, build_optimistic_mask, find_longest_intersection, find_response_choices_intersection, @@ -48,3 +49,11 @@ def test_find_longest_common_prefix(response, choice, expected_prefix): def test_build_optimistic_mask(transposed, mask_size, expected_mask): mask = build_optimistic_mask(transposed, mask_size) assert mask == expected_mask + + +def test_model_name_validation(): + with pytest.raises(ValueError): + OpenAI(model_name="invalid_model_name") + + with pytest.raises(ValueError): + OpenAI(model_name="gpt-4-1106-preview")