Skip to content

Commit

Permalink
PoC for ingestion with CLI
Browse files Browse the repository at this point in the history
  • Loading branch information
nenb committed Aug 16, 2024
1 parent bd2962c commit 198f35d
Showing 1 changed file with 160 additions and 0 deletions.
160 changes: 160 additions & 0 deletions ragna/deploy/_cli/core.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import json
import subprocess
import sys
import time
Expand All @@ -8,10 +9,13 @@
import rich
import typer
import uvicorn
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn

import ragna
from ragna._utils import timeout_after
from ragna.core._utils import default_user
from ragna.deploy._api import app as api_app
from ragna.deploy._api import database, orm
from ragna.deploy._ui import app as ui_app

from .config import ConfigOption, check_config, init_config
Expand All @@ -23,6 +27,8 @@
add_completion=False,
pretty_exceptions_enable=False,
)
corpus_app = typer.Typer()
app.add_typer(corpus_app, name="corpus")


def version_callback(value: bool) -> None:
Expand Down Expand Up @@ -171,3 +177,157 @@ def wait_for_api() -> None:
if process is not None:
process.kill()
process.communicate()


@corpus_app.command(help="Ingest some documents into a given corpus.")
def ingest(
documents: list[Path],
corpus_name: Optional[str] = typer.Option(
None, help="Name of the corpus to ingest the documents into."
),
config: ConfigOption = "./ragna.toml", # type: ignore[assignment]
user: Optional[str] = typer.Option(
None, help="User to link the documents to in the ragna database."
),
verbose: bool = typer.Option(
False, help="Print the documents that could not be ingested."
),
ignore_log: bool = typer.Option(
False, help="Ignore the log file and re-ingest all documents."
),
) -> None:
try:
document_factory = getattr(config.document, "from_path")
except AttributeError:
raise typer.BadParameter(
f"{config.document.__name__} does not support creating documents from a path. "
"Please implement a `from_path` method."
)

try:
make_session = database.get_sessionmaker(config.api.database_url)
except Exception:
raise typer.BadParameter(
f"Could not connect to the database: {config.api.database_url}"
)

if user is None:
user = default_user()
with make_session() as session:
user_id = database._get_user_id(session, user)

# Log (JSONL) for recording which files previously added to vector database.
# Each entry has keys for 'user', 'corpus_name', 'source_storage' and 'document'.
ingestion_log = {}
if not ignore_log:
ingestion_log_file = Path.cwd() / ".ragna_ingestion_log.jsonl"
if ingestion_log_file.exists():
with open(ingestion_log_file, "r") as stream:
for line in stream:
entry = json.loads(line)
if entry["corpus_name"] == corpus_name and entry["user"] == user:
ingestion_log.setdefault(entry["source_storage"], set()).add(
entry["document"]
)

with Progress(
TextColumn("[progress.description]{task.description}"),
BarColumn(),
"[progress.percentage]{task.percentage:>3.1f}%",
TimeRemainingColumn(),
) as progress:
overall_task = progress.add_task(
"[cyan]Adding document embeddings to source storages...",
total=len(config.source_storages),
)

bad_documents_collection = {}
for source_storage in config.source_storages:
BATCH_SIZE = 10
number_of_batches = len(documents) // BATCH_SIZE
source_storage_task = progress.add_task(
f"[green]Adding document embeddings to {source_storage.__name__}...",
total=number_of_batches,
)

documents_not_ingested = []
for batch_number in range(0, len(documents), BATCH_SIZE):
document_instances = []
orm_documents = []

if source_storage.__name__ in ingestion_log:
batch_doc_set = set(
[
str(doc)
for doc in documents[
batch_number : batch_number + BATCH_SIZE
]
]
)
if batch_doc_set.issubset(ingestion_log[source_storage.__name__]):
progress.advance(source_storage_task)
continue

for document in documents[batch_number : batch_number + BATCH_SIZE]:
try:
doc_instance = document_factory(document)
document_instances.append(doc_instance)
orm_documents.append(
orm.Document(
id=doc_instance.id,
user_id=user_id,
name=doc_instance.name,
metadata_=doc_instance.metadata,
)
)
except Exception:
documents_not_ingested.append(document)

if not orm_documents:
continue

try:
session = make_session()
session.add_all(orm_documents)
source_storage().store(corpus_name, document_instances)
except Exception:
documents_not_ingested.extend(
documents[batch_number : batch_number + BATCH_SIZE]
)
session.rollback()
finally:
session.commit()
session.close()

if not ignore_log:
with open(ingestion_log_file, "a") as stream:
for document in documents[
batch_number : batch_number + BATCH_SIZE
]:
stream.write(
json.dumps(
{
"user": user,
"corpus_name": corpus_name,
"source_storage": source_storage.__name__,
"document": str(document),
}
)
+ "\n"
)

progress.advance(source_storage_task)

bad_documents_collection[source_storage.__name__] = set(
documents_not_ingested
)

progress.update(source_storage_task, completed=number_of_batches)
progress.advance(overall_task)

progress.update(overall_task, completed=len(config.source_storages))

if verbose:
typer.echo(
f"Failed to embed the following documents: {bad_documents_collection}"
)

0 comments on commit 198f35d

Please sign in to comment.