Skip to content

Commit

Permalink
Implement first version of runtime.
Browse files Browse the repository at this point in the history
  • Loading branch information
norpadon committed Sep 7, 2024
1 parent fa2cf3b commit 6506a70
Show file tree
Hide file tree
Showing 16 changed files with 876 additions and 34 deletions.
13 changes: 12 additions & 1 deletion poetry.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ jinja2 = "^3.1.4"
python = ">3.10,<4.0"
tiktoken = "^0.7.0"
tenacity = "^8.5.0"
json5 = "^0.9.25"

[tool.poetry.extras]
openai = ["openai", "openai-function-tokens"]
Expand Down
172 changes: 172 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
from dataclasses import dataclass

import pytest
from jinja2 import Template

from wanga.common import JSON
from wanga.function import AIFunction, ai_function
from wanga.models import GenerationParams
from wanga.schema import CallableSchema
from wanga.schema.schema import SchemaValidationError


def test_basic_ai_function():
@ai_function()
def add_numbers(a: int, b: int) -> int:
"""
[|system|]
You are a helpful math assistant.
[|user|]
Please add the following numbers: {{ a }} and {{ b }}
"""
raise

assert isinstance(add_numbers.__ai_function, AIFunction) # type: ignore
assert list(add_numbers.__ai_function.signature.parameters.keys()) == ["a", "b"] # type: ignore
assert add_numbers.__ai_function.return_schema.call_schema.name == "submit_response" # type: ignore
assert isinstance(add_numbers.__ai_function.prompt_template, Template) # type: ignore


def test_ai_function_with_tools():
def multiply(x: int, y: int) -> int:
return x * y

@ai_function(tools=[multiply])
def complex_math(a: int, b: int) -> int:
"""
[|system|]
You are a math assistant capable of addition and multiplication.
[|user|]
I need help with a math problem. First, add {{ a }} and {{ b }}, then multiply the result by 2.
"""
raise

assert isinstance(complex_math.__ai_function, AIFunction) # type: ignore
assert len(complex_math.__ai_function.tools) == 1 # type: ignore
assert isinstance(complex_math.__ai_function.tools[0], CallableSchema) # type: ignore
assert complex_math.__ai_function.tools[0].call_schema.name == "multiply" # type: ignore


def test_ai_function_with_preferred_models():
@ai_function(preferred_models=["gpt-4", "claude-v1"])
def simple_task(task: str) -> str:
"""
[|system|]
You are a helpful assistant.
[|user|]
Please help me with the following task: {{ task }}
"""
raise

assert isinstance(simple_task.__ai_function, AIFunction) # type: ignore
assert simple_task.__ai_function.preferred_models == ["gpt-4", "claude-v1"] # type: ignore


def test_ai_function_with_generation_params():
@ai_function(generation_params=GenerationParams(max_tokens=100, temperature=0.7))
def creative_writing(topic: str) -> str:
"""
[|system|]
You are a creative writing assistant.
[|user|]
Write a short story about {{ topic }} in exactly 3 sentences.
"""
raise

assert isinstance(creative_writing.__ai_function, AIFunction) # type: ignore
assert creative_writing.__ai_function.generation_params.max_tokens == 100 # type: ignore
assert creative_writing.__ai_function.generation_params.temperature == 0.7 # type: ignore


def test_ai_function_missing_return_annotation():
with pytest.raises(ValueError, match="Function must have a concrete return type annotation."):

@ai_function()
def missing_return():
"""
[|system|]
You are a helpful assistant.
"""
raise


def test_ai_function_missing_docstring():
with pytest.raises(ValueError, match="Prompt is missing."):

@ai_function()
def missing_docstring() -> str:
raise


def test_ai_function_invalid_template():
with pytest.raises(ValueError, match="Variable invalid_var is not a parameter of the function."):

@ai_function()
def invalid_template(name: str) -> str:
"""
[|system|]
You are a helpful assistant.
[|user|]
Please greet {{ invalid_var }}
"""
raise


def test_ai_function_return_schema():
@ai_function()
def greet(name: str) -> str:
"""
[|system|]
You are a friendly greeter.
[|user|]
Please greet {{ name }}.
"""
raise NotImplementedError

ai_func = greet.__ai_function # type: ignore
assert isinstance(ai_func, AIFunction)
assert ai_func.return_schema is None # For string return type, return_schema is None


def test_ai_function_eval_return_schema():
@dataclass
class RectangleArea:
area: float
perimeter: float

@ai_function()
def calculate_area(length: float, width: float) -> RectangleArea:
"""
[|system|]
You are a helpful assistant that calculates the area of rectangles.
[|user|]
Please calculate the area of a rectangle with length {{ length }} and width {{ width }}.
"""
raise NotImplementedError

ai_func = calculate_area.__ai_function # type: ignore
# Test with valid input
valid_input: JSON = {"response": {"area": 50.0, "perimeter": 30.0}}
result = ai_func.return_schema.eval(valid_input)
assert isinstance(result, RectangleArea)
assert result.area == 50.0
assert result.perimeter == 30.0

# Test with invalid input (missing field)
with pytest.raises(SchemaValidationError):
ai_func.return_schema.eval({"response": {"area": 50.0}})

# Test with invalid input (wrong type)
with pytest.raises(SchemaValidationError):
ai_func.return_schema.eval({"response": {"area": "50.0", "perimeter": 30.0}})

# Test with invalid input (extra field)
with pytest.raises(SchemaValidationError):
ai_func.return_schema.eval({"response": {"area": 50.0, "perimeter": 30.0, "extra": 10}})
4 changes: 2 additions & 2 deletions tests/test_models/test_openai.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from wanga.models.model import FinishReason, ModelResponse, ToolParams
from wanga.models.model import RateLimitError as WangaRateLimitError
from wanga.models.openai import OpenaAIModel
from wanga.schema import default_schema_extractor
from wanga.schema import DEFAULT_SCHEMA_EXTRACTOR


def test_reply():
Expand Down Expand Up @@ -52,7 +52,7 @@ def test_num_tokens():
def tool(x: int, y: str):
pass

tool_schema = default_schema_extractor.extract_schema(tool)
tool_schema = DEFAULT_SCHEMA_EXTRACTOR.extract_schema(tool)
prompt = dedent(prompt.removeprefix("\n"))
messages = parse_messages(prompt)
tools = ToolParams(tools=[tool_schema])
Expand Down
Loading

0 comments on commit 6506a70

Please sign in to comment.