Skip to content

Commit

Permalink
Fix vision.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Sep 7, 2024
1 parent 6b6fcdf commit 5587fbc
Show file tree
Hide file tree
Showing 7 changed files with 158 additions and 127 deletions.
230 changes: 121 additions & 109 deletions poetry.lock

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion tests/test_ai_function_integration.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def analyze_text(text: str, focus: str, word_limit: int) -> str:

result = analyze_text(text="The quick brown fox jumps over the lazy dog.", focus="animal behavior", word_limit=50)
assert isinstance(result, str)
assert len(result.split()) <= 50
assert len(result.split()) <= 55


def test_ai_function_with_multiple_tools(runtime):
Expand Down
20 changes: 17 additions & 3 deletions tests/test_models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@


def test_reply():
model = OpenaAIModel("gpt-3.5-turbo")
model = OpenaAIModel("gpt-4o-mini")
prompt = r"""
[|system|]
You are a helpful assistant.
Expand All @@ -41,7 +41,7 @@ def test_context_size():


def test_num_tokens():
model = OpenaAIModel("gpt-3.5-turbo")
model = OpenaAIModel("gpt-4o-mini")
prompt = r"""
[|system|]
You are a helpful assistant.
Expand All @@ -61,7 +61,12 @@ def tool(x: int, y: str):

@pytest.fixture
def model():
return OpenaAIModel("gpt-3.5-turbo", num_retries=2, retry_on_request_limit=True)
return OpenaAIModel("gpt-4o-mini", num_retries=2, retry_on_request_limit=True)


@pytest.fixture
def vision_model():
return OpenaAIModel("gpt-4o", num_retries=2, retry_on_request_limit=True)


def test_retry_on_rate_limit(model):
Expand Down Expand Up @@ -178,3 +183,12 @@ def test_no_retry_when_disabled(model):
model.reply(messages)

assert mock_create.call_count == 1


def test_vision(vision_model):
messages = parse_messages(
"How many kittens are in this image?"
'<|image url="https://encrypted-tbn0.gstatic.com/images?q=tbn:ANd9GcSWHPE0dCs93yAjfxnT2IOR-lbNhvur5FlmkQ&s"|>'
)
response = vision_model.reply(messages)
assert "four" in response.response_options[0].message.content.lower() # type: ignore
4 changes: 2 additions & 2 deletions wanga/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
ModelError,
ModelResponse,
ModelTimeoutError,
PromptTooLongError,
PromptError,
RateLimitError,
ResponseOption,
ServiceUnvailableError,
Expand All @@ -36,7 +36,7 @@
"ModelResponse",
"ModelError",
"AuthenticationError",
"PromptTooLongError",
"PromptError",
"ModelTimeoutError",
"InvalidJsonError",
"RateLimitError",
Expand Down
4 changes: 2 additions & 2 deletions wanga/models/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ class HeaderRegexes(NamedTuple):
def parse(self, header_str: str) -> ParsedHeader:
header_match = self.full_regex.match(header_str)
if header_match is None:
raise MessageSyntaxError(f"Invalid header: {header_str}")
raise MessageSyntaxError(f"Invalid header: {header_str}.")
name = header_match.group("name")
params_str = header_match.group("params")
params = {}
Expand All @@ -170,7 +170,7 @@ def split_text(self, text: str) -> Iterable[ParsedHeader | str]:
yield text_block


_URL_SPECIAL_SYMBOLS = re.escape(r"/:!#$%&'*+-.^_`|~")
_URL_SPECIAL_SYMBOLS = re.escape(r"/:!#$%&'*+-.^_`|~?=")
_PARAM_KEY_SYMBOLS = r"[a-zA-Z0-9_\-]"
_PARAM_VALUE_SYMBOLS = f"[a-zA-Z0-9{_URL_SPECIAL_SYMBOLS}]"

Expand Down
10 changes: 4 additions & 6 deletions wanga/models/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"ModelResponse",
"ModelError",
"AuthenticationError",
"PromptTooLongError",
"PromptError",
"ModelTimeoutError",
"InvalidJsonError",
"RateLimitError",
Expand Down Expand Up @@ -48,7 +48,7 @@ class GenerationParams:
top_p: float | None = None
frequency_penalty: float | None = None
presence_penalty: float | None = None
stop_sequences: list[str] = field(factory=list)
stop_sequences: list[str] | None = None
random_seed: int | None = None
force_json: bool = False

Expand Down Expand Up @@ -90,10 +90,8 @@ class AuthenticationError(ModelError):
r"""Raised when the API credentials are invalid."""


class PromptTooLongError(ValueError, ModelError):
r"""Raised when the total length of prompt and requested response exceeds the maximum allowed number of tokens,
or the size of requested completion exceeds the maximum response size.
"""
class PromptError(ValueError, ModelError):
r"""Raised when the prompt is malformed or uses unsupported features."""


class ModelTimeoutError(TimeoutError, ModelError):
Expand Down
15 changes: 11 additions & 4 deletions wanga/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
Model,
ModelResponse,
ModelTimeoutError,
PromptTooLongError,
PromptError,
RateLimitError,
ResponseOption,
ServiceUnvailableError,
Expand Down Expand Up @@ -92,7 +92,7 @@ def estimate_num_tokens(self, messages: list[Message], tools: ToolParams) -> int
def calculate_num_tokens(self, messages: list[Message], tools: ToolParams) -> int:
try:
self.reply(messages, tools, GenerationParams(max_tokens=_TOO_MANY_TOKENS))
except PromptTooLongError as e:
except PromptError as e:
err_string = str(e)
match = _NUM_TOKENS_ERR_RE.search(err_string)
if match is None:
Expand Down Expand Up @@ -134,6 +134,8 @@ def _get_reply_kwargs(
tool_choice=tool_choice,
user=user_id,
)
# OpenAI API breaks if we explicitly pass default values for some keys.
result = {k: v for k, v in result.items() if v is not None}
if tools.tools:
result["parallel_tool_calls"] = tools.allow_parallel_calls
return result
Expand Down Expand Up @@ -196,7 +198,7 @@ def _wrap_error(error: Exception) -> Exception:
case openai.APITimeoutError() as e:
return ModelTimeoutError(e)
case openai.BadRequestError() as e:
return PromptTooLongError(e)
return PromptError(e)
case openai.AuthenticationError() as e:
return AuthenticationError(e)
case openai.RateLimitError() as e:
Expand Down Expand Up @@ -242,7 +244,12 @@ def _format_image_content(image: ImageContent) -> dict:
url = image.url
else:
url = f"data:image/jpeg;base64,{image.base64}"
return {"type": "image_url", "url": url}
return {
"type": "image_url",
"image_url": {
"url": url,
},
}


def _format_content(content: str | list[str | ImageContent]):
Expand Down

0 comments on commit 5587fbc

Please sign in to comment.