From 92bbc377f65e4d9288569164e6e95471f26ec2d2 Mon Sep 17 00:00:00 2001 From: Florents Tselai Date: Sat, 22 Jun 2024 23:54:21 +0300 Subject: [PATCH] basic tests --- LICENSE | 4 +- README.md | 20 +++++-- setup.py | 10 ++-- tests/conftest.py | 35 +++++++++++ tests/test_tsellm.py | 30 +++++++++- tsellm/__init__.py | 18 +++++- tsellm/__main__.py | 3 +- tsellm/cli.py | 139 ++++++++++++++++++++++++++++++++++++++----- 8 files changed, 225 insertions(+), 34 deletions(-) create mode 100644 tests/conftest.py diff --git a/LICENSE b/LICENSE index 3eeb18d..31bcd0b 100644 --- a/LICENSE +++ b/LICENSE @@ -1,8 +1,6 @@ - - BSD License -Copyright (c) 2024, Florents Tselai +Copyright (c) 2024, Florents Tselai All rights reserved. Redistribution and use in source and binary forms, with or without modification, diff --git a/README.md b/README.md index 2c4b4f1..f496301 100644 --- a/README.md +++ b/README.md @@ -6,10 +6,8 @@

tsellm

- LLM support in SQLite + SQLite with LLM Superpowers

