Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

recursive json_embed(json, model) for SQLite and DuckDB #34

Merged
merged 6 commits into from
Aug 15, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
44 changes: 41 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,6 @@
[![codecov](https://codecov.io/gh/Florents-Tselai/tsellm/branch/main/graph/badge.svg)](https://codecov.io/gh/Florents-Tselai/tsellm)
[![License](https://img.shields.io/badge/BSD%20license-blue.svg)](https://github.com/Florents-Tselai/tsellm/blob/main/LICENSE)




**tsellm** is the easiest way to access LLMs from SQLite or DuckDB.

```shell
Expand Down Expand Up @@ -44,12 +41,53 @@ so you can use any of its plugins:
```shell
llm install llm-sentence-transformers
llm sentence-transformers register all-MiniLM-L12-v2
llm install llm-embed-hazo # dummy embedding model for demonstration purposes
```

```sql
tsellm prompts.sqlite3 "select embed(p, 'sentence-transformers/all-MiniLM-L12-v2')"
```

### Embedding `JSON` Recursively

If you have `JSON` columns, you can embed these object recursively.
That is, an embedding vector of floats will replace each text occurrence in the object.

```bash
cat <<EOF | tee >(sqlite3 prompts.sqlite3) | duckdb prompts.duckdb
CREATE TABLE people(d JSON);
INSERT INTO people (d) VALUES
('{"name": "John Doe", "age": 30, "hobbies": ["reading", "biking"]}'),
('{"name": "Jane Smith", "age": 25, "hobbies": ["painting", "traveling"]}')
EOF
```

#### SQLite

```sql
tsellm prompts.sqlite3 "select json_embed(d, 'hazo') from people"
```

*Output*

```
('{"name": [4.0, 3.0,..., 0.0], "age": 30, "hobbies": [[7.0, 0.0,..., 0.0], [6.0, 0.0, ..., 0.0]]}',)
('{"name": [4.0, 5.0, ,..., 0.0], "age": 25, "hobbies": [[8.0, 0.0,..., 0.0], [9.0, 0.0,..., 0.0]]}',)
```

#### DuckDB

```sql
tsellm prompts.duckdb "select json_embed(d, 'hazo') from people"
```

*Output*

```
('{"name": [4.0, 3.0,..., 0.0], "age": 30, "hobbies": [[7.0, 0.0,..., 0.0], [6.0, 0.0, ..., 0.0]]}',)
('{"name": [4.0, 5.0, ,..., 0.0], "age": 25, "hobbies": [[8.0, 0.0,..., 0.0], [9.0, 0.0,..., 0.0]]}',)
```

### Embeddings for binary (`BLOB`) columns

```shell
Expand Down
64 changes: 63 additions & 1 deletion tests/test_tsellm.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
import unittest
from pathlib import Path
from test.support import captured_stdout, captured_stderr, captured_stdin
from test.support.os_helper import TESTFN, unlink

import duckdb
import llm.cli
Expand Down Expand Up @@ -175,6 +174,15 @@ def test_interact_valid_multiline_sql(self):

class InMemorySQLiteTest(TsellmConsoleTest):
path_args = None
alice_json = """{
\"name\": \"Alice\",
\"details\": {
\"age\": 30,
\"hobbies\": [\"reading\", \"cycling\"],
\"location\": \"Wonderland\"
},
\"greeting\": \"Hello, World!\"
}"""

def setUp(self):
super().setUp()
Expand Down Expand Up @@ -225,6 +233,33 @@ def test_embed_hazo_binary(self):
self.assertTrue(llm.get_embedding_model("hazo").supports_binary)
self.expect_success(*self.path_args, "select embed(randomblob(16), 'hazo')")

def test_embed_json_recursive(self):
out = self.expect_success(
*self.path_args,
f"select json_extract('{self.alice_json}', '$.name')",
)
self.assertEqual(
"('Alice',)\n",
out,
)

out = self.expect_success(
*self.path_args,
f"select json_embed('{self.alice_json}', 'hazo')",
)
self.assertEqual(
(
'(\'{"name": [5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0], "details": {"age": 30, "hobbies": [[7.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.0, 0.0, 0.0, "
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "
'"location": [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0]}, "greeting": [6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}',)\n"
),
out,
)

def test_embed_default_hazo(self):
self.assertEqual(llm_cli.get_default_embedding_model(), "hazo")
out = self.expect_success(*self.path_args, "select embed('hello world')")
Expand Down Expand Up @@ -290,6 +325,33 @@ def test_embed_hazo_binary(self):
# See https://github.com/Florents-Tselai/tsellm/issues/25
pass

def test_embed_json_recursive(self):
out = self.expect_success(
*self.path_args,
f"select '{self.alice_json}'::json -> 'name'",
)
self.assertEqual(
"('\"Alice\"',)\n",
out,
)

out = self.expect_success(
*self.path_args,
f"select json_embed('{self.alice_json}'::json, 'hazo')",
)
self.assertEqual(
(
'(\'{"name": [5.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0], "details": {"age": 30, "hobbies": [[7.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0], [7.0, 0.0, 0.0, "
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]], "
'"location": [10.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
'0.0, 0.0, 0.0, 0.0]}, "greeting": [6.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, '
"0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0]}',)\n"
),
out,
)


class DiskDuckDBTest(InMemoryDuckDBTest):
db_fp = None
Expand Down
2 changes: 1 addition & 1 deletion tsellm/__version__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
__title__ = "tsellm"
__description__ = "Use LLMs in SQLite and DuckDB"
__version__ = "0.1.0a10"
__version__ = "0.1.0a12"
5 changes: 4 additions & 1 deletion tsellm/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
_prompt_model,
_prompt_model_default,
_embed_model,
_json_embed_model,
_embed_model_default,
)

