diff --git a/python/src/typechat/_internal/translator.py b/python/src/typechat/_internal/translator.py index 6c649c1e..86b57126 100644 --- a/python/src/typechat/_internal/translator.py +++ b/python/src/typechat/_internal/translator.py @@ -64,12 +64,10 @@ async def translate(self, input: str, *, prompt_preamble: str | list[PromptSecti messages: list[PromptSection] = [] - messages.append({"role": "user", "content": input}) if prompt_preamble: if isinstance(prompt_preamble, str): prompt_preamble = [{"role": "user", "content": prompt_preamble}] - else: - messages.extend(prompt_preamble) + messages.extend(prompt_preamble) messages.append({"role": "user", "content": self._create_request_prompt(input)}) @@ -85,16 +83,21 @@ async def translate(self, input: str, *, prompt_preamble: str | list[PromptSecti error_message: str if 0 <= first_curly < last_curly: trimmed_response = text_response[first_curly:last_curly] - parsed_response = pydantic_core.from_json(trimmed_response, allow_inf_nan=False, cache_strings=False) - result = self.validator.validate_object(parsed_response) - if isinstance(result, Success): - return result - error_message = result.message + try: + parsed_response = pydantic_core.from_json(trimmed_response, allow_inf_nan=False, cache_strings=False) + except ValueError as e: + error_message = f"Error: {e}\n\nAttempted to parse:\n\n{trimmed_response}" + else: + result = self.validator.validate_object(parsed_response) + if isinstance(result, Success): + return result + error_message = result.message else: - error_message = "Response did not contain any text resembling JSON." + error_message = f"Response did not contain any text resembling JSON.\nResponse was\n\n{text_response}" if num_repairs_attempted >= self._max_repair_attempts: return Failure(error_message) num_repairs_attempted += 1 + messages.append({"role": "assistant", "content": text_response}) messages.append({"role": "user", "content": self._create_repair_prompt(error_message)}) def _create_request_prompt(self, intent: str) -> str: diff --git a/python/tests/__snapshots__/test_translator.ambr b/python/tests/__snapshots__/test_translator.ambr index 594786fb..84e22594 100644 --- a/python/tests/__snapshots__/test_translator.ambr +++ b/python/tests/__snapshots__/test_translator.ambr @@ -5,9 +5,39 @@ 'kind': 'CLIENT REQUEST', 'payload': list([ dict({ - 'content': 'Get me stuff.', + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', 'role': 'user', }), + ]), + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true, "c": 1234 }', + }), + ]) +# --- +# name: test_translator_with_invalid_json + list([ + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ dict({ 'content': ''' @@ -33,7 +63,57 @@ }), dict({ 'kind': 'MODEL RESPONSE', - 'payload': '{ "a": "hello", "b": true, "c": 1234 }', + 'payload': '{ "a": "hello" "b": true }', + }), + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + dict({ + 'content': '{ "a": "hello" "b": true }', + 'role': 'assistant', + }), + dict({ + 'content': ''' + + The above JSON object is invalid for the following reason: + ''' + Error: expected `,` or `}` at line 1 column 16 + + Attempted to parse: + + { "a": "hello" "b": true } + ''' + The following is a revised JSON object: + + ''', + 'role': 'user', + }), + ]), + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello" "b": true, "c": 1234 }', }), ]) # --- @@ -43,9 +123,94 @@ 'kind': 'CLIENT REQUEST', 'payload': list([ dict({ - 'content': 'Get me stuff.', + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + ]), + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true }', + }), + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + dict({ + 'content': '{ "a": "hello", "b": true }', + 'role': 'assistant', + }), + dict({ + 'content': ''' + + The above JSON object is invalid for the following reason: + ''' + Validation path `c` failed for value `{"a": "hello", "b": true}` because: + Field required + ''' + The following is a revised JSON object: + + ''', + 'role': 'user', + }), + ]), + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true, "c": 1234 }', + }), + ]) +# --- +# name: test_translator_with_single_failure_and_list_preamble_1 + list([ + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ + dict({ + 'content': 'Hey, I need some stuff.', 'role': 'user', }), + dict({ + 'content': 'Okay, what kind of stuff?', + 'role': 'assistant', + }), dict({ 'content': ''' @@ -67,6 +232,48 @@ ''', 'role': 'user', }), + ]), + }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true }', + }), + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ + dict({ + 'content': 'Hey, I need some stuff.', + 'role': 'user', + }), + dict({ + 'content': 'Okay, what kind of stuff?', + 'role': 'assistant', + }), + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + dict({ + 'content': '{ "a": "hello", "b": true }', + 'role': 'assistant', + }), dict({ 'content': ''' @@ -82,6 +289,44 @@ }), ]), }), + dict({ + 'kind': 'MODEL RESPONSE', + 'payload': '{ "a": "hello", "b": true, "c": 1234 }', + }), + ]) +# --- +# name: test_translator_with_single_failure_and_str_preamble + list([ + dict({ + 'kind': 'CLIENT REQUEST', + 'payload': list([ + dict({ + 'content': 'Just so you know, I need some stuff.', + 'role': 'user', + }), + dict({ + 'content': ''' + + You are a service that translates user requests into JSON objects of type "ExampleABC" according to the following TypeScript definitions: + ``` + interface ExampleABC { + a: string; + b: boolean; + c: number; + } + + ``` + The following is a user request: + ''' + Get me stuff. + ''' + The following is the user request translated into a JSON object with 2 spaces of indentation and no properties with the value undefined: + + ''', + 'role': 'user', + }), + ]), + }), dict({ 'kind': 'MODEL RESPONSE', 'payload': '{ "a": "hello", "b": true }', @@ -90,7 +335,7 @@ 'kind': 'CLIENT REQUEST', 'payload': list([ dict({ - 'content': 'Get me stuff.', + 'content': 'Just so you know, I need some stuff.', 'role': 'user', }), dict({ @@ -114,6 +359,10 @@ ''', 'role': 'user', }), + dict({ + 'content': '{ "a": "hello", "b": true }', + 'role': 'assistant', + }), dict({ 'content': ''' diff --git a/python/tests/test_translator.py b/python/tests/test_translator.py index 1245ef34..a86502b3 100644 --- a/python/tests/test_translator.py +++ b/python/tests/test_translator.py @@ -20,6 +20,11 @@ def __init__(self, responses: list[str]) -> None: @override async def complete(self, prompt: str | list[typechat.PromptSection]) -> typechat.Result[str]: + # Capture a snapshot because the translator + # can choose to pass in the same underlying list. + if isinstance(prompt, list): + prompt = prompt.copy() + self.conversation.append({ "kind": "CLIENT REQUEST", "payload": prompt }) response = next(self.responses) self.conversation.append({ "kind": "MODEL RESPONSE", "payload": response }) @@ -50,4 +55,41 @@ def test_translator_with_single_failure(snapshot: Any): t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) asyncio.run(t.translate("Get me stuff.")) - assert m.conversation == snapshot \ No newline at end of file + assert m.conversation == snapshot + +def test_translator_with_invalid_json(snapshot: Any): + m = FixedModel([ + '{ "a": "hello" "b": true }', + '{ "a": "hello" "b": true, "c": 1234 }', + ]) + t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) + asyncio.run(t.translate("Get me stuff.")) + + assert m.conversation == snapshot + +def test_translator_with_single_failure_and_str_preamble(snapshot: Any): + m = FixedModel([ + '{ "a": "hello", "b": true }', + '{ "a": "hello", "b": true, "c": 1234 }', + ]) + t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) + asyncio.run(t.translate( + "Get me stuff.", + prompt_preamble="Just so you know, I need some stuff.", + )) + + assert m.conversation == snapshot + +def test_translator_with_single_failure_and_list_preamble_1(snapshot: Any): + m = FixedModel([ + '{ "a": "hello", "b": true }', + '{ "a": "hello", "b": true, "c": 1234 }', + ]) + t = typechat.TypeChatJsonTranslator(m, v, ExampleABC) + asyncio.run(t.translate("Get me stuff.", prompt_preamble=[ + {"role": "user", "content": "Hey, I need some stuff."}, + {"role": "assistant", "content": "Okay, what kind of stuff?"}, + ])) + + assert m.conversation == snapshot +