Skip to content

Commit

Permalink
Add end2end tests.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Sep 7, 2024
1 parent fb0213b commit 6b6fcdf
Show file tree
Hide file tree
Showing 4 changed files with 287 additions and 29 deletions.
245 changes: 245 additions & 0 deletions tests/test_ai_function_integration.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,245 @@
from dataclasses import dataclass
from datetime import date, datetime, time, timedelta
from typing import List, Literal, Optional

import pytest

from wanga.function import ai_function
from wanga.models.openai import OpenaAIModel
from wanga.runtime import Runtime

model = OpenaAIModel("gpt-4o-mini")


@pytest.fixture(scope="module")
def runtime():
with Runtime(model) as rt:
yield rt


def test_basic_ai_function(runtime):
@ai_function()
def greet(name: str) -> str:
"""
[|system|]
You are a friendly assistant.
[|user|]
Greet {{ name }} in a friendly manner.
"""
raise NotImplementedError

result = greet("Alice")
assert isinstance(result, str)
assert "Alice" in result


def test_ai_function_with_tools(runtime):
def get_current_time() -> str:
return datetime.now().strftime("%H:%M:%S")

@ai_function(tools=[get_current_time])
def time_greeting(name: str) -> str:
"""
[|system|]
You are a helpful assistant who can tell the time.
[|user|]
Greet {{ name }} and tell them the current time.
"""
raise NotImplementedError

result = time_greeting("Bob")
assert isinstance(result, str)
assert "Bob" in result
assert ":" in result # Checking for time format


def test_ai_function_with_complex_return_type(runtime):
from dataclasses import dataclass

@dataclass
class WeatherInfo:
temperature: float
conditions: str

@ai_function()
def get_weather(city: str) -> WeatherInfo:
"""
[|system|]
You are a weather information provider.
[|user|]
Provide the current weather for {{ city }}. Use realistic but made-up data.
"""
raise NotImplementedError

result = get_weather("New York")
assert isinstance(result, WeatherInfo)
assert isinstance(result.temperature, float)
assert isinstance(result.conditions, str)


def test_ai_function_with_complex_prompt(runtime):
@ai_function()
def analyze_text(text: str, focus: str, word_limit: int) -> str:
"""
[|system|]
You are a text analysis expert.
[|user|]
Analyze the following text:
{{ text }}
Focus on aspects related to {{ focus }}.
Provide your analysis in {{ word_limit }} words or less.
"""
raise NotImplementedError

result = analyze_text(text="The quick brown fox jumps over the lazy dog.", focus="animal behavior", word_limit=50)
assert isinstance(result, str)
assert len(result.split()) <= 50


def test_ai_function_with_multiple_tools(runtime):
num_calls = 0

def add(a: int, b: int) -> int:
nonlocal num_calls
num_calls += 1
return a + b

def multiply(a: int, b: int) -> int:
nonlocal num_calls
num_calls += 1
return a * b

@ai_function(tools=[add, multiply])
def complex_calculation(x: int, y: int, z: int) -> int:
"""
[|system|]
You are a math assistant with addition and multiplication capabilities.
[|user|]
Perform the following calculation:
1. Add {{ x }} and {{ y }}
2. Multiply the result by {{ z }}
"""
raise NotImplementedError

result = complex_calculation(5, 3, 2)
assert result == 16 # (5 + 3) * 2 = 16
assert num_calls == 2 # Two tool functions were called


@dataclass
class Address:
street: str
city: str
country: str
postal_code: str


@dataclass
class Person:
name: str
age: int
address: Address
birth_date: date


def test_ai_function_with_nested_dataclasses(runtime):
@ai_function()
def generate_person_info() -> Person:
"""
[|system|]
You are an AI assistant capable of generating realistic person information.
[|user|]
Generate information for a fictional person, including their name, age, address, and birth date.
"""
raise NotImplementedError

result = generate_person_info()
assert isinstance(result, Person)
assert isinstance(result.address, Address)
assert isinstance(result.birth_date, date)
assert 0 <= result.age <= 120


JobStatus = Literal["Employed", "Unemployed", "Student"]


@dataclass
class WorkExperience:
company: str
position: str
start_date: date
end_date: Optional[date] = None


@dataclass
class ComplexPerson:
name: str
age: int
status: JobStatus
experiences: List[WorkExperience]
last_update: datetime


def test_ai_function_with_complex_structures(runtime):
@ai_function()
def generate_complex_person() -> ComplexPerson:
"""
[|system|]
You are an AI assistant capable of generating detailed person profiles.
[|user|]
Generate a complex person profile with name, age, job status, work experiences, and last update time.
Include at least two work experiences.
"""
raise NotImplementedError

result = generate_complex_person()
assert isinstance(result, ComplexPerson)
assert len(result.experiences) >= 2
for exp in result.experiences:
assert isinstance(exp, WorkExperience)
assert isinstance(exp.start_date, date)
assert exp.end_date is None or isinstance(exp.end_date, date)
assert isinstance(result.last_update, datetime)


