Skip to content

Commit

Permalink
fix: improve (re)insertion speed (#80)
Browse files Browse the repository at this point in the history
  • Loading branch information
lsorber authored Jan 6, 2025
1 parent 8052840 commit 620e556
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 31 deletions.
12 changes: 6 additions & 6 deletions src/raglite/_database.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,8 @@ class Document(SQLModel, table=True):
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))

# Add relationships so we can access document.chunks and document.evals.
chunks: list["Chunk"] = Relationship(back_populates="document")
evals: list["Eval"] = Relationship(back_populates="document")
chunks: list["Chunk"] = Relationship(back_populates="document", cascade_delete=True)
evals: list["Eval"] = Relationship(back_populates="document", cascade_delete=True)

@staticmethod
def from_path(doc_path: Path, **kwargs: Any) -> "Document":
Expand All @@ -76,15 +76,15 @@ class Chunk(SQLModel, table=True):

# Table columns.
id: ChunkId = Field(..., primary_key=True)
document_id: DocumentId = Field(..., foreign_key="document.id", index=True)
document_id: DocumentId = Field(..., foreign_key="document.id", index=True, ondelete="CASCADE")
index: int = Field(..., index=True)
headings: str
body: str
metadata_: dict[str, Any] = Field(default_factory=dict, sa_column=Column("metadata", JSON))

# Add relationships so we can access chunk.document and chunk.embeddings.
document: Document = Relationship(back_populates="chunks")
embeddings: list["ChunkEmbedding"] = Relationship(back_populates="chunk")
embeddings: list["ChunkEmbedding"] = Relationship(back_populates="chunk", cascade_delete=True)

@staticmethod
def from_body(
Expand Down Expand Up @@ -230,7 +230,7 @@ class ChunkEmbedding(SQLModel, table=True):

# Table columns.
id: int = Field(..., primary_key=True)
chunk_id: ChunkId = Field(..., foreign_key="chunk.id", index=True)
chunk_id: ChunkId = Field(..., foreign_key="chunk.id", index=True, ondelete="CASCADE")
embedding: FloatVector = Field(..., sa_column=Column(Embedding(dim=-1)))

# Add relationship so we can access embedding.chunk.
Expand Down Expand Up @@ -285,7 +285,7 @@ class Eval(SQLModel, table=True):

# Table columns.
id: EvalId = Field(..., primary_key=True)
document_id: DocumentId = Field(..., foreign_key="document.id", index=True)
document_id: DocumentId = Field(..., foreign_key="document.id", index=True, ondelete="CASCADE")
chunk_ids: list[ChunkId] = Field(default_factory=list, sa_column=Column(JSON))
question: str
contexts: list[str] = Field(default_factory=list, sa_column=Column(JSON))
Expand Down
43 changes: 18 additions & 25 deletions src/raglite/_insert.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,16 @@ def insert_document(doc_path: Path, *, config: RAGLiteConfig | None = None) -> N
"""Insert a document into the database and update the index."""
# Use the default config if not provided.
config = config or RAGLiteConfig()
db_backend = make_url(config.db_url).get_backend_name()
# Preprocess the document into chunks and chunk embeddings.
with tqdm(total=5, unit="step", dynamic_ncols=True) as pbar:
with tqdm(total=6, unit="step", dynamic_ncols=True) as pbar:
pbar.set_description("Initializing database")
engine = create_database_engine(config)
document_record = Document.from_path(doc_path)
with Session(engine) as session: # Exit early if the document is already in the database.
if session.get(Document, document_record.id) is not None:
pbar.update(6)
pbar.close()
return
pbar.update(1)
pbar.set_description("Converting to Markdown")
doc = document_to_markdown(doc_path)
Expand All @@ -92,32 +97,20 @@ def insert_document(doc_path: Path, *, config: RAGLiteConfig | None = None) -> N
max_size=config.chunk_max_size,
)
pbar.update(1)
# Create and store the chunk records.
with Session(engine) as session:
# Add the document to the document table.
document_record = Document.from_path(doc_path)
if session.get(Document, document_record.id) is None:
pbar.set_description("Updating database")
with Session(engine) as session:
session.add(document_record)
for chunk_record, chunk_embedding_record_list in zip(
*_create_chunk_records(document_record.id, chunks, chunk_embeddings, config),
strict=True,
):
session.add(chunk_record)
session.add_all(chunk_embedding_record_list)
session.commit()
# Create the chunk records to insert into the chunk table.
chunk_records, chunk_embedding_records = _create_chunk_records(
document_record.id, chunks, chunk_embeddings, config
)
# Store the chunk and chunk embedding records.
for chunk_record, chunk_embedding_record_list in tqdm(
zip(chunk_records, chunk_embedding_records, strict=True),
desc="Inserting chunks",
total=len(chunk_records),
unit="chunk",
dynamic_ncols=True,
):
if session.get(Chunk, chunk_record.id) is not None:
continue
session.add(chunk_record)
session.add_all(chunk_embedding_record_list)
session.commit()
pbar.update(1)
pbar.close()
# Manually update the vector search chunk index for SQLite.
if db_backend == "sqlite":
if make_url(config.db_url).get_backend_name() == "sqlite":
from pynndescent import NNDescent

with Session(engine) as session:
Expand Down

0 comments on commit 620e556

Please sign in to comment.