From 198f35d1cb8d7e58bf658cebcf250cddc1ba7cf9 Mon Sep 17 00:00:00 2001 From: Nick Byrne Date: Fri, 16 Aug 2024 15:30:16 -0300 Subject: [PATCH] PoC for ingestion with CLI --- ragna/deploy/_cli/core.py | 160 ++++++++++++++++++++++++++++++++++++++ 1 file changed, 160 insertions(+) diff --git a/ragna/deploy/_cli/core.py b/ragna/deploy/_cli/core.py index fa493c21..6e232480 100644 --- a/ragna/deploy/_cli/core.py +++ b/ragna/deploy/_cli/core.py @@ -1,3 +1,4 @@ +import json import subprocess import sys import time @@ -8,10 +9,13 @@ import rich import typer import uvicorn +from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn import ragna from ragna._utils import timeout_after +from ragna.core._utils import default_user from ragna.deploy._api import app as api_app +from ragna.deploy._api import database, orm from ragna.deploy._ui import app as ui_app from .config import ConfigOption, check_config, init_config @@ -23,6 +27,8 @@ add_completion=False, pretty_exceptions_enable=False, ) +corpus_app = typer.Typer() +app.add_typer(corpus_app, name="corpus") def version_callback(value: bool) -> None: @@ -171,3 +177,157 @@ def wait_for_api() -> None: if process is not None: process.kill() process.communicate() + + +@corpus_app.command(help="Ingest some documents into a given corpus.") +def ingest( + documents: list[Path], + corpus_name: Optional[str] = typer.Option( + None, help="Name of the corpus to ingest the documents into." + ), + config: ConfigOption = "./ragna.toml", # type: ignore[assignment] + user: Optional[str] = typer.Option( + None, help="User to link the documents to in the ragna database." + ), + verbose: bool = typer.Option( + False, help="Print the documents that could not be ingested." + ), + ignore_log: bool = typer.Option( + False, help="Ignore the log file and re-ingest all documents." + ), +) -> None: + try: + document_factory = getattr(config.document, "from_path") + except AttributeError: + raise typer.BadParameter( + f"{config.document.__name__} does not support creating documents from a path. " + "Please implement a `from_path` method." + ) + + try: + make_session = database.get_sessionmaker(config.api.database_url) + except Exception: + raise typer.BadParameter( + f"Could not connect to the database: {config.api.database_url}" + ) + + if user is None: + user = default_user() + with make_session() as session: + user_id = database._get_user_id(session, user) + + # Log (JSONL) for recording which files previously added to vector database. + # Each entry has keys for 'user', 'corpus_name', 'source_storage' and 'document'. + ingestion_log = {} + if not ignore_log: + ingestion_log_file = Path.cwd() / ".ragna_ingestion_log.jsonl" + if ingestion_log_file.exists(): + with open(ingestion_log_file, "r") as stream: + for line in stream: + entry = json.loads(line) + if entry["corpus_name"] == corpus_name and entry["user"] == user: + ingestion_log.setdefault(entry["source_storage"], set()).add( + entry["document"] + ) + + with Progress( + TextColumn("[progress.description]{task.description}"), + BarColumn(), + "[progress.percentage]{task.percentage:>3.1f}%", + TimeRemainingColumn(), + ) as progress: + overall_task = progress.add_task( + "[cyan]Adding document embeddings to source storages...", + total=len(config.source_storages), + ) + + bad_documents_collection = {} + for source_storage in config.source_storages: + BATCH_SIZE = 10 + number_of_batches = len(documents) // BATCH_SIZE + source_storage_task = progress.add_task( + f"[green]Adding document embeddings to {source_storage.__name__}...", + total=number_of_batches, + ) + + documents_not_ingested = [] + for batch_number in range(0, len(documents), BATCH_SIZE): + document_instances = [] + orm_documents = [] + + if source_storage.__name__ in ingestion_log: + batch_doc_set = set( + [ + str(doc) + for doc in documents[ + batch_number : batch_number + BATCH_SIZE + ] + ] + ) + if batch_doc_set.issubset(ingestion_log[source_storage.__name__]): + progress.advance(source_storage_task) + continue + + for document in documents[batch_number : batch_number + BATCH_SIZE]: + try: + doc_instance = document_factory(document) + document_instances.append(doc_instance) + orm_documents.append( + orm.Document( + id=doc_instance.id, + user_id=user_id, + name=doc_instance.name, + metadata_=doc_instance.metadata, + ) + ) + except Exception: + documents_not_ingested.append(document) + + if not orm_documents: + continue + + try: + session = make_session() + session.add_all(orm_documents) + source_storage().store(corpus_name, document_instances) + except Exception: + documents_not_ingested.extend( + documents[batch_number : batch_number + BATCH_SIZE] + ) + session.rollback() + finally: + session.commit() + session.close() + + if not ignore_log: + with open(ingestion_log_file, "a") as stream: + for document in documents[ + batch_number : batch_number + BATCH_SIZE + ]: + stream.write( + json.dumps( + { + "user": user, + "corpus_name": corpus_name, + "source_storage": source_storage.__name__, + "document": str(document), + } + ) + + "\n" + ) + + progress.advance(source_storage_task) + + bad_documents_collection[source_storage.__name__] = set( + documents_not_ingested + ) + + progress.update(source_storage_task, completed=number_of_batches) + progress.advance(overall_task) + + progress.update(overall_task, completed=len(config.source_storages)) + + if verbose: + typer.echo( + f"Failed to embed the following documents: {bad_documents_collection}" + )