@dataclass
class TimeRange:
start: time
end: time
duration: timedelta


def test_ai_function_with_time_structures(runtime):
@ai_function()
def generate_work_schedule() -> List[TimeRange]:
"""
[|system|]
You are an AI assistant capable of generating work schedules.
[|user|]
Generate a work schedule for a typical day, with at least 3 time ranges.
Each time range should have a start time, end time, and duration.
"""
raise NotImplementedError

result = generate_work_schedule()
assert isinstance(result, list)
assert len(result) >= 3
for time_range in result:
assert isinstance(time_range, TimeRange)
assert isinstance(time_range.start, time)
assert isinstance(time_range.end, time)
assert isinstance(time_range.duration, timedelta)
assert time_range.duration == timedelta(
hours=time_range.end.hour - time_range.start.hour,
minutes=time_range.end.minute - time_range.start.minute,
)
18 changes: 10 additions & 8 deletions wanga/models/openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -226,13 +226,15 @@ def _format_tools(tools: list[CallableSchema]) -> list:


def _format_tool_choice(tool_use_mode: ToolUseMode | str):
if tool_use_mode == ToolUseMode.AUTO:
return "auto"
elif tool_use_mode == ToolUseMode.FORCE:
return "required"
elif tool_use_mode == ToolUseMode.NEVER:
return "none"
return {"type": "function", "name": tool_use_mode}
match tool_use_mode:
case ToolUseMode.AUTO:
return "auto"
case ToolUseMode.FORCE:
return "required"
case ToolUseMode.NEVER:
return "none"
case _:
return {"type": "function", "name": tool_use_mode}


def _format_image_content(image: ImageContent) -> dict:
Expand Down Expand Up @@ -289,7 +291,7 @@ def _format_message(message: Message) -> dict:
case ToolMessage(content=content, invocation_id=invocation_id):
if not isinstance(content, str):
raise ValueError("OpenAI doesn't support images in tool messages.")
return {"role": "tool", "content": content, "invocation_id": invocation_id}
return {"role": "tool", "content": content, "tool_call_id": invocation_id}
case _:
raise ValueError(f"Unknown message type: {message}")
return result
Expand Down
43 changes: 22 additions & 21 deletions wanga/runtime.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,28 +73,29 @@ def call_and_use_tools(
while True:
response = model.reply(messages, tool_params, generation_params).response_options[0]
message = response.message
match response.finish_reason:
case FinishReason.STOP:
if not allow_plain_text_response:
raise RuntimeError("Model generated plain text response.")
messages.append(message)
if response.finish_reason == FinishReason.STOP:
if allow_plain_text_response:
return message.content
case FinishReason.TOOL_CALL:
if message.tool_invocations:
tool_invocation_results = invoke_tools(message, tool_params)
if isinstance(tool_invocation_results, FinalResponse):
return tool_invocation_results.value
tool_messages, errors = tool_invocation_results
messages.extend(tool_messages)
if errors:
if retries_left == 0:
raise errors[0]
retries_left -= 1
else:
retries_left = max_retries_on_invalid_output
case FinishReason.CONTENT_FILTER:
raise ContentFilterError(f"Content filter triggered: {message.content}")
case FinishReason.LENGTH:
raise OutputTooLongError(message.content)
if response.finish_reason in [FinishReason.STOP, FinishReason.TOOL_CALL]:
if message.tool_invocations:
tool_invocation_results = invoke_tools(message, tool_params)
if isinstance(tool_invocation_results, FinalResponse):
return tool_invocation_results.value
tool_messages, errors = tool_invocation_results
messages.extend(tool_messages)
if errors:
if retries_left == 0:
raise errors[0]
retries_left -= 1
else:
retries_left = max_retries_on_invalid_output
else:
raise RuntimeError("Model returned a stop response without any tool invocations.")
elif response.finish_reason == FinishReason.CONTENT_FILTER:
raise ContentFilterError(f"Content filter triggered: {message.content}")
elif FinishReason.LENGTH:
raise OutputTooLongError(message.content)


class Runtime:
Expand Down
10 changes: 10 additions & 0 deletions wanga/schema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,10 @@ class UnionNode(SchemaNode):

options: list[SchemaNode | None]

@property
def is_optional(self) -> bool:
return None in self.options

@property
def is_primitive(self) -> bool:
if len(self.options) == 1:
Expand All @@ -208,6 +212,12 @@ def is_primitive(self) -> bool:
def json_schema(self, parent_hint: str | None = None) -> JsonSchema:
if not self.is_primitive:
raise UnsupportedSchemaError("JSON schema cannot be generated for non-trivial Union types.")
if self.is_optional:
options: list[SchemaNode | None] = [option for option in self.options if option is not None]
if len(options) == 1:
assert options[0] is not None
return options[0].json_schema(parent_hint)
return UnionNode(options).json_schema(parent_hint)
type_names = [
_type_to_jsonname[option.primitive_type] # type: ignore
for option in self.options
Expand Down

0 comments on commit 6b6fcdf

Please sign in to comment.