Skip to content

Commit

Permalink
Add docstring comments for TypeChat for Python (#191)
Browse files Browse the repository at this point in the history
* Add docstring comments, rework some logic around `create_language_model`.

* Rename `TypeChatModel` to `TypeChatLanguageModel`.
  • Loading branch information
DanielRosenwasser authored Feb 26, 2024
1 parent aeb4d89 commit 497f3c9
Show file tree
Hide file tree
Showing 12 changed files with 113 additions and 31 deletions.
4 changes: 2 additions & 2 deletions python/examples/healthData/translator.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
from typing_extensions import TypeVar, Any, override, TypedDict, Literal

from typechat import TypeChatValidator, TypeChatModel, TypeChatTranslator, Result, Failure
from typechat import TypeChatValidator, TypeChatLanguageModel, TypeChatTranslator, Result, Failure

from datetime import datetime

Expand All @@ -19,7 +19,7 @@ class TranslatorWithHistory(TypeChatTranslator[T]):
_additional_agent_instructions: str

def __init__(
self, model: TypeChatModel, validator: TypeChatValidator[T], target_type: type[T], additional_agent_instructions: str
self, model: TypeChatLanguageModel, validator: TypeChatValidator[T], target_type: type[T], additional_agent_instructions: str
):
super().__init__(model=model, validator=validator, target_type=target_type)
self._chat_history = []
Expand Down
4 changes: 2 additions & 2 deletions python/examples/math/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
Failure,
Result,
Success,
TypeChatModel,
TypeChatLanguageModel,
TypeChatValidator,
TypeChatTranslator,
python_type_to_typescript_schema,
Expand Down Expand Up @@ -149,7 +149,7 @@ def validate(self, json_text: str) -> Result[JsonProgram]:
class TypeChatProgramTranslator(TypeChatTranslator[JsonProgram]):
_api_declaration_str: str

def __init__(self, model: TypeChatModel, validator: TypeChatProgramValidator, api_type: type):
def __init__(self, model: TypeChatLanguageModel, validator: TypeChatProgramValidator, api_type: type):
super().__init__(model=model, validator=validator, target_type=api_type)
# TODO: the conversion result here has errors!
conversion_result = python_type_to_typescript_schema(api_type)
Expand Down
8 changes: 4 additions & 4 deletions python/examples/multiSchema/agents.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
import json

from typing_extensions import TypeVar, Generic
from typechat import Failure, TypeChatTranslator, TypeChatValidator, TypeChatModel
from typechat import Failure, TypeChatTranslator, TypeChatValidator, TypeChatLanguageModel

import examples.math.schema as math_schema
from examples.math.program import (
Expand All @@ -29,7 +29,7 @@ class JsonPrintAgent(Generic[T]):
_validator: TypeChatValidator[T]
_translator: TypeChatTranslator[T]

def __init__(self, model: TypeChatModel, target_type: type[T]):
def __init__(self, model: TypeChatLanguageModel, target_type: type[T]):
super().__init__()
self._validator = TypeChatValidator(target_type)
self._translator = TypeChatTranslator(model, self._validator, target_type)
Expand All @@ -47,7 +47,7 @@ class MathAgent:
_validator: TypeChatProgramValidator
_translator: TypeChatProgramTranslator

def __init__(self, model: TypeChatModel):
def __init__(self, model: TypeChatLanguageModel):
super().__init__()
self._validator = TypeChatProgramValidator()
self._translator = TypeChatProgramTranslator(model, self._validator, math_schema.MathAPI)
Expand Down Expand Up @@ -95,7 +95,7 @@ class MusicAgent:
_client_context: ClientContext | None
_authentication_vals: dict[str, str | None]

def __init__(self, model: TypeChatModel, authentication_vals: dict[str, str | None]):
def __init__(self, model: TypeChatLanguageModel, authentication_vals: dict[str, str | None]):
super().__init__()
self._validator = TypeChatValidator(music_schema.PlayerActions)
self._translator = TypeChatTranslator(model, self._validator, music_schema.PlayerActions)
Expand Down
4 changes: 2 additions & 2 deletions python/examples/multiSchema/router.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import json
from typing_extensions import Any, Callable, Awaitable, TypedDict, Annotated
from typechat import Failure, TypeChatValidator, TypeChatModel, TypeChatTranslator
from typechat import Failure, TypeChatValidator, TypeChatLanguageModel, TypeChatTranslator


class AgentInfo(TypedDict):
Expand All @@ -18,7 +18,7 @@ class TextRequestRouter:
_validator: TypeChatValidator[TaskClassification]
_translator: TypeChatTranslator[TaskClassification]

def __init__(self, model: TypeChatModel):
def __init__(self, model: TypeChatLanguageModel):
super().__init__()
self._validator = TypeChatValidator(TaskClassification)
self._translator = TypeChatTranslator(model, self._validator, TaskClassification)
Expand Down
4 changes: 2 additions & 2 deletions python/src/typechat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,15 @@
#
# SPDX-License-Identifier: MIT

from typechat._internal.model import TypeChatModel, create_language_model
from typechat._internal.model import TypeChatLanguageModel, create_language_model
from typechat._internal.result import Failure, Result, Success
from typechat._internal.translator import TypeChatTranslator
from typechat._internal.ts_conversion import python_type_to_typescript_schema
from typechat._internal.validator import TypeChatValidator
from typechat._internal.interactive import process_requests

__all__ = [
"TypeChatModel",
"TypeChatLanguageModel",
"TypeChatTranslator",
"TypeChatValidator",
"Success",
Expand Down
10 changes: 10 additions & 0 deletions python/src/typechat/_internal/interactive.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
from typing import Callable, Awaitable

async def process_requests(interactive_prompt: str, input_file_name: str | None, process_request: Callable[[str], Awaitable[None]]):
"""
A request processor for interactive input or input from a text file. If an input file name is specified,
the callback function is invoked for each line in file. Otherwise, the callback function is invoked for
each line of interactive input until the user types "quit" or "exit".
Args:
interactive_prompt: Prompt to present to user.
input_file_name: Input text file name, if any.
process_request: Async callback function that is invoked for each interactive input or each line in text file.
"""
if input_file_name is not None:
with open(input_file_name, "r") as file:
lines = filter(str.rstrip, file)
Expand Down
53 changes: 41 additions & 12 deletions python/src/typechat/_internal/model.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,23 @@
from typing_extensions import Protocol, override
import os
import openai

from typechat._internal.result import Failure, Result, Success


class TypeChatModel(Protocol):
class TypeChatLanguageModel(Protocol):
async def complete(self, input: str) -> Result[str]:
"""
Represents a AI language model that can complete prompts.
TypeChat uses an implementation of this protocol to communicate
with an AI service that can translate natural language requests to JSON
instances according to a provided schema.
The `create_language_model` function can create an instance.
"""
...


class DefaultOpenAIModel(TypeChatModel):
class DefaultOpenAIModel(TypeChatLanguageModel):
model_name: str
client: openai.AsyncOpenAI | openai.AsyncAzureOpenAI

Expand All @@ -34,24 +41,46 @@ async def complete(self, input: str) -> Result[str]:
except Exception as e:
return Failure(str(e))

def create_language_model(vals: dict[str,str|None]) -> TypeChatModel:
model: TypeChatModel
def create_language_model(vals: dict[str, str | None]) -> TypeChatLanguageModel:
"""
Creates a language model encapsulation of an OpenAI or Azure OpenAI REST API endpoint
chosen by a dictionary of variables (typically just `os.environ`).
If an `OPENAI_API_KEY` environment variable exists, an OpenAI model is constructed.
The `OPENAI_ENDPOINT` and `OPENAI_MODEL` environment variables must also be defined or an error will be raised.
If an `AZURE_OPENAI_API_KEY` environment variable exists, an Azure OpenAI model is constructed.
The `AZURE_OPENAI_ENDPOINT` environment variable must also be defined or an exception will be thrown.
If none of these key variables are defined, an exception is thrown.
@returns An instance of `TypeChatLanguageModel`.
Args:
vals: A dictionary of variables. Typically just `os.environ`.
"""
model: TypeChatLanguageModel
client: openai.AsyncOpenAI | openai.AsyncAzureOpenAI

def required_var(name: str) -> str:
val = vals.get(name, None)
if val is None:
raise ValueError(f"Missing environment variable {name}.")
return val

if "OPENAI_API_KEY" in vals:
client = openai.AsyncOpenAI(api_key=vals["OPENAI_API_KEY"])
model = DefaultOpenAIModel(model_name=vals.get("OPENAI_MODEL", None) or "gpt-35-turbo", client=client)
client = openai.AsyncOpenAI(api_key=required_var("OPENAI_API_KEY"))
model = DefaultOpenAIModel(model_name=required_var("OPENAI_MODEL"), client=client)

elif "AZURE_OPENAI_API_KEY" in vals and "AZURE_OPENAI_ENDPOINT" in vals:
os.environ["OPENAI_API_TYPE"] = "azure"
elif "AZURE_OPENAI_API_KEY" in vals:
openai.api_type = "azure"
client = openai.AsyncAzureOpenAI(
azure_endpoint=vals.get("AZURE_OPENAI_ENDPOINT", None) or "",
api_key=vals["AZURE_OPENAI_API_KEY"],
api_key=required_var("AZURE_OPENAI_API_KEY"),
azure_endpoint=required_var("AZURE_OPENAI_ENDPOINT"),
api_version="2023-03-15-preview",
)
model = DefaultOpenAIModel(model_name=vals.get("AZURE_OPENAI_MODEL", None) or "gpt-35-turbo", client=client)

else:
raise ValueError("Missing environment variables for Open AI or Azure OpenAI model")
raise ValueError("Missing environment variables for OPENAI_API_KEY or AZURE_OPENAI_API_KEY.")

return model
9 changes: 7 additions & 2 deletions python/src/typechat/_internal/result.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,21 @@
from dataclasses import dataclass
from typing_extensions import Generic, TypeVar
from typing_extensions import Generic, TypeAlias, TypeVar

T = TypeVar("T", covariant=True)

@dataclass
class Success(Generic[T]):
"An object representing a successful operation with a result of type `T`."
value: T


@dataclass
class Failure:
"An object representing an operation that failed for the reason given in `message`."
message: str


Result = Success[T] | Failure
"""
An object representing a successful or failed operation of type `T`.
"""
Result: TypeAlias = Success[T] | Failure
31 changes: 27 additions & 4 deletions python/src/typechat/_internal/translator.py
Original file line number Diff line number Diff line change
@@ -1,30 +1,53 @@
from typing_extensions import Generic, TypeVar

from typechat._internal.model import TypeChatModel
from typechat._internal.model import TypeChatLanguageModel
from typechat._internal.result import Failure, Result, Success
from typechat._internal.ts_conversion import python_type_to_typescript_schema
from typechat._internal.validator import TypeChatValidator

T = TypeVar("T", covariant=True)

class TypeChatTranslator(Generic[T]):
model: TypeChatModel
"""
Represents an object that can translate natural language requests in JSON objects of the given type.
"""

model: TypeChatLanguageModel
validator: TypeChatValidator[T]
target_type: type[T]
_type_name: str
_schema_str: str
_max_repair_attempts = 1

def __init__(self, model: TypeChatModel, validator: TypeChatValidator[T], target_type: type[T]):
def __init__(self, model: TypeChatLanguageModel, validator: TypeChatValidator[T], target_type: type[T]):
"""
Args:
model: The associated `TypeChatLanguageModel`.
validator: The associated `TypeChatValidator[T]`.
target_type: A runtime type object describing `T` - the expected shape of JSON data.
"""
super().__init__()
self.model = model
self.target_type = target_type
self.validator = validator
self.target_type = target_type

conversion_result = python_type_to_typescript_schema(target_type)
# TODO: Examples may not work here!
# if conversion_result.errors:
# raise ValueError(f"Could not convert Python type to TypeScript schema: {conversion_result.errors}")
self._type_name = conversion_result.typescript_type_reference
self._schema_str = conversion_result.typescript_schema_str

async def translate(self, request: str) -> Result[T]:
"""
Translates a natural language request into an object of type `T`. If the JSON object returned by
the language model fails to validate, repair attempts will be made up until `_max_repair_attempts`.
The prompt for the subsequent attempts will include the diagnostics produced for the prior attempt.
This often helps produce a valid instance.
Args:
request: A natural language request.
"""
request = self._create_request_prompt(request)
num_repairs_attempted = 0
while True:
Expand Down
3 changes: 2 additions & 1 deletion python/src/typechat/_internal/ts_conversion/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
from dataclasses import dataclass
from typing_extensions import TypeAliasType

from typechat._internal.ts_conversion.python_type_to_ts_nodes import python_type_to_typescript_nodes
from typechat._internal.ts_conversion.ts_node_to_string import ts_declaration_to_str
Expand All @@ -19,7 +20,7 @@ class TypeScriptSchemaConversionResult:
errors: list[str]
"""Any errors that occurred during conversion."""

def python_type_to_typescript_schema(py_type: object) -> TypeScriptSchemaConversionResult:
def python_type_to_typescript_schema(py_type: type | TypeAliasType) -> TypeScriptSchemaConversionResult:
"""Converts a Python type to a TypeScript schema."""

node_conversion_result = python_type_to_typescript_nodes(py_type)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ class GenericDeclarationish(Protocol):
class GenericAliasish(Protocol):
__origin__: object
__args__: tuple[object, ...]
__name__: str


class Annotatedish(Protocol):
Expand Down
13 changes: 13 additions & 0 deletions python/src/typechat/_internal/validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,15 +8,28 @@
T = TypeVar("T", covariant=True)

class TypeChatValidator(Generic[T]):
"""
Validates JSON text against a given Python type.
"""
_type: type[T]
_adapted_type: pydantic.TypeAdapter[T]

def __init__(self, py_type: type[T]):
"""
Args:
py_type: The schema type to validate against.
"""
super().__init__()
self._type = py_type
self._adapted_type = pydantic.TypeAdapter(py_type)

def validate(self, json_text: str) -> Result[T]:
"""
Validates the given JSON object according to the associated TypeScript schema. Returns a
`Success[T]` object containing the JSON object if validation was successful. Otherwise, returns
a `Failure` object with a `message` property describing the error.
"""
try:
typed_dict = self._adapted_type.validate_json(json_text, strict=True)
return Success(typed_dict)
Expand Down

0 comments on commit 497f3c9

Please sign in to comment.