diff --git a/tests/test_tsellm.py b/tests/test_tsellm.py index 3f0e39b..e2ded0c 100644 --- a/tests/test_tsellm.py +++ b/tests/test_tsellm.py @@ -3,7 +3,6 @@ import unittest from pathlib import Path from test.support import captured_stdout, captured_stderr, captured_stdin -from test.support.os_helper import TESTFN, unlink import duckdb import llm.cli @@ -225,6 +224,40 @@ def test_embed_hazo_binary(self): self.assertTrue(llm.get_embedding_model("hazo").supports_binary) self.expect_success(*self.path_args, "select embed(randomblob(16), 'hazo')") + def test_embed_json_recursive(self): + example_json = """{ + \"name\": \"Alice\", + \"details\": { + \"age\": 30, + \"hobbies\": [\"reading\", \"cycling\"], + \"location\": \"Wonderland\" + }, + \"greeting\": \"Hello, World!\" + }""" + out = self.expect_success( + *self.path_args, + f"select json_extract('{example_json}', '$.name')", + ) + self.assertEqual( + "('Alice',)\n", + out, + ) + + out = self.expect_success( + *self.path_args, + f"select json_embed('{example_json}', 'hazo')", + ) + self.assertEqual( + ('(\'{"name": [5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ' + '0.0, 0.0, 0.0, 0.0], "details": {"age": 30, "hobbies": [[7.0, 0.0, 0.0, 0.0, ' + '0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.0, 0.0, 0.0, ' + '0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], ' + '"location": [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ' + '0.0, 0.0, 0.0, 0.0]}, "greeting": [6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, ' + "0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}',)\n"), + out, + ) + def test_embed_default_hazo(self): self.assertEqual(llm_cli.get_default_embedding_model(), "hazo") out = self.expect_success(*self.path_args, "select embed('hello world')") diff --git a/tsellm/cli.py b/tsellm/cli.py index 80707d0..30e1f69 100644 --- a/tsellm/cli.py +++ b/tsellm/cli.py @@ -16,6 +16,7 @@ _prompt_model, _prompt_model_default, _embed_model, + _json_embed_model, _embed_model_default, ) @@ -79,6 +80,7 @@ class TsellmConsole(InteractiveConsole, ABC): ("prompt", 1, _prompt_model_default, False), ("embed", 2, _embed_model, False), ("embed", 1, _embed_model_default, False), + ("json_embed", 2, _json_embed_model, False), ] error_class = None @@ -87,7 +89,7 @@ class TsellmConsole(InteractiveConsole, ABC): @staticmethod def create_console( - fp: Union[str, Path], in_memory_type: DatabaseType = DatabaseType.UNKNOWN + fp: Union[str, Path], in_memory_type: DatabaseType = DatabaseType.UNKNOWN ): sniffer = DBSniffer(fp) if sniffer.is_in_memory: diff --git a/tsellm/core.py b/tsellm/core.py index 5f87e1f..054aff2 100644 --- a/tsellm/core.py +++ b/tsellm/core.py @@ -14,6 +14,21 @@ """ + +def json_recurse_apply(json_obj, f): + if isinstance(json_obj, dict): + # Recursively apply the function to dictionary values + return {k: json_recurse_apply(v, f) for k, v in json_obj.items()} + elif isinstance(json_obj, list): + # Recursively apply the function to list elements + return [json_recurse_apply(item, f) for item in json_obj] + elif isinstance(json_obj, str): + # Apply the function to string values, which returns a list of floats + return f(json_obj) + else: + # Return the object as is if it's neither a dictionary, list, or string + return json_obj + def _prompt_model(prompt: str, model: str) -> str: return llm.get_model(model).prompt(prompt).text() @@ -26,6 +41,12 @@ def _embed_model(text: str, model: str) -> str: return json.dumps(llm.get_embedding_model(model).embed(text)) +def _json_embed_model(js: str, model: str) -> str: + return json.dumps( + json_recurse_apply(json.loads(js), lambda v: json.loads(_embed_model(v, model))) + ) + + def _embed_model_default(text: str) -> str: return json.dumps( llm.get_embedding_model(llm_cli.get_default_embedding_model()).embed(text)