Skip to content

Commit

Permalink
Cache: Refactoring, code formatting, and CI readiness
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Jan 2, 2025
1 parent 29cb0ee commit 5650b61
Show file tree
Hide file tree
Showing 11 changed files with 42 additions and 31 deletions.
8 changes: 4 additions & 4 deletions docs/cache.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@
"## Prerequisites\n",
"Because this notebook uses OpenAI's APIs, you need to supply an authentication\n",
"token. Either set the environment variable `OPENAI_API_KEY`, or optionally\n",
"configure your token here."
"configure your token here after enabling the code fragment."
]
},
{
Expand All @@ -92,9 +92,9 @@
"source": [
"import os\n",
"\n",
"_ = os.environ.setdefault(\n",
" \"OPENAI_API_KEY\", \"sk-XJZ7pfog5Gp8Kus8D--invalid--0CJ5lyAKSefZLaV1Y9S1\"\n",
")"
"# _ = os.environ.setdefault(\n",
"# \"OPENAI_API_KEY\", \"sk-XJZ7pfog5Gp8Kus8D--invalid--0CJ5lyAKSefZLaV1Y9S1\"\n",
"# )"
]
},
{
Expand Down
14 changes: 9 additions & 5 deletions examples/cache.py → examples/basic/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@
"""
Prerequisites: Because this program uses OpenAI's APIs, you need to supply an
authentication token. Either set the environment variable `OPENAI_API_KEY`,
or optionally configure your token here.
or optionally configure your token here after enabling the code fragment.
"""
_ = os.environ.setdefault(
"OPENAI_API_KEY", "sk-XJZ7pfog5Gp8Kus8D--invalid--0CJ5lyAKSefZLaV1Y9S1"
)
# _ = os.environ.setdefault(
# "OPENAI_API_KEY", "sk-XJZ7pfog5Gp8Kus8D--invalid--0CJ5lyAKSefZLaV1Y9S1"
# )


def standard_cache() -> None:
Expand Down Expand Up @@ -95,11 +95,15 @@ def semantic_cache() -> None:
set_llm_cache(None)


if __name__ == "__main__":
def main() -> None:
standard_cache()
semantic_cache()


if __name__ == "__main__":
main()


"""
What is the answer to everything?
Expand Down
2 changes: 1 addition & 1 deletion langchain_cratedb/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ def update(self, prompt: str, llm_string: str, return_val: RETURN_VAL_TYPE) -> N
metadata = {
"llm_string": llm_string,
"prompt": prompt,
"return_val": dumps([g for g in return_val]),
"return_val": dumps(list(return_val)),
}
llm_cache.add_texts(texts=[prompt], metadatas=[metadata])

Expand Down
6 changes: 2 additions & 4 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -127,22 +127,20 @@ select = [

[tool.ruff.lint.per-file-ignores]
"docs/*.ipynb" = [
"ERA001", # Found commented-out code
"F401",
"F821",
"T201",
"ERA001", # Found commented-out code
]
"examples/*.py" = [
"ERA001", # Found commented-out code
"F401",
"F821",
"T20", # `print` found.
]
"tests/*" = ["S101"] # Use of `assert` detected
".github/scripts/*" = ["S101"] # Use of `assert` detected

[tool.ruff.lint.per-file-ignores]
"docs/*.ipynb" = ["F401", "F821", "T201"]

[tool.coverage.run]
omit = [
"langchain_cratedb/retrievers.py",
Expand Down
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from langchain_core.outputs import ChatGeneration, Generation, LLMResult

from langchain_cratedb import CrateDBSemanticCache
from tests.integration_tests.cache.fake_embeddings import (
from tests.feature.cache.fake_embeddings import (
ConsistentFakeEmbeddings,
FakeEmbeddings,
)
Expand Down Expand Up @@ -50,7 +50,7 @@ def test_semantic_cache_single(engine: sa.Engine) -> None:
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
cache_output = get_llm_cache().lookup("bar", llm_string)
assert cache_output == [Generation(text="fizz")]
Expand All @@ -76,7 +76,7 @@ def test_semantic_cache_multi(engine: sa.Engine) -> None:
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))
get_llm_cache().update(
"foo", llm_string, [Generation(text="fizz"), Generation(text="Buzz")]
)
Expand Down Expand Up @@ -106,7 +106,7 @@ def test_semantic_cache_chat(engine: sa.Engine) -> None:
llm = FakeChatModel()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))
prompt: t.List[BaseMessage] = [HumanMessage(content="foo")]
llm_cache = t.cast(CrateDBSemanticCache, get_llm_cache())
llm_cache.update(
Expand All @@ -131,6 +131,7 @@ def test_semantic_cache_chat(engine: sa.Engine) -> None:
([random_string()], [[random_string(), random_string()]]),
# Single prompt, multiple generations
([random_string()], [[random_string(), random_string(), random_string()]]),
# ruff: noqa: ERA001
# Multiple prompts, multiple generations
# (
# [random_string(), random_string()],
Expand Down Expand Up @@ -165,7 +166,7 @@ def test_semantic_cache_hit(
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))

llm_generations = [
[
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def test_memcached_cache(cache: BaseCache) -> None:

params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
expected_output = LLMResult(
Expand All @@ -48,7 +48,7 @@ def test_memcached_cache_flush(cache: BaseCache) -> None:

params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo"])
expected_output = LLMResult(
Expand Down Expand Up @@ -79,7 +79,7 @@ class FulltextLLMCache(Base): # type: ignore
__tablename__ = "llm_cache_fulltext"
# TODO: Original. Can it be converged by adding a polyfill to
# `sqlalchemy-cratedb`?
# id = Column(Integer, Sequence("cache_id"), primary_key=True)
# id = Column(Integer, Sequence("cache_id"), primary_key=True) # noqa: ERA001
id = sa.Column(sa.BigInteger, server_default=sa.func.now(), primary_key=True)
prompt = sa.Column(sa.String, nullable=False)
llm = sa.Column(sa.String, nullable=False)
Expand All @@ -92,7 +92,7 @@ class FulltextLLMCache(Base): # type: ignore
llm = FakeLLM()
params = llm.dict()
params["stop"] = None
llm_string = str(sorted([(k, v) for k, v in params.items()]))
llm_string = str(sorted(params.items()))
get_llm_cache().update("foo", llm_string, [Generation(text="fizz")])
output = llm.generate(["foo", "bar", "foo"])
expected_cache_output = [Generation(text="foo")]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,10 +33,9 @@
def cache(request: FixtureRequest, engine: sa.Engine) -> BaseCache:
if request.param == "memory":
return InMemoryCache()
elif request.param == "cratedb":
if request.param == "cratedb":
return CrateDBCache(engine=engine)
else:
raise NotImplementedError(f"Cache type not implemented: {request.param}")
raise NotImplementedError(f"Cache type not implemented: {request.param}")


@pytest.fixture(autouse=True)
Expand Down Expand Up @@ -213,4 +212,4 @@ async def test_llm_cache_clear() -> None:
def create_llm_string(llm: Union[BaseLLM, BaseChatModel]) -> str:
_dict: Dict = llm.dict()
_dict["stop"] = None
return str(sorted([(k, v) for k, v in _dict.items()]))
return str(sorted(_dict.items()))
14 changes: 12 additions & 2 deletions tests/test_docs.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from pathlib import Path

import nbclient
import pytest
from _pytest.python import Metafunc

Expand Down Expand Up @@ -31,7 +32,16 @@ def pytest_generate_tests(metafunc: Metafunc) -> None:
def test_notebook(notebook: Path) -> None:
"""
Execute Jupyter Notebook, one test case per .ipynb file.
Skip test cases that trip when no OpenAI API key is configured.
"""
if notebook.name in SKIP_NOTEBOOKS:
raise pytest.skip(f"FIXME: Skipping notebook: {notebook.name}")
run_notebook(notebook)
raise pytest.skip(f"FIXME: Excluding notebook: {notebook.name}")
try:
run_notebook(notebook)
except nbclient.exceptions.CellExecutionError as ex:
if "The api_key client option must be set" not in str(ex):
raise
raise pytest.skip(
"Skipping test because `OPENAI_API_KEY` is not defined"
) from ex
3 changes: 1 addition & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,7 @@ def _call(
return self.queries[prompt]
if stop is None:
return "foo"
else:
return "bar"
return "bar"

@property
def _identifying_params(self) -> Dict[str, Any]:
Expand Down

0 comments on commit 5650b61

Please sign in to comment.