diff --git a/.gitignore b/.gitignore index 7700e23..8fed5dc 100644 --- a/.gitignore +++ b/.gitignore @@ -163,7 +163,7 @@ cython_debug/ #.idea/ # default dataset download location -data/ +datasets/ # database files *.sqlite diff --git a/Makefile b/Makefile index 3372a7b..9f82344 100644 --- a/Makefile +++ b/Makefile @@ -1,26 +1,3 @@ -# Define variables to extract the package name and version from pyproject.toml -PACKAGE_INFO = $(shell python scripts/get_project_info.py) -PACKAGE_NAME = $(shell echo $(PACKAGE_INFO) | awk '{print $$1}') -PACKAGE_VERSION = $(shell echo $(PACKAGE_INFO) | awk '{print $$2}') -WHEEL_FILE = dist/$(PACKAGE_NAME)-$(PACKAGE_VERSION)-py3-none-any.whl - -.PHONY: build uninstall install build_and_reinstall fmt - -# Target to build the package using poetry -build: - poetry build - -# Target to uninstall the package using pip -uninstall: - pip uninstall -y $(PACKAGE_NAME) - -# Target to install the package using pip -install: - pip install -q $(WHEEL_FILE) - -# Target to run all steps: build, uninstall, and install -build_and_reinstall: build uninstall install - # Sort imports and format python files fmt: isort --profile black . diff --git a/README.md b/README.md index 72a7236..d59c518 100644 --- a/README.md +++ b/README.md @@ -14,7 +14,7 @@ pip install ragulate ## Usage ```sh -usage: ragulate [-h] {download-llamadataset,ingest,query} ... +usage: ragulate [-h] {download,ingest,query,compare} ... RAGu-late CLI tool. @@ -22,9 +22,7 @@ options: -h, --help show this help message and exit commands: - {download-llamadataset,ingest,query} - download-llamadataset - Download a llama-dataset + download Download a dataset ingest Run an ingest pipeline query Run an query pipeline compare Compare results from 2 (or more) recipes @@ -33,7 +31,7 @@ commands: ### Download Dataset Example ``` -ragulate download-llamadataset BraintrustCodaHelpDesk +ragulate download -k llama BraintrustCodaHelpDesk ``` ### Ingest Example diff --git a/poetry.lock b/poetry.lock index 0f0eb28..dcf99c5 100644 --- a/poetry.lock +++ b/poetry.lock @@ -187,6 +187,25 @@ doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphin test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] trio = ["trio (>=0.23)"] +[[package]] +name = "astrapy" +version = "1.2.1" +description = "AstraPy is a Pythonic SDK for DataStax Astra and its Data API" +optional = false +python-versions = "<4.0.0,>=3.8.0" +files = [ + {file = "astrapy-1.2.1-py3-none-any.whl", hash = "sha256:0d7ca1e6f18a6a4e9a41ffaf2aa4cc585d36de3e983b5c5ce0bbb30a1595e30b"}, + {file = "astrapy-1.2.1.tar.gz", hash = "sha256:c4ba88ef16ac1e990ccba322d376b6ea256513a3004a0894c14bfa2403f1d646"}, +] + +[package.dependencies] +bson = ">=0.5.10,<0.6.0" +cassio = ">=0.1.4,<0.2.0" +deprecation = ">=2.1.0,<2.2.0" +httpx = {version = ">=0.25.2,<1", extras = ["http2"]} +toml = ">=0.10.2,<0.11.0" +uuid6 = ">=2024.1.12,<2024.2.0" + [[package]] name = "async-timeout" version = "4.0.3" @@ -295,6 +314,20 @@ files = [ {file = "blinker-1.8.2.tar.gz", hash = "sha256:8f77b09d3bf7c795e969e9486f39c2c5e9c39d4ee07424be2bc594ece9642d83"}, ] +[[package]] +name = "bson" +version = "0.5.10" +description = "BSON codec for Python" +optional = false +python-versions = "*" +files = [ + {file = "bson-0.5.10.tar.gz", hash = "sha256:d6511b2ab051139a9123c184de1a04227262173ad593429d21e443d6462d6590"}, +] + +[package.dependencies] +python-dateutil = ">=2.4.0" +six = ">=1.9.0" + [[package]] name = "cachetools" version = "5.3.3" @@ -306,6 +339,69 @@ files = [ {file = "cachetools-5.3.3.tar.gz", hash = "sha256:ba29e2dfa0b8b556606f097407ed1aa62080ee108ab0dc5ec9d6a723a007d105"}, ] +[[package]] +name = "cassandra-driver" +version = "3.29.1" +description = "DataStax Driver for Apache Cassandra" +optional = false +python-versions = "*" +files = [ + {file = "cassandra-driver-3.29.1.tar.gz", hash = "sha256:38e9c2a2f2a9664bb03f1f852d5fccaeff2163942b5db35dffcf8bf32a51cfe5"}, + {file = "cassandra_driver-3.29.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a8f175c7616a63ca48cb8bd4acc443e2a3d889964d5157cead761f23cc8db7bd"}, + {file = "cassandra_driver-3.29.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:7d66398952b9cd21c40edff56e22b6d3bce765edc94b207ddb5896e7bc9aa088"}, + {file = "cassandra_driver-3.29.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5bbc6f575ef109ce5d4abfa2033bf36c394032abd83e32ab671159ce68e7e17b"}, + {file = "cassandra_driver-3.29.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78f241af75696adb3e470209e2fbb498804c99e2b197d24d74774eee6784f283"}, + {file = "cassandra_driver-3.29.1-cp310-cp310-win32.whl", hash = "sha256:54d9e651a742d6ca3d874ef8d06a40fa032d2dba97142da2d36f60c5675e39f8"}, + {file = "cassandra_driver-3.29.1-cp310-cp310-win_amd64.whl", hash = "sha256:630dc5423cd40eba0ee9db31065e2238098ff1a25a6b1bd36360f85738f26e4b"}, + {file = "cassandra_driver-3.29.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:0b841d38c96bb878d31df393954863652d6d3a85f47bcc00fd1d70a5ea73023f"}, + {file = "cassandra_driver-3.29.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:19cc7375f673e215bd4cbbefae2de9f07830be7dabef55284a2d2ff8d8691efe"}, + {file = "cassandra_driver-3.29.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2b74b355be3dcafe652fffda8f14f385ccc1a8dae9df28e6080cc660da39b45f"}, + {file = "cassandra_driver-3.29.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0e6dac7eddd3f4581859f180383574068a3f113907811b4dad755a8ace4c3fbd"}, + {file = "cassandra_driver-3.29.1-cp311-cp311-win32.whl", hash = "sha256:293a79dba417112b56320ed0013d71fd7520f5fc4a5fd2ac8000c762c6dd5b07"}, + {file = "cassandra_driver-3.29.1-cp311-cp311-win_amd64.whl", hash = "sha256:7c2374fdf1099047a6c9c8329c79d71ad11e61d9cca7de92a0f49655da4bdd8a"}, + {file = "cassandra_driver-3.29.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:4431a0c836f33a33c733c84997fbdb6398be005c4d18a8c8525c469fdc29393c"}, + {file = "cassandra_driver-3.29.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:d23b08381b171a9e42ace483a82457edcddada9e8367e31677b97538cde2dc34"}, + {file = "cassandra_driver-3.29.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4beb29a0139e63a10a5b9a3c7b72c30a4e6e20c9f0574f9d22c0d4144fe3d348"}, + {file = "cassandra_driver-3.29.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1b206423cc454a78f16b411e7cb641dddc26168ac2e18f2c13665f5f3c89868c"}, + {file = "cassandra_driver-3.29.1-cp312-cp312-win32.whl", hash = "sha256:ac898cca7303a3a2a3070513eee12ef0f1be1a0796935c5b8aa13dae8c0a7f7e"}, + {file = "cassandra_driver-3.29.1-cp312-cp312-win_amd64.whl", hash = "sha256:4ad0c9fb2229048ad6ff8c6ddbf1fdc78b111f2b061c66237c2257fcc4a31b14"}, + {file = "cassandra_driver-3.29.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:4282c5deac462e4bb0f6fd0553a33d514dbd5ee99d0812594210080330ddd1a2"}, + {file = "cassandra_driver-3.29.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:41ca7eea069754002418d3bdfbd3dfd150ea12cb9db474ab1a01fa4679a05bcb"}, + {file = "cassandra_driver-3.29.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6639ccb268c4dc754bc45e03551711780d0e02cb298ab26cde1f42b7bcc74f8"}, + {file = "cassandra_driver-3.29.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1a9d7d3b1be24a7f113b5404186ccccc977520401303a8fe78ba34134cad2482"}, + {file = "cassandra_driver-3.29.1-cp38-cp38-win32.whl", hash = "sha256:81c8fd556c6e1bb93577e69c1f10a3fadf7ddb93958d226ccbb72389396e9a92"}, + {file = "cassandra_driver-3.29.1-cp38-cp38-win_amd64.whl", hash = "sha256:cfe70ed0f27af949de2767ea9cef4092584e8748759374a55bf23c30746c7b23"}, + {file = "cassandra_driver-3.29.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:a2c03c1d834ac1a0ae39f9af297a8cd38829003ce910b08b324fb3abe488ce2b"}, + {file = "cassandra_driver-3.29.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:9a3e1e2b01f3b7a5cf75c97401bce830071d99c42464352087d7475e0161af93"}, + {file = "cassandra_driver-3.29.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:90c42006665a4e490b0766b70f3d637f36a30accbef2da35d6d4081c0e0bafc3"}, + {file = "cassandra_driver-3.29.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c1aca41f45772f9759e8246030907d92bc35fbbdc91525a3cb9b49939b80ad7"}, + {file = "cassandra_driver-3.29.1-cp39-cp39-win32.whl", hash = "sha256:ce4a66245d4a0c8b07fdcb6398698c2c42eb71245fb49cff39435bb702ff7be6"}, + {file = "cassandra_driver-3.29.1-cp39-cp39-win_amd64.whl", hash = "sha256:4cae69ceb1b1d9383e988a1b790115253eacf7867ceb15ed2adb736e3ce981be"}, +] + +[package.dependencies] +geomet = ">=0.1,<0.3" + +[package.extras] +cle = ["cryptography (>=35.0)"] +graph = ["gremlinpython (==3.4.6)"] + +[[package]] +name = "cassio" +version = "0.1.8" +description = "A framework-agnostic Python library to seamlessly integrate Apache Cassandra(R) with ML/LLM/genAI workloads." +optional = false +python-versions = "<4.0,>=3.8" +files = [ + {file = "cassio-0.1.8-py3-none-any.whl", hash = "sha256:c09e7c884ba7227ff5277c86f3b0f31c523672ea407f56d093c7227e69c54d94"}, + {file = "cassio-0.1.8.tar.gz", hash = "sha256:4e09929506cb3dd6fad217e89846d0a1a59069afd24b82c72526ef6f2e9271af"}, +] + +[package.dependencies] +cassandra-driver = ">=3.28.0,<4.0.0" +numpy = ">=1.0" +requests = ">=2.31.0,<3.0.0" + [[package]] name = "certifi" version = "2024.2.2" @@ -551,6 +647,20 @@ wrapt = ">=1.10,<2" [package.extras] dev = ["PyTest", "PyTest-Cov", "bump2version (<1)", "sphinx (<2)", "tox"] +[[package]] +name = "deprecation" +version = "2.1.0" +description = "A library to handle automated deprecations" +optional = false +python-versions = "*" +files = [ + {file = "deprecation-2.1.0-py2.py3-none-any.whl", hash = "sha256:a10811591210e1fb0e768a8c25517cabeabcba6f0bf96564f8ff45189f90b14a"}, + {file = "deprecation-2.1.0.tar.gz", hash = "sha256:72b3bde64e5d778694b0cf68178aed03d15e15477116add3fb773e581f9518ff"}, +] + +[package.dependencies] +packaging = "*" + [[package]] name = "dill" version = "0.3.8" @@ -913,6 +1023,21 @@ test-downstream = ["aiobotocore (>=2.5.4,<3.0.0)", "dask-expr", "dask[dataframe, test-full = ["adlfs", "aiohttp (!=4.0.0a0,!=4.0.0a1)", "cloudpickle", "dask", "distributed", "dropbox", "dropboxdrivefs", "fastparquet", "fusepy", "gcsfs", "jinja2", "kerchunk", "libarchive-c", "lz4", "notebook", "numpy", "ocifs", "pandas", "panel", "paramiko", "pyarrow", "pyarrow (>=1)", "pyftpdlib", "pygit2", "pytest", "pytest-asyncio (!=0.22.0)", "pytest-benchmark", "pytest-cov", "pytest-mock", "pytest-recording", "pytest-rerunfailures", "python-snappy", "requests", "smbprotocol", "tqdm", "urllib3", "zarr", "zstandard"] tqdm = ["tqdm"] +[[package]] +name = "geomet" +version = "0.2.1.post1" +description = "GeoJSON <-> WKT/WKB conversion utilities" +optional = false +python-versions = ">2.6, !=3.3.*, <4" +files = [ + {file = "geomet-0.2.1.post1-py3-none-any.whl", hash = "sha256:a41a1e336b381416d6cbed7f1745c848e91defaa4d4c1bdc1312732e46ffad2b"}, + {file = "geomet-0.2.1.post1.tar.gz", hash = "sha256:91d754f7c298cbfcabd3befdb69c641c27fe75e808b27aa55028605761d17e95"}, +] + +[package.dependencies] +click = "*" +six = "*" + [[package]] name = "gitdb" version = "4.0.11" @@ -1027,6 +1152,32 @@ files = [ {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] +[[package]] +name = "h2" +version = "4.1.0" +description = "HTTP/2 State-Machine based protocol implementation" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "h2-4.1.0-py3-none-any.whl", hash = "sha256:03a46bcf682256c95b5fd9e9a99c1323584c3eec6440d379b9903d709476bc6d"}, + {file = "h2-4.1.0.tar.gz", hash = "sha256:a83aca08fbe7aacb79fec788c9c0bac936343560ed9ec18b82a13a12c28d2abb"}, +] + +[package.dependencies] +hpack = ">=4.0,<5" +hyperframe = ">=6.0,<7" + +[[package]] +name = "hpack" +version = "4.0.0" +description = "Pure-Python HPACK header compression" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "hpack-4.0.0-py3-none-any.whl", hash = "sha256:84a076fad3dc9a9f8063ccb8041ef100867b1878b25ef0ee63847a5d53818a6c"}, + {file = "hpack-4.0.0.tar.gz", hash = "sha256:fc41de0c63e687ebffde81187a948221294896f6bdc0ae2312708df339430095"}, +] + [[package]] name = "htbuilder" version = "0.6.2" @@ -1076,6 +1227,7 @@ files = [ [package.dependencies] anyio = "*" certifi = "*" +h2 = {version = ">=3,<5", optional = true, markers = "extra == \"http2\""} httpcore = "==1.*" idna = "*" sniffio = "*" @@ -1100,6 +1252,17 @@ files = [ [package.extras] tests = ["freezegun", "pytest", "pytest-cov"] +[[package]] +name = "hyperframe" +version = "6.0.1" +description = "HTTP/2 framing layer for Python" +optional = false +python-versions = ">=3.6.1" +files = [ + {file = "hyperframe-6.0.1-py3-none-any.whl", hash = "sha256:0ec6bafd80d8ad2195c4f03aacba3a8265e57bc4cff261e802bf39970ed02a15"}, + {file = "hyperframe-6.0.1.tar.gz", hash = "sha256:ae510046231dc8e9ecb1a6586f63d2347bf4c8905914aa84ba585ae85f28a914"}, +] + [[package]] name = "idna" version = "3.7" @@ -1226,12 +1389,17 @@ referencing = ">=0.31.0" [[package]] name = "kaleido" -version = "0.2.1.post1" +version = "0.2.1" description = "Static image export for web-based visualization libraries with zero dependencies" optional = false python-versions = "*" files = [ - {file = "kaleido-0.2.1.post1-py2.py3-none-manylinux2014_armv7l.whl", hash = "sha256:d313940896c24447fc12c74f60d46ea826195fc991f58569a6e73864d53e5c20"}, + {file = "kaleido-0.2.1-py2.py3-none-macosx_10_11_x86_64.whl", hash = "sha256:ca6f73e7ff00aaebf2843f73f1d3bacde1930ef5041093fe76b83a15785049a7"}, + {file = "kaleido-0.2.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:bb9a5d1f710357d5d432ee240ef6658a6d124c3e610935817b4b42da9c787c05"}, + {file = "kaleido-0.2.1-py2.py3-none-manylinux1_x86_64.whl", hash = "sha256:aa21cf1bf1c78f8fa50a9f7d45e1003c387bd3d6fe0a767cfbbf344b95bdc3a8"}, + {file = "kaleido-0.2.1-py2.py3-none-manylinux2014_aarch64.whl", hash = "sha256:845819844c8082c9469d9c17e42621fbf85c2b237ef8a86ec8a8527f98b6512a"}, + {file = "kaleido-0.2.1-py2.py3-none-win32.whl", hash = "sha256:ecc72635860be616c6b7161807a65c0dbd9b90c6437ac96965831e2e24066552"}, + {file = "kaleido-0.2.1-py2.py3-none-win_amd64.whl", hash = "sha256:4670985f28913c2d063c5734d125ecc28e40810141bdb0a46f15b76c1d45f23c"}, ] [[package]] @@ -1387,6 +1555,22 @@ openai = ["openai (<2)", "tiktoken (>=0.3.2,<0.6.0)"] qdrant = ["qdrant-client (>=1.3.1,<2.0.0)"] text-helpers = ["chardet (>=5.1.0,<6.0.0)"] +[[package]] +name = "langchain-astradb" +version = "0.3.3" +description = "An integration package connecting Astra DB and LangChain" +optional = false +python-versions = "<4.0,>=3.8.1" +files = [ + {file = "langchain_astradb-0.3.3-py3-none-any.whl", hash = "sha256:39deef1253947ef1bfaf3c27881ecdf07621d96c2cf37814aed9e506a9bee217"}, + {file = "langchain_astradb-0.3.3.tar.gz", hash = "sha256:f9a996ec4bef134896195430adeb7f264389c368a03d2ea91356837e8ddde091"}, +] + +[package.dependencies] +astrapy = ">=1.2,<2.0" +langchain-core = ">=0.1.31,<0.3" +numpy = ">=1,<2" + [[package]] name = "langchain-community" version = "0.0.38" @@ -3685,6 +3869,17 @@ h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] +[[package]] +name = "uuid6" +version = "2024.1.12" +description = "New time-based UUID formats which are suited for use as a database key" +optional = false +python-versions = ">=3.8" +files = [ + {file = "uuid6-2024.1.12-py3-none-any.whl", hash = "sha256:8150093c8d05a331bc0535bc5ef6cf57ac6eceb2404fd319bc10caee2e02c065"}, + {file = "uuid6-2024.1.12.tar.gz", hash = "sha256:ed0afb3a973057575f9883201baefe402787ca5e11e1d24e377190f0c43f1993"}, +] + [[package]] name = "validators" version = "0.28.3" @@ -3925,4 +4120,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.10,<3.13" -content-hash = "ede06679bbee0905cdc816b7c28b2c829af1a73ac846dbb7284be80ac506a1c7" +content-hash = "da0e35853bec0d8701394206b6d25986045ea38e195bfa5868028c1cc14038a2" diff --git a/pyproject.toml b/pyproject.toml index 8861b63..2e54d1a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,28 +5,31 @@ description = "A tool for evaluating RAG pipelines" authors = ["Eric Pinzur "] license = "Apache 2.0" readme = "README.md" -packages = [{include = "ragulate/"}] +packages = [{include = "ragulate/*"}] [tool.poetry.dependencies] python = ">=3.10,<3.13" epinzur-trulens-eval = ">=0.30.1b0" -kaleido = "^0.2.1" -llama-index-core = "^0.10.39.post1" -python-dotenv = "^1.0.1" -plotly = "^5.22.0" +kaleido = "0.2.1" inflection = "^0.5.1" -tqdm = "^4.66.4" +llama-index-core = "^0.10.31" +numpy = ">=1.23.5" +pandas = ">=2.2.2" +plotly = "^5.22.0" +python-dotenv = ">=1.0.0" +tqdm = ">=4.66.1" [tool.poetry.group.dev.dependencies] -langchain-core = "0.1.52" -langchain-community = "0.0.38" -langchain-openai = "0.1.3" black = "^24.4.2" isort = "^5.13.2" +langchain-astradb = "0.3.3" +langchain-community = "0.0.38" +langchain-core = "0.1.52" +langchain-openai = "0.1.3" [build-system] requires = ["poetry-core", "setuptools>=42", "wheel", "pip"] build-backend = "poetry.core.masonry.api" [tool.poetry.scripts] -ragulate = "ragulate.cli:main" \ No newline at end of file +ragulate = "ragulate.cli:main" diff --git a/ragulate/analysis_engine.py b/ragulate/analysis.py similarity index 62% rename from ragulate/analysis_engine.py rename to ragulate/analysis.py index cf051b0..f7c1b0c 100644 --- a/ragulate/analysis_engine.py +++ b/ragulate/analysis.py @@ -1,15 +1,15 @@ -from typing import List, Optional - -from .utils import get_tru - -from pandas import DataFrame -import pandas as pd +from typing import List import numpy as np +import pandas as pd import plotly.graph_objects as go +from pandas import DataFrame from plotly.io import write_image -class AnalysisEngine: +from .utils import get_tru + + +class Analysis: def get_all_data(self, recipes: List[str]) -> DataFrame: df_all = pd.DataFrame() @@ -24,8 +24,15 @@ def get_all_data(self, recipes: List[str]) -> DataFrame: df, metrics = tru.get_records_and_feedback([dataset]) all_metrics.extend(metrics) - columns_to_keep = metrics + ["record_id", "latency", "total_tokens", "total_cost"] - columns_to_drop = [col for col in df.columns if col not in columns_to_keep] + columns_to_keep = metrics + [ + "record_id", + "latency", + "total_tokens", + "total_cost", + ] + columns_to_drop = [ + col for col in df.columns if col not in columns_to_keep + ] df.drop(columns=columns_to_drop, inplace=True) df["recipe"] = recipe @@ -35,7 +42,6 @@ def get_all_data(self, recipes: List[str]) -> DataFrame: for metric in metrics: df.loc[df[metric] < 0, metric] = None - df_all = pd.concat([df_all, df], axis=0, ignore_index=True) tru.delete_singleton() @@ -45,12 +51,15 @@ def get_all_data(self, recipes: List[str]) -> DataFrame: return df_all, list(set(all_metrics)) def output_plots_by_dataset(self, df: DataFrame, metrics: List[str]): - recipes = sorted(df['recipe'].unique(), key=lambda x: x.lower()) - datasets = sorted(df['dataset'].unique(), key=lambda x: x.lower()) + recipes = sorted(df["recipe"].unique(), key=lambda x: x.lower()) + datasets = sorted(df["dataset"].unique(), key=lambda x: x.lower()) # generate an array of rainbow colors by fixing the saturation and lightness of the HSL # representation of color and marching around the hue. - c = ["hsl("+str(h)+",50%"+",50%)" for h in np.linspace(0, 360, len(recipes) + 1)] + c = [ + "hsl(" + str(h) + ",50%" + ",50%)" + for h in np.linspace(0, 360, len(recipes) + 1) + ] height = max((len(metrics) * len(recipes) * 20) + 150, 450) @@ -65,24 +74,32 @@ def output_plots_by_dataset(self, df: DataFrame, metrics: List[str]): x.extend(dx) y.extend([metric] * len(dx)) - fig.add_trace(go.Box( - y=y, - x=x, - name=recipe, - marker_color=c[test_index], - visible=True, - )) + fig.add_trace( + go.Box( + y=y, + x=x, + name=recipe, + marker_color=c[test_index], + visible=True, + ) + ) test_index += 1 - fig.update_traces(orientation="h", boxmean=True, jitter=1, ) + fig.update_traces( + orientation="h", + boxmean=True, + jitter=1, + ) fig.update_layout(boxmode="group", height=height, width=900) - fig.update_layout(legend=dict( - orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1)) + fig.update_layout( + legend=dict( + orientation="h", yanchor="bottom", y=1.02, xanchor="right", x=1 + ) + ) fig.update_layout(yaxis_title="metric", xaxis_title="score") write_image(fig, f"./{dataset}.png") - - def compare(self, recipes: List[str], datasets: Optional[List[str]] = []): + def compare(self, recipes: List[str]): df, metrics = self.get_all_data(recipes=recipes) self.output_plots_by_dataset(df=df, metrics=metrics) diff --git a/ragulate/cli.py b/ragulate/cli.py index 13fa09e..4c690bd 100644 --- a/ragulate/cli.py +++ b/ragulate/cli.py @@ -1,199 +1,27 @@ import argparse -from dotenv import load_dotenv -from .analysis_engine import AnalysisEngine -from .datasets import LLAMA_DATASETS_LFS_URL, download_llama_dataset -from .query_pipeline import QueryPipeline -from .ingest import ingest +from dotenv import load_dotenv +from . import cli_commands from .logging_config import logger -from typing import List, Optional - if load_dotenv(): logger.info("Parsed .env file successfully") - -def setup_download_llama_dataset(subparsers): - download_parser = subparsers.add_parser( - "download-llamadataset", help="Download a llama-dataset" - ) - download_parser.add_argument( - "dataset_name", - type=str, - help=( - "The name of the llama-dataset you want to download, " - "such as `PaulGrahamEssayDataset`." - ), - ) - download_parser.add_argument( - "-d", - "--download-dir", - type=str, - default="./data", - help="Custom dirpath to download the dataset into.", - ) - download_parser.add_argument( - "--llama-datasets-lfs-url", - type=str, - default=LLAMA_DATASETS_LFS_URL, - help="URL to llama datasets.", - ) - download_parser.set_defaults(func=lambda args: download_llama_dataset(**vars(args))) - - -def setup_ingest(subparsers): - ingest_parser = subparsers.add_parser("ingest", help="Run an ingest pipeline") - ingest_parser.add_argument( - "-n", - "--name", - type=str, - help="A unique name for the ingest pipeline", - required=True, - ) - ingest_parser.add_argument( - "-s", - "--script_path", - type=str, - help="The path to the python script that contains the ingest method", - required=True, - ) - ingest_parser.add_argument( - "-m", - "--method-name", - type=str, - help="The name of the method in the script to run ingest", - required=True, - ) - ingest_parser.add_argument( - "--var-name", - type=str, - help=( - "The name of a variable in the ingest script", - "This should be paired with a `--var-value` argument", - "and can be passed multiple times.", - ), - action="append", - ) - ingest_parser.add_argument( - "--var-value", - type=str, - help=( - "The value of a variable in the ingest script", - "This should be paired with a `--var-name` argument", - "and can be passed multiple times.", - ), - action="append", - ) - ingest_parser.add_argument( - "--dataset", - type=str, - help=("The name of a dataset to ingest", "This can be passed multiple times."), - action="append", - ) - ingest_parser.set_defaults(func=lambda args: ingest(**vars(args))) - - -def setup_query(subparsers): - query_parser = subparsers.add_parser("query", help="Run an query pipeline") - query_parser.add_argument( - "-n", - "--name", - type=str, - help="A unique name for the query pipeline", - required=True, - ) - query_parser.add_argument( - "-s", - "--script_path", - type=str, - help="The path to the python script that contains the query method", - required=True, - ) - query_parser.add_argument( - "-m", - "--method-name", - type=str, - help="The name of the method in the script to run query", - required=True, - ) - query_parser.add_argument( - "--var-name", - type=str, - help=( - "The name of a variable in the query script", - "This should be paired with a `--var-value` argument", - "and can be passed multiple times.", - ), - action="append", - ) - query_parser.add_argument( - "--var-value", - type=str, - help=( - "The value of a variable in the query script", - "This should be paired with a `--var-name` argument", - "and can be passed multiple times.", - ), - action="append", - ) - query_parser.add_argument( - "--dataset", - type=str, - help=("The name of a dataset to query", "This can be passed multiple times."), - action="append", - ) - query_parser.set_defaults(func=lambda args: query(**vars(args))) - - def query( - name:str, - script_path: str, - method_name: str, - var_name: List[str], - var_value: List[str], - dataset: List[str], - **kwargs, - ): - query_pipeline = QueryPipeline(name=name, datasets=dataset) - query_pipeline.query( - script_path=script_path, - method_name=method_name, - var_names=var_name, - var_values=var_value, - datasets=dataset - ) - -def setup_compare(subparsers): - compare_parser = subparsers.add_parser("compare", help="Compare results from 2 (or more) recipes") - compare_parser.add_argument( - "-r", - "--recipe", - type=str, - help="A recipe to compare. This can be passed multiple times.", - required=True, - action="append", - ) - compare_parser.set_defaults(func=lambda args: compare(**vars(args))) - -def compare( - recipe:List[str], - **kwargs, -): - analysis_engine = AnalysisEngine() - analysis_engine.compare(recipes = recipe) +else: + logger.info("Did not find .env file") def main() -> None: - parser = argparse.ArgumentParser(description="RAGu-late CLI tool.") # Subparsers for the main commands subparsers = parser.add_subparsers(title="commands", dest="command", required=True) - setup_download_llama_dataset(subparsers=subparsers) - setup_ingest(subparsers=subparsers) - setup_query(subparsers=subparsers) - setup_compare(subparsers=subparsers) + cli_commands.setup_download(subparsers=subparsers) + cli_commands.setup_ingest(subparsers=subparsers) + cli_commands.setup_query(subparsers=subparsers) + cli_commands.setup_compare(subparsers=subparsers) # Parse the command-line arguments args = parser.parse_args() diff --git a/ragulate/cli_commands/__init__.py b/ragulate/cli_commands/__init__.py new file mode 100644 index 0000000..0f11520 --- /dev/null +++ b/ragulate/cli_commands/__init__.py @@ -0,0 +1,11 @@ +from .compare import setup_compare +from .download import setup_download +from .ingest import setup_ingest +from .query import setup_query + +__all__ = [ + "setup_compare", + "setup_download", + "setup_ingest", + "setup_query", +] diff --git a/ragulate/cli_commands/compare.py b/ragulate/cli_commands/compare.py new file mode 100644 index 0000000..a79eded --- /dev/null +++ b/ragulate/cli_commands/compare.py @@ -0,0 +1,26 @@ +from typing import List + +from ..analysis import Analysis + + +def setup_compare(subparsers): + compare_parser = subparsers.add_parser( + "compare", help="Compare results from 2 (or more) recipes" + ) + compare_parser.add_argument( + "-r", + "--recipe", + type=str, + help="A recipe to compare. This can be passed multiple times.", + required=True, + action="append", + ) + compare_parser.set_defaults(func=lambda args: call_compare(**vars(args))) + + +def call_compare( + recipe: List[str], + **kwargs, +): + analysis = Analysis() + analysis.compare(recipes=recipe) diff --git a/ragulate/cli_commands/download.py b/ragulate/cli_commands/download.py new file mode 100644 index 0000000..044cbbe --- /dev/null +++ b/ragulate/cli_commands/download.py @@ -0,0 +1,28 @@ +from ragulate.datasets import LlamaDataset + + +def setup_download(subparsers): + download_parser = subparsers.add_parser("download", help="Download a dataset") + download_parser.add_argument( + "dataset_name", + type=str, + help=( + "The name of the dataset you want to download, " + "such as `PaulGrahamEssayDataset`." + ), + ) + download_parser.add_argument( + "-k", + "--kind", + type=str, + help="The kind of dataset to download. Currently only `llama` is supported", + required=True, + ) + download_parser.set_defaults(func=lambda args: call_download(**vars(args))) + + +def call_download(dataset_name: str, kind: str, **kwargs): + if not kind == "llama": + raise ("Currently only Llama Datasets are supported. Set param `-k llama`") + llama = LlamaDataset(dataset_name=dataset_name) + llama.download_dataset() diff --git a/ragulate/cli_commands/ingest.py b/ragulate/cli_commands/ingest.py new file mode 100644 index 0000000..3d5164d --- /dev/null +++ b/ragulate/cli_commands/ingest.py @@ -0,0 +1,77 @@ +from typing import List + +from ragulate.datasets import load_datasets +from ragulate.pipelines import IngestPipeline + + +def setup_ingest(subparsers): + ingest_parser = subparsers.add_parser("ingest", help="Run an ingest pipeline") + ingest_parser.add_argument( + "-n", + "--name", + type=str, + help="A unique name for the ingest pipeline", + required=True, + ) + ingest_parser.add_argument( + "-s", + "--script_path", + type=str, + help="The path to the python script that contains the ingest method", + required=True, + ) + ingest_parser.add_argument( + "-m", + "--method-name", + type=str, + help="The name of the method in the script to run ingest", + required=True, + ) + ingest_parser.add_argument( + "--var-name", + type=str, + help=( + "The name of a variable in the ingest script", + "This should be paired with a `--var-value` argument", + "and can be passed multiple times.", + ), + action="append", + ) + ingest_parser.add_argument( + "--var-value", + type=str, + help=( + "The value of a variable in the ingest script", + "This should be paired with a `--var-name` argument", + "and can be passed multiple times.", + ), + action="append", + ) + ingest_parser.add_argument( + "--dataset", + type=str, + help=("The name of a dataset to ingest", "This can be passed multiple times."), + action="append", + ) + ingest_parser.set_defaults(func=lambda args: call_ingest(**vars(args))) + + def call_ingest( + name: str, + script_path: str, + method_name: str, + var_name: List[str], + var_value: List[str], + dataset: List[str], + **kwargs, + ): + datasets = load_datasets(dataset_names=dataset) + + ingest_pipeline = IngestPipeline( + recipe_name=name, + script_path=script_path, + method_name=method_name, + var_names=var_name, + var_values=var_value, + datasets=datasets, + ) + ingest_pipeline.ingest() diff --git a/ragulate/cli_commands/query.py b/ragulate/cli_commands/query.py new file mode 100644 index 0000000..a2fdeac --- /dev/null +++ b/ragulate/cli_commands/query.py @@ -0,0 +1,77 @@ +from typing import List + +from ragulate.datasets import load_datasets +from ragulate.pipelines import QueryPipeline + + +def setup_query(subparsers): + query_parser = subparsers.add_parser("query", help="Run an query pipeline") + query_parser.add_argument( + "-n", + "--name", + type=str, + help="A unique name for the query pipeline", + required=True, + ) + query_parser.add_argument( + "-s", + "--script_path", + type=str, + help="The path to the python script that contains the query method", + required=True, + ) + query_parser.add_argument( + "-m", + "--method-name", + type=str, + help="The name of the method in the script to run query", + required=True, + ) + query_parser.add_argument( + "--var-name", + type=str, + help=( + "The name of a variable in the query script", + "This should be paired with a `--var-value` argument", + "and can be passed multiple times.", + ), + action="append", + ) + query_parser.add_argument( + "--var-value", + type=str, + help=( + "The value of a variable in the query script", + "This should be paired with a `--var-name` argument", + "and can be passed multiple times.", + ), + action="append", + ) + query_parser.add_argument( + "--dataset", + type=str, + help=("The name of a dataset to query", "This can be passed multiple times."), + action="append", + ) + query_parser.set_defaults(func=lambda args: call_query(**vars(args))) + + def call_query( + name: str, + script_path: str, + method_name: str, + var_name: List[str], + var_value: List[str], + dataset: List[str], + **kwargs, + ): + datasets = load_datasets(dataset_names=dataset) + + query_pipeline = QueryPipeline( + recipe_name=name, + script_path=script_path, + method_name=method_name, + var_names=var_name, + var_values=var_value, + datasets=datasets, + ) + query_pipeline.query() diff --git a/ragulate/datasets.py b/ragulate/datasets.py deleted file mode 100644 index 31dcb4d..0000000 --- a/ragulate/datasets.py +++ /dev/null @@ -1,75 +0,0 @@ -import json -import os -from pathlib import Path -from typing import Dict, List, Tuple - -import inflection -from llama_index.core.llama_dataset import download -from llama_index.core.llama_dataset.download import ( - LLAMA_DATASETS_LFS_URL, - LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL, -) - - -def get_llama_dataset_path(dataset_name: str, base_path: str) -> str: - folder = inflection.underscore(dataset_name) - folder = folder.removesuffix("_dataset") - return os.path.join(base_path, folder) - - -def download_llama_dataset( - dataset_name: str, - download_dir: str, - llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL, - llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL, - **kwargs, -) -> None: - - download_path = get_llama_dataset_path( - dataset_name=dataset_name, base_path=download_dir - ) - - if not dataset_name.endswith("Dataset"): - dataset_name = dataset_name + "Dataset" - - download.download_llama_dataset( - llama_dataset_class=dataset_name, - download_dir=download_path, - llama_datasets_lfs_url=llama_datasets_lfs_url, - llama_datasets_source_files_tree_url=llama_datasets_source_files_tree_url, - show_progress=True, - load_documents=False, - ) - - print(f"Successfully downloaded {dataset_name} to {download_dir}") - - -def get_source_file_paths(base_path: str, datasets: List[str]) -> List[str]: - file_paths = [] - - for dataset in datasets: - source_path = os.path.join( - get_llama_dataset_path(dataset_name=dataset, base_path=base_path), - "source_files", - ) - - file_paths.extend([f for f in Path(source_path).iterdir() if f.is_file()]) - - return file_paths - -def get_queries_and_golden_set(base_path: str, dataset: str) -> Tuple[List[str], List[Dict[str, str]]]: - json_path = os.path.join( - get_llama_dataset_path(dataset_name=dataset, base_path=base_path), - "rag_dataset.json", - ) - with open(json_path, "r") as f: - examples = json.load(f)["examples"] - queries = [e["query"] for e in examples] - golden_set = [ - { - "query": e["query"], - "response": e["reference_answer"], - } - for e in examples - ] - return queries, golden_set \ No newline at end of file diff --git a/ragulate/datasets/__init__.py b/ragulate/datasets/__init__.py new file mode 100644 index 0000000..1912467 --- /dev/null +++ b/ragulate/datasets/__init__.py @@ -0,0 +1,9 @@ +from .base_dataset import BaseDataset +from .llama_dataset import LlamaDataset +from .utils import load_datasets + +__all__ = [ + "BaseDataset", + "LlamaDataset", + "load_datasets", +] diff --git a/ragulate/datasets/base_dataset.py b/ragulate/datasets/base_dataset.py new file mode 100644 index 0000000..3c22dd6 --- /dev/null +++ b/ragulate/datasets/base_dataset.py @@ -0,0 +1,44 @@ +from abc import ABC, abstractmethod +from os import path +from pathlib import Path +from typing import Dict, List, Optional, Tuple + + +class BaseDataset(ABC): + + root_storage_path: str + name: str + + def __init__( + self, dataset_name: str, root_storage_path: Optional[str] = "datasets" + ): + self.name = dataset_name + self.root_storage_path = root_storage_path + + def storage_path(self) -> str: + """returns the path where dataset files should be stored""" + return path.join(self.root_storage_path, self.sub_storage_path()) + + def list_files_at_path(self, path: str) -> List[str]: + """lists all files at a path (excluding dot files)""" + return [ + f + for f in Path(path).iterdir() + if f.is_file() and not f.name.startswith(".") + ] + + @abstractmethod + def sub_storage_path(self) -> str: + """the sub-path to store the dataset in""" + + @abstractmethod + def download_dataset(self): + """downloads a dataset locally""" + + @abstractmethod + def get_source_file_paths(self) -> List[str]: + """gets a list of source file paths for for a dataset""" + + @abstractmethod + def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]: + """gets a list of queries and golden_truth answers for a dataset""" diff --git a/ragulate/datasets/llama_dataset.py b/ragulate/datasets/llama_dataset.py new file mode 100644 index 0000000..2b025fd --- /dev/null +++ b/ragulate/datasets/llama_dataset.py @@ -0,0 +1,78 @@ +import json +from os import path +from typing import Dict, List, Optional, Tuple + +import inflection +from llama_index.core.llama_dataset import download +from llama_index.core.llama_dataset.download import ( + LLAMA_DATASETS_LFS_URL, + LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL, +) + +from ..logging_config import logger +from .base_dataset import BaseDataset + + +class LlamaDataset(BaseDataset): + + _llama_datasets_lfs_url: str + _llama_datasets_source_files_tree_url: str + + def __init__( + self, dataset_name: str, root_storage_path: Optional[str] = "datasets" + ): + super().__init__(dataset_name=dataset_name, root_storage_path=root_storage_path) + self._llama_datasets_lfs_url: str = LLAMA_DATASETS_LFS_URL + self._llama_datasets_source_files_tree_url: str = LLAMA_DATASETS_SOURCE_FILES_GITHUB_TREE_URL + + def sub_storage_path(self) -> str: + return "llama" + + def _get_dataset_path(self) -> str: + folder = inflection.underscore(self.name) + folder = folder.removesuffix("_dataset") + return path.join(self.storage_path(), folder) + + def download_dataset(self) -> None: + """downloads a dataset locally""" + download_dir = self._get_dataset_path() + + # to conform with naming scheme at LlamaHub + llama_dataset_class = self.name + if not llama_dataset_class.endswith("Dataset"): + llama_dataset_class = llama_dataset_class + "Dataset" + + download.download_llama_dataset( + llama_dataset_class=llama_dataset_class, + download_dir=download_dir, + llama_datasets_lfs_url=self._llama_datasets_lfs_url, + llama_datasets_source_files_tree_url=self._llama_datasets_source_files_tree_url, + show_progress=True, + load_documents=False, + ) + + logger.info(f"Successfully downloaded {self.name} to {download_dir}") + + def get_source_file_paths(self) -> List[str]: + """gets a list of source file paths for for a dataset""" + source_path = path.join( + self._get_dataset_path(), "source_files" + ) + return self.list_files_at_path(path=source_path) + + def get_queries_and_golden_set(self) -> Tuple[List[str], List[Dict[str, str]]]: + """gets a list of queries and golden_truth answers for a dataset""" + json_path = path.join( + self._get_dataset_path(), "rag_dataset.json" + ) + with open(json_path, "r") as f: + examples = json.load(f)["examples"] + queries = [e["query"] for e in examples] + golden_set = [ + { + "query": e["query"], + "response": e["reference_answer"], + } + for e in examples + ] + return queries, golden_set diff --git a/ragulate/datasets/utils.py b/ragulate/datasets/utils.py new file mode 100644 index 0000000..24815d7 --- /dev/null +++ b/ragulate/datasets/utils.py @@ -0,0 +1,8 @@ +from typing import List + +from .base_dataset import BaseDataset +from .llama_dataset import LlamaDataset + + +def load_datasets(dataset_names: List[str]) -> List[BaseDataset]: + return [LlamaDataset(dataset_name=name) for name in dataset_names] diff --git a/ragulate/ingest.py b/ragulate/ingest.py deleted file mode 100644 index 7909def..0000000 --- a/ragulate/ingest.py +++ /dev/null @@ -1,33 +0,0 @@ -from .utils import load_module, convert_string -from .datasets import get_source_file_paths - -from typing import Any, Dict, List - -from tqdm import tqdm - -from .logging_config import logger - -def ingest( - name: str, - script_path: str, - method_name: str, - var_name: List[str], - var_value: List[str], - dataset: List[str], - **kwargs, -): - logger.info( - f"Starting ingest {name} on {script_path}/{method_name} with vars: {var_name} {var_value} on datasets: {dataset}" - ) - - ingest_module = load_module(script_path, "ingest_module") - ingest_method = getattr(ingest_module, method_name) - - params: Dict[str, Any] = {} - for i, name in enumerate(var_name): - params[name] = convert_string(var_value[i]) - - source_files = get_source_file_paths("data", datasets=dataset) - - for source_file in tqdm(source_files): - ingest_method(file_path=source_file, **params) \ No newline at end of file diff --git a/ragulate/pipelines/__init__.py b/ragulate/pipelines/__init__.py new file mode 100644 index 0000000..5edd959 --- /dev/null +++ b/ragulate/pipelines/__init__.py @@ -0,0 +1,9 @@ +from .base_pipeline import BasePipeline +from .ingest_pipeline import IngestPipeline +from .query_pipeline import QueryPipeline + +__all__ = [ + "BasePipeline", + "IngestPipeline", + "QueryPipeline", +] diff --git a/ragulate/pipelines/base_pipeline.py b/ragulate/pipelines/base_pipeline.py new file mode 100644 index 0000000..eaebff5 --- /dev/null +++ b/ragulate/pipelines/base_pipeline.py @@ -0,0 +1,63 @@ +import importlib.util +import re +from abc import ABC +from typing import Any, Dict, List + +from ragulate.datasets import BaseDataset + + +def convert_string(s): + s = s.strip() + if re.match(r"^\d+$", s): + return int(s) + elif re.match(r"^\d*\.\d+$", s): + return float(s) + else: + return s + + +# Function to dynamically load a module +def load_module(file_path, name): + spec = importlib.util.spec_from_file_location(name, file_path) + module = importlib.util.module_from_spec(spec) + spec.loader.exec_module(module) + return module + + +class BasePipeline(ABC): + recipe_name: str + script_path: str + method_name: str + var_names: List[str] + var_values: List[str] + datasets: List[BaseDataset] + + def __init__( + self, + recipe_name: str, + script_path: str, + method_name: str, + var_names: List[str], + var_values: List[str], + datasets: List[BaseDataset], + **kwargs, + ): + self.recipe_name = recipe_name + self.script_path = script_path + self.method_name = method_name + self.var_names = var_names + self.var_values = var_values + self.datasets = datasets + + def get_method(self, kind: str): + module = load_module(self.script_path, name=kind) + return getattr(module, self.method_name) + + def get_params(self) -> Dict[str, Any]: + params: Dict[str, Any] = {} + for i, name in enumerate(self.var_names): + params[name] = convert_string(self.var_values[i]) + return params + + def dataset_names(self) -> List[str]: + return [d.name for d in self.datasets] diff --git a/ragulate/metrics.py b/ragulate/pipelines/feedbacks.py similarity index 99% rename from ragulate/metrics.py rename to ragulate/pipelines/feedbacks.py index 66b7060..fa662cf 100644 --- a/ragulate/metrics.py +++ b/ragulate/pipelines/feedbacks.py @@ -8,7 +8,7 @@ from trulens_eval.utils.serial import Lens -class metrics: +class Feedbacks: _context: Lens _llm_provider: LLMProvider diff --git a/ragulate/pipelines/ingest_pipeline.py b/ragulate/pipelines/ingest_pipeline.py new file mode 100644 index 0000000..cbf5e1c --- /dev/null +++ b/ragulate/pipelines/ingest_pipeline.py @@ -0,0 +1,27 @@ +from typing import Any, Dict + +from tqdm import tqdm + +from ..logging_config import logger +from .base_pipeline import BasePipeline, convert_string, load_module + + +class IngestPipeline(BasePipeline): + def ingest(self): + logger.info( + f"Starting ingest {self.recipe_name} on {self.script_path}/{self.method_name} with vars: {self.var_names} {self.var_values} on datasets: {self.dataset_names()}" + ) + + ingest_module = load_module(self.script_path, "ingest_module") + ingest_method = getattr(ingest_module, self.method_name) + + params: Dict[str, Any] = {} + for i, name in enumerate(self.var_names): + params[name] = convert_string(self.var_values[i]) + + source_files = [] + for dataset in self.datasets: + source_files.extend(dataset.get_source_file_paths()) + + for source_file in tqdm(source_files): + ingest_method(file_path=source_file, **params) diff --git a/ragulate/query_pipeline.py b/ragulate/pipelines/query_pipeline.py similarity index 58% rename from ragulate/query_pipeline.py rename to ragulate/pipelines/query_pipeline.py index 73defd3..9fe65ea 100644 --- a/ragulate/query_pipeline.py +++ b/ragulate/pipelines/query_pipeline.py @@ -1,23 +1,21 @@ - -from typing import Any, Dict, List +import signal +import time +from typing import Dict, List from tqdm import tqdm - -from .datasets import get_queries_and_golden_set -from .metrics import metrics -from .logging_config import logger +from trulens_eval import Tru, TruChain from trulens_eval.feedback.provider import OpenAI -from trulens_eval import TruChain, Tru -from trulens_eval.schema.feedback import FeedbackResultStatus -import signal -import time +from trulens_eval.schema.feedback import FeedbackMode, FeedbackResultStatus -from .utils import load_module, convert_string, get_tru +from ragulate.datasets import BaseDataset -DEFERRED_FEEDBACK_MODE = "deferred" +from ..logging_config import logger +from ..utils import get_tru +from .base_pipeline import BasePipeline +from .feedbacks import Feedbacks -class QueryPipeline: +class QueryPipeline(BasePipeline): _sigint_received = False _tru: Tru @@ -31,17 +29,35 @@ class QueryPipeline: _finished_queries: int = 0 _evaluation_running = False - def __init__(self, name: str, datasets: List[str]): - self._name = name - self._tru = get_tru(recipe_name=name) + def __init__( + self, + recipe_name: str, + script_path: str, + method_name: str, + var_names: List[str], + var_values: List[str], + datasets: List[BaseDataset], + **kwargs, + ): + super().__init__( + recipe_name=recipe_name, + script_path=script_path, + method_name=method_name, + var_names=var_names, + var_values=var_values, + datasets=datasets, + ) + self._tru = get_tru(recipe_name=recipe_name) self._tru.reset_database() # Set up the signal handler for SIGINT (Ctrl-C) signal.signal(signal.SIGINT, self.signal_handler) for dataset in datasets: - self._queries[dataset], self._golden_sets[dataset] = get_queries_and_golden_set("data", dataset=dataset) - self._total_queries += len(self._queries[dataset]) + self._queries[dataset.name], self._golden_sets[dataset.name] = ( + dataset.get_queries_and_golden_set() + ) + self._total_queries += len(self._queries[dataset.name]) metric_count = 4 self._total_feedbacks = self._total_queries * metric_count @@ -50,23 +66,22 @@ def signal_handler(self, sig, frame): self._sigint_received = True self.stop_evaluation("sigint") - def start_evaluation(self): self._tru.start_evaluator(disable_tqdm=True) self._evaluation_running = True - def stop_evaluation(self, loc:str): + def stop_evaluation(self, loc: str): if self._evaluation_running: try: - print(f"Stopping evaluation from: {loc}") + logger.debug(f"Stopping evaluation from: {loc}") self._tru.stop_evaluator() self._evaluation_running = False except Exception as e: - print(f"issue stopping evaluator: {e}") + logger.error(f"issue stopping evaluator: {e}") finally: self._progress.close() - def update_progress(self, query_change:int = 0): + def update_progress(self, query_change: int = 0): self._finished_queries += query_change status = self._tru.db.get_feedback_count_by_status() @@ -88,54 +103,45 @@ def update_progress(self, query_change:int = 0): self._finished_feedbacks = done - def query( - self, - script_path: str, - method_name: str, - var_names: List[str], - var_values: List[str], - datasets: List[str], - **kwargs, - ): - - query_module = load_module(script_path, name="query_module") - query_method = getattr(query_module, method_name) - - params: Dict[str, Any] = {} - for i, name in enumerate(var_names): - params[name] = convert_string(var_values[i]) + def query(self): + query_method = self.get_method(kind="query") + params = self.get_params() pipeline = query_method(**params) llm_provider = OpenAI() - m = metrics(llm_provider=llm_provider, pipeline=pipeline) + feedbacks = Feedbacks(llm_provider=llm_provider, pipeline=pipeline) self.start_evaluation() time.sleep(0.1) - print( - f"Starting query {self._name} on {script_path}/{method_name} with vars: {var_names} {var_values} on datasets: {datasets}" + logger.info( + f"Starting query {self.recipe_name} on {self.script_path}/{self.method_name} with vars: {self.var_names} {self.var_values} on datasets: {self.dataset_names()}" + ) + logger.info( + "Progress postfix legend: (q)ueries completed; Evaluations (d)one, (r)unning, (w)aiting, (f)ailed, (s)kipped" ) - print("Progress postfix legend: (q)ueries completed; Evaluations (d)one, (r)unning, (w)aiting, (f)ailed, (s)kipped") self._progress = tqdm(total=(self._total_queries + self._total_feedbacks)) - for dataset in self._queries: + for dataset_name in self._queries: feedback_functions = [ - m.answer_correctness(golden_set=self._golden_sets[dataset]), - m.answer_relevance(), - m.context_relevance(), - m.groundedness(), + feedbacks.answer_correctness( + golden_set=self._golden_sets[dataset_name] + ), + feedbacks.answer_relevance(), + feedbacks.context_relevance(), + feedbacks.groundedness(), ] recorder = TruChain( pipeline, - app_id=dataset, + app_id=dataset_name, feedbacks=feedback_functions, - feedback_mode=DEFERRED_FEEDBACK_MODE, + feedback_mode=FeedbackMode.DEFERRED, ) - for query in self._queries[dataset]: + for query in self._queries[dataset_name]: if self._sigint_received: break try: diff --git a/ragulate/utils.py b/ragulate/utils.py index 9ba77bb..d6ed830 100644 --- a/ragulate/utils.py +++ b/ragulate/utils.py @@ -1,24 +1,9 @@ -import importlib.util -import re +from trulens_eval import Tru -from trulens_eval import TruChain, Tru - -def convert_string(s): - s = s.strip() - if re.match(r"^\d+$", s): - return int(s) - elif re.match(r"^\d*\.\d+$", s): - return float(s) - else: - return s - - -# Function to dynamically load a module -def load_module(file_path, name): - spec = importlib.util.spec_from_file_location(name, file_path) - module = importlib.util.module_from_spec(spec) - spec.loader.exec_module(module) - return module def get_tru(recipe_name: str) -> Tru: - return Tru(database_url=f"sqlite:///{recipe_name}.sqlite") #, name=name) \ No newline at end of file + Tru.RETRY_FAILED_SECONDS = 60 + Tru.RETRY_RUNNING_SECONDS = 30 + return Tru( + database_url=f"sqlite:///{recipe_name}.sqlite", database_redact_keys=True + ) # , name=name) diff --git a/scripts/get_project_info.py b/scripts/get_project_info.py deleted file mode 100644 index 91ee613..0000000 --- a/scripts/get_project_info.py +++ /dev/null @@ -1,14 +0,0 @@ -import toml - - -def get_package_info(): - with open("pyproject.toml", "r") as f: - pyproject = toml.load(f) - package_name = pyproject["tool"]["poetry"]["name"] - package_version = pyproject["tool"]["poetry"]["version"] - return package_name, package_version - - -if __name__ == "__main__": - name, version = get_package_info() - print(f"{name} {version}")