Skip to content

Commit

Permalink
Add ollama to supported embedding providers and test orphaned record …
Browse files Browse the repository at this point in the history
…removal with embeddings

Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Aug 1, 2024
1 parent aac7647 commit e33b7cf
Show file tree
Hide file tree
Showing 2 changed files with 91 additions and 5 deletions.
1 change: 1 addition & 0 deletions dlt/destinations/impl/lancedb/configuration.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class LanceDBClientOptions(BaseConfiguration):
"sentence-transformers",
"huggingface",
"colbert",
"ollama",
]


Expand Down
95 changes: 90 additions & 5 deletions tests/load/lancedb/test_remove_orphaned_records.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,21 @@
from typing import Iterator, List, Generator
from typing import Iterator, List, Generator, Any

import numpy as np
import pandas as pd
import pytest
from lancedb.table import Table # type: ignore
from pandas import DataFrame
from pandas.testing import assert_frame_equal
from pyarrow import Table

import dlt
from dlt.common.typing import DictStrAny
from dlt.common.utils import uniq_id
from dlt.destinations.impl.lancedb.lancedb_adapter import DOCUMENT_ID_HINT
from dlt.destinations.impl.lancedb.lancedb_adapter import (
DOCUMENT_ID_HINT,
lancedb_adapter,
)
from tests.load.lancedb.utils import chunk_document
from tests.load.utils import (
drop_active_pipeline_data,
)
Expand Down Expand Up @@ -119,10 +126,12 @@ def identity_resource(
.reset_index(drop=True)
)

expected_child_data = expected_child_data.sort_values(by="bar").reset_index(drop=True)
expected_grandchild_data = expected_grandchild_data.sort_values(by="baz").reset_index(
expected_child_data = expected_child_data.sort_values(by="bar").reset_index(
drop=True
)
expected_grandchild_data = expected_grandchild_data.sort_values(
by="baz"
).reset_index(drop=True)

assert_frame_equal(actual_child_df[["bar"]], expected_child_data)
assert_frame_equal(actual_grandchild_df[["baz"]], expected_grandchild_data)
Expand Down Expand Up @@ -184,7 +193,83 @@ def identity_resource(
tbl = client.db_client.open_table(root_table_name) # type: ignore[attr-defined]

actual_root_df: DataFrame = (
tbl.to_pandas().sort_values(by=["doc_id", "chunk_hash"]).reset_index(drop=True)
tbl.to_pandas()
.sort_values(by=["doc_id", "chunk_hash"])
.reset_index(drop=True)
)[["doc_id", "chunk_hash"]]

assert_frame_equal(actual_root_df, expected_root_table_df)


def test_lancedb_root_table_remove_orphaned_records_with_real_embeddings() -> None:
@dlt.resource( # type: ignore
write_disposition="merge",
table_name="document",
columns={"doc_id": {DOCUMENT_ID_HINT: True}},
)
def documents(docs: List[DictStrAny]) -> Generator[DictStrAny, None, None]:
for doc in docs:
doc_id = doc["doc_id"]
for chunk in chunk_document(doc["text"]):
yield {"doc_id": doc_id, "doc_text": doc["text"], "chunk": chunk}

@dlt.source()
def documents_source(
docs: List[DictStrAny],
) -> Any:
return documents(docs)

lancedb_adapter(
documents,
embed=["chunk"],
)

pipeline = dlt.pipeline(
pipeline_name="test_lancedb_remove_orphaned_records_with_embeddings",
destination="lancedb",
dataset_name=f"test_lancedb_remove_orphaned_records_{uniq_id()}",
dev_mode=True,
)

initial_docs = [
{
"text": (
"This is the first document. It contains some text that will be chunked and"
" embedded. (I don't want to be seen in updated run's embedding chunk texts btw)"
),
"doc_id": 1,
},
{
"text": "Here's another document. It's a bit different from the first one.",
"doc_id": 2,
},
]

info = pipeline.run(documents_source(initial_docs))
assert_load_info(info)

updated_docs = [
{
"text": "This is the first document, but it has been updated with new content.",
"doc_id": 1,
},
{
"text": "This is a completely new document that wasn't in the initial set.",
"doc_id": 3,
},
]

info = pipeline.run(documents_source(updated_docs))
assert_load_info(info)

with pipeline.destination_client() as client:
embeddings_table_name = client.make_qualified_table_name("document") # type: ignore[attr-defined]
tbl: Table = client.db_client.open_table(embeddings_table_name) # type: ignore[attr-defined]
df = tbl.to_pandas()

# Check (non-empty) embeddings as present, and that orphaned embeddings have been discarded.
assert len(df) == 21
assert "vector__" in df.columns
for _, vector in enumerate(df["vector__"]):
assert isinstance(vector, np.ndarray)
assert vector.size > 0

0 comments on commit e33b7cf

Please sign in to comment.