Expand Down Expand Up @@ -79,6 +80,7 @@ class TsellmConsole(InteractiveConsole, ABC):
("prompt", 1, _prompt_model_default, False),
("embed", 2, _embed_model, False),
("embed", 1, _embed_model_default, False),
("json_embed", 2, _json_embed_model, False),
]

error_class = None
Expand All @@ -87,7 +89,7 @@ class TsellmConsole(InteractiveConsole, ABC):

@staticmethod
def create_console(
fp: Union[str, Path], in_memory_type: DatabaseType = DatabaseType.UNKNOWN
fp: Union[str, Path], in_memory_type: DatabaseType = DatabaseType.UNKNOWN
):
sniffer = DBSniffer(fp)
if sniffer.is_in_memory:
Expand Down Expand Up @@ -274,6 +276,7 @@ def is_valid_db(self) -> bool:
_functions = [
("prompt", 2, _prompt_model, False),
("embed", 2, _embed_model, False),
("json_embed", 2, _json_embed_model, False),
]

def connect(self):
Expand Down
21 changes: 21 additions & 0 deletions tsellm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,21 @@
"""


def json_recurse_apply(json_obj, f):
if isinstance(json_obj, dict):
# Recursively apply the function to dictionary values
return {k: json_recurse_apply(v, f) for k, v in json_obj.items()}
elif isinstance(json_obj, list):
# Recursively apply the function to list elements
return [json_recurse_apply(item, f) for item in json_obj]
elif isinstance(json_obj, str):
# Apply the function to string values, which returns a list of floats
return f(json_obj)
else:
# Return the object as is if it's neither a dictionary, list, or string
return json_obj


def _prompt_model(prompt: str, model: str) -> str:
return llm.get_model(model).prompt(prompt).text()

Expand All @@ -26,6 +41,12 @@ def _embed_model(text: str, model: str) -> str:
return json.dumps(llm.get_embedding_model(model).embed(text))


def _json_embed_model(js: str, model: str) -> str:
return json.dumps(
json_recurse_apply(json.loads(js), lambda v: json.loads(_embed_model(v, model)))
)


def _embed_model_default(text: str) -> str:
return json.dumps(
llm.get_embedding_model(llm_cli.get_default_embedding_model()).embed(text)
Expand Down
Loading