- Status | - Why | How | Installation | Usage | @@ -25,12 +23,22 @@ -## Status - -## Why +**tsellm** is SQLite wrapper with LLM superpowers. +It's available as Python package and as a SQLite shell wrapper. ## How +**tsellm** relies on three facts: + +* SQLite is bundled with the standard Python library (`import sqlite3`) +* one can create Python-written user-defined functions to be used in SQLite + queries (see [create_function](https://github.com/simonw/llm)) +* [Simon Willison](https://github.com/simonw/) has gone through the process of + creating the beautiful [llm](https://github.com/simonw/llm) Python + library and CLI + +With that in mind, one can bring the whole `llm` library into SQLite. +**tsellm** attempts to do just that ## Installation diff --git a/setup.py b/setup.py index 1d19011..7bbd395 100644 --- a/setup.py +++ b/setup.py @@ -6,8 +6,8 @@ def get_long_description(): with open( - os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"), - encoding="utf8", + os.path.join(os.path.dirname(os.path.abspath(__file__)), "README.md"), + encoding="utf8", ) as fp: return fp.read() @@ -32,6 +32,8 @@ def get_long_description(): version=VERSION, packages=["tsellm"], install_requires=["click", "llm", "setuptools", "pip"], - extras_require={"test": ["pytest", "pytest-cov", "black", "ruff", "click"]}, - python_requires=">=3.7" + extras_require={ + "test": ["pytest", "pytest-cov", "black", "ruff", "click", "sqlite_utils"] + }, + python_requires=">=3.7", ) diff --git a/tests/conftest.py b/tests/conftest.py new file mode 100644 index 0000000..ffc8573 --- /dev/null +++ b/tests/conftest.py @@ -0,0 +1,35 @@ +from sqlite_utils import Database +from sqlite_utils.utils import sqlite3 +import pytest + + +def pytest_configure(config): + import sys + + sys._called_from_test = True + + +@pytest.fixture +def fresh_db(): + return Database(memory=True) + + +@pytest.fixture +def existing_db(db_path): + database = Database(db_path) + database.executescript( + """ + CREATE TABLE foo (text TEXT); + INSERT INTO foo (text) values ("one"); + INSERT INTO foo (text) values ("two"); + INSERT INTO foo (text) values ("three"); + """ + ) + return database + + +@pytest.fixture +def db_path(tmpdir): + path = str(tmpdir / "test.db") + db = sqlite3.connect(path) + return path diff --git a/tests/test_tsellm.py b/tests/test_tsellm.py index 73cb0a6..261b1ad 100644 --- a/tests/test_tsellm.py +++ b/tests/test_tsellm.py @@ -1,5 +1,29 @@ -from tsellm import example_function +from sqlite_utils import Database +from tsellm.cli import cli +import pytest +import datetime +from click.testing import CliRunner -def test_example_function(): - assert example_function() == 2 + +def test_cli(db_path): + db = Database(db_path) + assert [] == db.table_names() + table = db.create_table( + "prompts", + { + "prompt": str, + "generated": str, + "model": str, + "embedding": dict, + }, + ) + + assert ["prompts"] == db.table_names() + + table.insert({"prompt": "hello"}) + table.insert({"prompt": "world"}) + + assert db.execute( + "select prompt from prompts" + ).fetchall() == [("hello",), ("world",)] diff --git a/tsellm/__init__.py b/tsellm/__init__.py index 923334e..3da56b1 100644 --- a/tsellm/__init__.py +++ b/tsellm/__init__.py @@ -1,2 +1,16 @@ -def example_function(): - return 1 + 1 +def _prompt(p): + return p * 2 + + +TSELLM_CONFIG_SQL = """ +CREATE TABLE IF NOT EXISTS __tsellm ( +data +); + +""" + + +def _tsellm_init(con): + """Entry-point for tsellm initialization.""" + con.execute(TSELLM_CONFIG_SQL) + con.create_function("prompt", 1, _prompt) diff --git a/tsellm/__main__.py b/tsellm/__main__.py index 7e34ccd..189a4fd 100644 --- a/tsellm/__main__.py +++ b/tsellm/__main__.py @@ -1,4 +1,5 @@ +import sys from .cli import cli if __name__ == "__main__": - cli() \ No newline at end of file + cli(sys.argv[1:]) diff --git a/tsellm/cli.py b/tsellm/cli.py index c8a5d15..9d01710 100644 --- a/tsellm/cli.py +++ b/tsellm/cli.py @@ -1,15 +1,124 @@ -import click -@click.group() -@click.version_option() -def cli(): - """ CLI for tsellm """ - pass - -@cli.command() -@click.argument( - "name", - type=str, - required=True, -) -def hello(name): - print(f"Hello, {name}") \ No newline at end of file +import sqlite3 +import sys + +from argparse import ArgumentParser +from code import InteractiveConsole +from textwrap import dedent +from . import _prompt, _tsellm_init + + +def execute(c, sql, suppress_errors=True): + """Helper that wraps execution of SQL code. + + This is used both by the REPL and by direct execution from the CLI. + + 'c' may be a cursor or a connection. + 'sql' is the SQL string to execute. + """ + + try: + for row in c.execute(sql): + print(row) + except sqlite3.Error as e: + tp = type(e).__name__ + try: + print(f"{tp} ({e.sqlite_errorname}): {e}", file=sys.stderr) + except AttributeError: + print(f"{tp}: {e}", file=sys.stderr) + if not suppress_errors: + sys.exit(1) + + +class SqliteInteractiveConsole(InteractiveConsole): + """A simple SQLite REPL.""" + + def __init__(self, connection): + super().__init__() + self._con = connection + self._cur = connection.cursor() + + def runsource(self, source, filename="", symbol="single"): + """Override runsource, the core of the InteractiveConsole REPL. + + Return True if more input is needed; buffering is done automatically. + Return False is input is a complete statement ready for execution. + """ + match source: + case ".version": + print(f"{sqlite3.sqlite_version}") + case ".help": + print("Enter SQL code and press enter.") + case ".quit": + sys.exit(0) + case _: + if not sqlite3.complete_statement(source): + return True + execute(self._cur, source) + return False + + +def cli(*args): + print(args) + parser = ArgumentParser( + description="tsellm sqlite3 CLI", + prog="python -m tsellm", + ) + parser.add_argument( + "filename", type=str, default=":memory:", nargs="?", + help=( + "SQLite database to open (defaults to ':memory:'). " + "A new database is created if the file does not previously exist." + ), + ) + parser.add_argument( + "sql", type=str, nargs="?", + help=( + "An SQL query to execute. " + "Any returned rows are printed to stdout." + ), + ) + parser.add_argument( + "-v", "--version", action="version", + version=f"SQLite version {sqlite3.sqlite_version}", + help="Print underlying SQLite library version", + ) + args = parser.parse_args(*args) + + if args.filename == ":memory:": + db_name = "a transient in-memory database" + else: + db_name = repr(args.filename) + + # Prepare REPL banner and prompts. + if sys.platform == "win32" and "idlelib.run" not in sys.modules: + eofkey = "CTRL-Z" + else: + eofkey = "CTRL-D" + banner = dedent(f""" + tsellm shell, running on SQLite version {sqlite3.sqlite_version} + Connected to {db_name} + + Each command will be run using execute() on the cursor. + Type ".help" for more information; type ".quit" or {eofkey} to quit. + """).strip() + sys.ps1 = "tsellm> " + sys.ps2 = " ... " + + con = sqlite3.connect(args.filename, isolation_level=None) + _tsellm_init(con) + try: + if args.sql: + # SQL statement provided on the command-line; execute it directly. + execute(con, args.sql, suppress_errors=False) + else: + # No SQL provided; start the REPL. + console = SqliteInteractiveConsole(con) + try: + import readline + except ImportError: + pass + console.interact(banner, exitmsg="") + finally: + con.close() + + sys.exit(0)