Skip to content

Commit

Permalink
Add validation for model_name parameter (#476)
Browse files Browse the repository at this point in the history
This PR introduces a validation check for the `model_name` parameter
within the `openai` method of the `outlines.models` module. The aim is
to ensure that only valid model names are used when calling this
function. As per the current implementation, any string passed as the
`model_name` would be accepted.

The change enforces the function to only accept `gpt-4` or
`gpt-3.5-turbo` as valid model names. If any other model name is passed,
a ValueError is raised, informing the user about the invalid input and
suggesting the correct model names.

This change also adds tests to `tests/models/test_openai.py`. To ensure
an invalid model raises a value error.

Note: I added `gpt-4-1106-preview` in the tests, so when or if it's
added to the suite of models supported by outlines, that will need to be
updated, along with the mode in `openai.py`

Co-authored-by: Stephen Witkowski <[email protected]>
  • Loading branch information
smwitkowski and Stephen Witkowski authored Dec 23, 2023
1 parent 6084f4c commit 7355c4b
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 0 deletions.
5 changes: 5 additions & 0 deletions outlines/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,11 @@ def __init__(
parameters that cannot be set by calling this class' methods.
"""
if model_name not in ["gpt-4", "gpt-3.5-turbo"]:
raise ValueError(
"Invalid model_name. It must be either 'gpt-4' or 'gpt-3.5-turbo'."
)

try:
import openai
except ImportError:
Expand Down
9 changes: 9 additions & 0 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pytest

from outlines.models.openai import (
OpenAI,
build_optimistic_mask,
find_longest_intersection,
find_response_choices_intersection,
Expand Down Expand Up @@ -48,3 +49,11 @@ def test_find_longest_common_prefix(response, choice, expected_prefix):
def test_build_optimistic_mask(transposed, mask_size, expected_mask):
mask = build_optimistic_mask(transposed, mask_size)
assert mask == expected_mask


def test_model_name_validation():
with pytest.raises(ValueError):
OpenAI(model_name="invalid_model_name")

with pytest.raises(ValueError):
OpenAI(model_name="gpt-4-1106-preview")

0 comments on commit 7355c4b

Please sign in to comment.