Skip to content

Commit

Permalink
Update the OpenAI model tests
Browse files Browse the repository at this point in the history
  • Loading branch information
RobinPicard committed Feb 5, 2024
1 parent d630720 commit 2d8f284
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 1 deletion.
3 changes: 2 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,8 @@ test = [
"datasets",
"responses",
"llama-cpp-python",
"huggingface_hub"
"huggingface_hub",
"openai>=1.0.0"
]
serve = [
"vllm>=0.3.0",
Expand Down
63 changes: 63 additions & 0 deletions tests/models/test_openai.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,75 @@
from unittest.mock import MagicMock, patch

import pytest
from openai import AsyncAzureOpenAI, AsyncOpenAI

from outlines.models.openai import (
OpenAI,
OpenAIConfig,
build_optimistic_mask,
find_longest_intersection,
find_response_choices_intersection,
)


def test_openai_init():
with patch.object(OpenAI, "check_model_validity") as mocked_check_model_validity:
mocked_check_model_validity.return_value = None

async_client = MagicMock(spec=AsyncOpenAI, api_key="key")

model = OpenAI(async_client, "gpt-4")
assert isinstance(model.config, OpenAIConfig)

model = OpenAI(async_client, "gpt-4", config=OpenAIConfig(n=2))
assert model.config.n == 2

model = OpenAI(async_client, "foo", encoding="gpt-4")
assert model.config.model == "foo"
assert model.model_name == "gpt-4"

async_azure_client = MagicMock(spec=AsyncAzureOpenAI, api_key="key")
model = OpenAI(async_azure_client, "gpt-4", deployment_name="foo")
assert model.config.model == "foo"
assert model.model_name == "gpt-4"

wrong_client = object
with pytest.raises(ValueError):
model = OpenAI(wrong_client, "gpt-4")


def test_openai_call():
with patch("outlines.models.openai.generate_chat") as mocked_generate_chat:
mocked_generate_chat.return_value = ["foo"], 1, 2
async_client = MagicMock(spec=AsyncOpenAI, api_key="key")

model = OpenAI(async_client, "text-curie-001")
with pytest.raises(NotImplementedError):
model("bar")

model = OpenAI(
async_client,
"gpt-4",
OpenAIConfig(max_tokens=10, temperature=0.5, n=2, stop=["."]),
)

assert model("bar")[0] == "foo"
assert model.prompt_tokens == 1
assert model.completion_tokens == 2
mocked_generate_chat_args = mocked_generate_chat.call_args
mocked_generate_chat_arg_config = mocked_generate_chat_args[3]
assert isinstance(mocked_generate_chat_arg_config, OpenAIConfig)
assert mocked_generate_chat_arg_config.max_tokens == 10
assert mocked_generate_chat_arg_config.temperature == 0.5
assert mocked_generate_chat_arg_config.n == 2
assert mocked_generate_chat_arg_config.stop == ["."]

model("bar", samples=3)
mocked_generate_chat_args = mocked_generate_chat.call_args
mocked_generate_chat_arg_config = mocked_generate_chat_args[3]
assert mocked_generate_chat_arg_config.n == 3


@pytest.mark.parametrize(
"response,choice,expected_intersection,expected_choices_left",
(
Expand Down

0 comments on commit 2d8f284

Please sign in to comment.