From af2cf1b17fce9452bd275b737b86a0c7a622be59 Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Sat, 20 Apr 2024 01:07:49 +0000 Subject: [PATCH 1/3] Add a test for the built-in translator and add the snapshots. --- .../tests/__snapshots__/test_translator.ambr | 84 +++++++++++++++++++ python/tests/test_translator.py | 53 ++++++++++++ 2 files changed, 137 insertions(+) create mode 100644 python/tests/__snapshots__/test_translator.ambr create mode 100644 python/tests/test_translator.py diff --git a/python/tests/__snapshots__/test_translator.ambr b/python/tests/__snapshots__/test_translator.ambr new file mode 100644 index 00000000..7ca089bd --- /dev/null +++ b/python/tests/__snapshots__/test_translator.ambr @@ -0,0 +1,84 @@ +# serializer version: 1 +# name: test_translator_with_immediate_pass + list([ + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true, "c": 1234 }', + }), + ]) +# --- +# name: test_translator_with_single_failure + list([ + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true }', + }), + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true, "c": 1234 }', + }), + ]) +# --- diff --git a/python/tests/test_translator.py b/python/tests/test_translator.py new file mode 100644 index 00000000..1245ef34 --- /dev/null +++ b/python/tests/test_translator.py @@ -0,0 +1,53 @@ + +import asyncio +from dataclasses import dataclass +from typing_extensions import Any, Iterator, Literal, TypedDict, override +import typechat + +class ConvoRecord(TypedDict): + kind: Literal["CLIENT REQUEST", "MODEL RESPONSE"] + payload: str | list[typechat.PromptSection] + +class FixedModel(typechat.TypeChatLanguageModel): + responses: Iterator[str] + conversation: list[ConvoRecord] + + "A model which responds with one of a series of responses." + def __init__(self, responses: list[str]) -> None: + super().__init__() + self.responses = iter(responses) + self.conversation = [] + + @override + async def complete(self, prompt: str | list[typechat.PromptSection]) -> typechat.Result[str]: + self.conversation.append({ "kind": "CLIENT REQUEST", "payload": prompt }) + response = next(self.responses) + self.conversation.append({ "kind": "MODEL RESPONSE", "payload": response }) + return typechat.Success(response) + +@dataclass +class ExampleABC: + a: str + b: bool + c: int + +v = typechat.TypeChatValidator(ExampleABC) + +def test_translator_with_immediate_pass(snapshot: Any): + m = FixedModel([ + '{ "a": "hello", "b": true, "c": 1234 }', + ]) + t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) + asyncio.run(t.translate("Get me stuff.")) + + assert m.conversation == snapshot + +def test_translator_with_single_failure(snapshot: Any): + m = FixedModel([ + '{ "a": "hello", "b": true }', + '{ "a": "hello", "b": true, "c": 1234 }', + ]) + t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) + asyncio.run(t.translate("Get me stuff.")) + + assert m.conversation == snapshot \ No newline at end of file From 48a4ab38d8a28c7a1d44eb33fa394fcba286b1bd Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Sat, 20 Apr 2024 01:20:50 +0000 Subject: [PATCH 2/3] Ensure the repair prompt is actually appended to our messages and update snapshots. --- python/examples/healthData/translator.py | 4 +- .../tests/__snapshots__/test_translator.ambr | 155 ++++++++++++------ 2 files changed, 106 insertions(+), 53 deletions(-) diff --git a/python/examples/healthData/translator.py b/python/examples/healthData/translator.py index a1565334..d9c1c46b 100644 --- a/python/examples/healthData/translator.py +++ b/python/examples/healthData/translator.py @@ -27,8 +27,8 @@ def __init__( self._additional_agent_instructions = additional_agent_instructions @override - async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]: - result = await super().translate(request=request, prompt_preamble=prompt_preamble) + async def translate(self, input: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]: + result = await super().translate(input=input, prompt_preamble=prompt_preamble) if not isinstance(result, Failure): self._chat_history.append(ChatMessage(source="assistant", body=result.value)) return result diff --git a/python/tests/__snapshots__/test_translator.ambr b/python/tests/__snapshots__/test_translator.ambr index 7ca089bd..594786fb 100644 --- a/python/tests/__snapshots__/test_translator.ambr +++ b/python/tests/__snapshots__/test_translator.ambr @@ -3,24 +3,33 @@ list([ dict({ 'kind': 'CLIENT REQUEST', - 'payload': ''' - - You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: - ``` - interface ExampleABC { - a: string; - b: boolean; - c: number; - } - - ``` - The following is a user request: - ''' - Get me stuff. - ''' - The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + 'payload': list([ + dict({ + 'content': 'Get me stuff.', + 'role': 'user', + }), + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: - ''', + ''', + 'role': 'user', + }), + ]), }), dict({ 'kind': 'MODEL RESPONSE', @@ -32,24 +41,46 @@ list([ dict({ 'kind': 'CLIENT REQUEST', - 'payload': ''' - - You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: - ``` - interface ExampleABC { - a: string; - b: boolean; - c: number; - } - - ``` - The following is a user request: - ''' - Get me stuff. - ''' - The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + 'payload': list([ + dict({ + 'content': 'Get me stuff.', + 'role': 'user', + }), + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: - ''', + ''', + 'role': 'user', + }), + dict({ + 'content': ''' + + The above JSON object is invalid for the following reason: + ''' + Validation path `c` failed for value `{"a": "hello", "b": true}` because: + Field required + ''' + The following is a revised JSON object: + + ''', + 'role': 'user', + }), + ]), }), dict({ 'kind': 'MODEL RESPONSE', @@ -57,24 +88,46 @@ }), dict({ 'kind': 'CLIENT REQUEST', - 'payload': ''' - - You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: - ``` - interface ExampleABC { - a: string; - b: boolean; - c: number; - } - - ``` - The following is a user request: - ''' - Get me stuff. - ''' - The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + 'payload': list([ + dict({ + 'content': 'Get me stuff.', + 'role': 'user', + }), + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + dict({ + 'content': ''' + + The above JSON object is invalid for the following reason: + ''' + Validation path `c` failed for value `{"a": "hello", "b": true}` because: + Field required + ''' + The following is a revised JSON object: - ''', + ''', + 'role': 'user', + }), + ]), }), dict({ 'kind': 'MODEL RESPONSE', From 83f1435b76cfc58123fecc15e1fa74a8f0cc284b Mon Sep 17 00:00:00 2001 From: Daniel Rosenwasser Date: Sat, 20 Apr 2024 01:27:34 +0000 Subject: [PATCH 3/3] Add uncommitted file. --- python/src/typechat/_internal/translator.py | 22 +++++++++++---------- 1 file changed, 12 insertions(+), 10 deletions(-) diff --git a/python/src/typechat/_internal/translator.py b/python/src/typechat/_internal/translator.py index 141a4772..6c649c1e 100644 --- a/python/src/typechat/_internal/translator.py +++ b/python/src/typechat/_internal/translator.py @@ -49,7 +49,7 @@ def __init__( self._type_name = conversion_result.typescript_type_reference self._schema_str = conversion_result.typescript_schema_str - async def translate(self, request: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]: + async def translate(self, input: str, *, prompt_preamble: str | list[PromptSection] | None = None) -> Result[T]: """ Translates a natural language request into an object of type `T`. If the JSON object returned by the language model fails to validate, repair attempts will be made up until `_max_repair_attempts`. @@ -57,23 +57,25 @@ async def translate(self, request: str, *, prompt_preamble: str | list[PromptSec This often helps produce a valid instance. Args: - request: A natural language request. + input: A natural language request. prompt_preamble: An optional string or list of prompt sections to prepend to the generated prompt.\ If a string is given, it is converted to a single "user" role prompt section. """ - request = self._create_request_prompt(request) - prompt: str | list[PromptSection] - if prompt_preamble is None: - prompt = request - else: + messages: list[PromptSection] = [] + + messages.append({"role": "user", "content": input}) + if prompt_preamble: if isinstance(prompt_preamble, str): prompt_preamble = [{"role": "user", "content": prompt_preamble}] - prompt = [*prompt_preamble, {"role": "user", "content": request}] + else: + messages.extend(prompt_preamble) + + messages.append({"role": "user", "content": self._create_request_prompt(input)}) num_repairs_attempted = 0 while True: - completion_response = await self.model.complete(prompt) + completion_response = await self.model.complete(messages) if isinstance(completion_response, Failure): return completion_response @@ -93,7 +95,7 @@ async def translate(self, request: str, *, prompt_preamble: str | list[PromptSec if num_repairs_attempted >= self._max_repair_attempts: return Failure(error_message) num_repairs_attempted += 1 - request = f"{text_response}\n{self._create_repair_prompt(error_message)}" + messages.append({"role": "user", "content": self._create_repair_prompt(error_message)}) def _create_request_prompt(self, intent: str) -> str: prompt = f"""