Skip to content

Commit

Permalink
Refactor LanceDB client and tests, enhance DB type mapping
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Jun 4, 2024
1 parent 4c73541 commit bfcc8bb
Show file tree
Hide file tree
Showing 4 changed files with 168 additions and 131 deletions.
182 changes: 156 additions & 26 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,19 @@
Type,
Optional,
Sequence,
Dict,
)

import lancedb # type: ignore
import pyarrow as pa
from lancedb import DBConnection
from lancedb.common import DATA # type: ignore
from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore
from lancedb.pydantic import LanceModel # type: ignore
from lancedb.pydantic import LanceModel, Vector # type: ignore
from lancedb.query import LanceQueryBuilder # type: ignore
from numpy import ndarray
from pyarrow import Array, ChunkedArray
from pydantic import create_model

from dlt.common import json, pendulum, logger
from dlt.common.destination import DestinationCapabilitiesContext
Expand All @@ -34,7 +36,8 @@
StateInfo,
TLoadJobState,
)
from dlt.common.schema import Schema, TTableSchema, TSchemaTables
from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TColumnSchema
from dlt.common.schema.typing import TColumnType
from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.storages import FileStorage
from dlt.common.typing import DictStrAny
Expand All @@ -44,13 +47,43 @@
TEmbeddingProvider,
)
from dlt.destinations.impl.lancedb.lancedb_adapter import VECTORIZE_HINT
from dlt.destinations.impl.lancedb.utils import infer_lancedb_model_from_data
from dlt.destinations.job_impl import EmptyLoadJob
from dlt.destinations.type_mapping import TypeMapper


TLanceModel = Type[LanceModel]


class LanceDBTypeMapper(TypeMapper):
sct_to_unbound_dbt = {
"text": pa.string(),
"double": pa.float64(),
"bool": pa.bool_(),
"timestamp": pa.timestamp("us", "UTC"),
"bigint": pa.int64(),
"binary": pa.binary(),
"decimal": pa.decimal128(38, 18),
"date": pa.date32(),
"time": pa.time64("us"),
"complex": pa.string(),
"wei": pa.float64(),
}

sct_to_dbt = {}

dbt_to_sct = {
pa.string(): "text",
pa.float64(): "double",
pa.bool_(): "bool",
pa.timestamp("us", "UTC"): "timestamp",
pa.int64(): "bigint",
pa.binary(): "binary",
pa.decimal128(38, 18): "decimal",
pa.date32(): "date",
pa.time64("us"): "time",
}


class NullSchema(LanceModel):
pass

Expand Down Expand Up @@ -97,6 +130,7 @@ def __init__(self, schema: Schema, config: LanceDBClientConfiguration) -> None:
uri=self.config.credentials.uri, api_key=self.config.credentials.api_key
)
self.registry = EmbeddingFunctionRegistry.get_instance()
self.type_mapper = LanceDBTypeMapper(self.capabilities)

# LanceDB doesn't provide a standardized way to set API keys across providers.
# Some use ENV variables and others allow passing api key as an argument.
Expand Down Expand Up @@ -132,15 +166,19 @@ def _make_qualified_table_name(self, table_name: str) -> str:
def get_table_schema(self, table_name: str) -> pa.Schema:
return cast(pa.Schema, self.db_client[table_name].schema)

def _create_table(self, table_name: str, schema: Union[pa.Schema, LanceModel]) -> None:
def _create_table(
self, table_name: str, schema: Union[pa.Schema, LanceModel]
) -> None:
"""Create a LanceDB Table from the provided LanceModel or PyArrow schema.
Args:
schema: The table schema to create.
table_name: The name of the table to create.
"""

self.db_client.create_table(table_name, schema=schema, embedding_functions=self.model_func)
self.db_client.create_table(
table_name, schema=schema, embedding_functions=self.model_func
)

def delete_table(self, table_name: str) -> None:
"""Delete a LanceDB table.
Expand Down Expand Up @@ -205,7 +243,9 @@ def add_to_table(
Returns:
None
"""
self.db_client.open_table(table_name).add(data, mode, on_bad_vectors, fill_value)
self.db_client.open_table(table_name).add(
data, mode, on_bad_vectors, fill_value
)

def drop_storage(self) -> None:
"""Drop the dataset from the LanceDB instance.
Expand Down Expand Up @@ -248,7 +288,9 @@ def is_storage_initialized(self) -> bool:

def _create_sentinel_table(self) -> None:
"""Create an empty table to indicate that the storage is initialized."""
self._create_table(schema=cast(LanceModel, NullSchema), table_name=self.sentinel_table)
self._create_table(
schema=cast(LanceModel, NullSchema), table_name=self.sentinel_table
)

def _delete_sentinel_table(self) -> None:
"""Delete the sentinel table."""
Expand Down Expand Up @@ -285,7 +327,9 @@ def _update_schema_in_storage(self, schema: Schema) -> None:
"inserted_at": str(pendulum.now()),
"schema": json.dumps(schema.to_dict()),
}
version_table_name = self._make_qualified_table_name(self.schema.version_table_name)
version_table_name = self._make_qualified_table_name(
self.schema.version_table_name
)
self._create_record(properties, VersionSchema, version_table_name)

def _create_record(
Expand All @@ -301,7 +345,9 @@ def _create_record(
try:
tbl = self.db_client.open_table(self._make_qualified_table_name(table_name))
except FileNotFoundError:
tbl = self.db_client.create_table(self._make_qualified_table_name(table_name))
tbl = self.db_client.create_table(
self._make_qualified_table_name(table_name)
)
except Exception:
raise

Expand All @@ -321,7 +367,9 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
"""Loads compressed state from destination storage by finding a load ID that was completed."""
while True:
try:
state_table_name = self._make_qualified_table_name(self.schema.state_table_name)
state_table_name = self._make_qualified_table_name(
self.schema.state_table_name
)
state_records = (
self.db_client.open_table(state_table_name)
.search()
Expand All @@ -333,7 +381,9 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
return None
for state in state_records:
load_id = state["_dlt_load_id"]
loads_table_name = self._make_qualified_table_name(self.schema.loads_table_name)
loads_table_name = self._make_qualified_table_name(
self.schema.loads_table_name
)
load_records = (
self.db_client.open_table(loads_table_name)
.search()
Expand Down Expand Up @@ -365,7 +415,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> StorageSchemaInfo:
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
try:
version_table_name = self._make_qualified_table_name(self.schema.version_table_name)
version_table_name = self._make_qualified_table_name(
self.schema.version_table_name
)
response = (
self.db_client[version_table_name]
.search()
Expand Down Expand Up @@ -402,11 +454,14 @@ def complete_load(self, load_id: str) -> None:
def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
def start_file_load(
self, table: TTableSchema, file_path: str, load_id: str
) -> LoadJob:
return LoadLanceDBJob(
self.schema,
table,
file_path,
type_mapper=self.type_mapper,
db_client=self.db_client,
client_config=self.config,
table_name=self._make_qualified_table_name(table["name"]),
Expand All @@ -416,6 +471,11 @@ def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) ->
def _table_exists(self, table_name: str) -> bool:
return table_name in self.db_client.table_names()

def _from_db_type(
self, wt_t: str, precision: Optional[int], scale: Optional[int]
) -> TColumnType:
return self.type_mapper.from_db_type(wt_t, precision, scale)


class LoadLanceDBJob(LoadJob):
embedding_fields: List[str]
Expand All @@ -426,19 +486,24 @@ def __init__(
schema: Schema,
table_schema: TTableSchema,
local_path: str,
type_mapper: LanceDBTypeMapper,
db_client: DBConnection,
client_config: LanceDBClientConfiguration,
table_name: str,
model_func: TextEmbeddingFunction,
) -> None:
file_name = FileStorage.get_file_name_from_file_path(local_path)
super().__init__(file_name)
self.schema = schema
self.config = client_config
self.db_client = db_client
self.type_mapper = type_mapper
self.table_name = table_name
self.table_schema: TTableSchema = table_schema
self.unique_identifiers = self._list_unique_identifiers(table_schema)
self.embedding_fields = get_columns_names_with_prop(table_schema, VECTORIZE_HINT)
self.embedding_fields = get_columns_names_with_prop(
table_schema, VECTORIZE_HINT
)
self.embedding_model_func = model_func
self.embedding_model_dimensions = client_config.embedding_model_dimensions

Expand All @@ -458,16 +523,26 @@ def __init__(
)
record.update({self.id_field_name: uuid_id})

# TODO: Use `table_schema` to infer LanceDB schema instead.
inferred_lancedb_model: Type[LanceModel] = infer_lancedb_model_from_data(
data=records,
id_field_name=self.id_field_name,
vector_field_name=self.vector_field_name,
embedding_fields=self.embedding_fields,
embedding_model_func=self.embedding_model_func,
embedding_model_dimensions=self.embedding_model_dimensions,
template_model: TLanceModel = self._create_template_schema(
self.id_field_name,
self.vector_field_name,
self.embedding_fields,
self.embedding_model_func,
self.embedding_model_dimensions,
)

field_types: DictStrAny = {
k: v for d in self._make_fields(self.table_name) for k, v in d.items()
}

lance_model: TLanceModel = create_model(
self.table_name,
__base__=template_model,
__module__=__name__,
**field_types,
)
self._upload_data(records, inferred_lancedb_model, table_name)

self._upload_data(records, lance_model, table_name)

def _upload_data(
self, records: List[DictStrAny], lancedb_model: TLanceModel, table_name: str
Expand All @@ -487,15 +562,21 @@ def _upload_data(
except Exception:
raise

parsed_records: List[LanceModel] = [lancedb_model(**record) for record in records]
parsed_records: List[LanceModel] = [
lancedb_model(**record) for record in records
]

# Upsert using reserved ID as the key.
tbl.merge_insert(
self.id_field_name
).when_matched_update_all().when_not_matched_insert_all().execute(parsed_records)
).when_matched_update_all().when_not_matched_insert_all().execute(
parsed_records
)

@staticmethod
def _generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str:
def _generate_uuid(
data: DictStrAny, unique_identifiers: Sequence[str], table_name: str
) -> str:
"""Generates deterministic UUID - used for deduplication.
Args:
Expand Down Expand Up @@ -529,3 +610,52 @@ def state(self) -> TLoadJobState:

def exception(self) -> str:
raise NotImplementedError()

@staticmethod
def _create_template_schema(
id_field_name: str,
vector_field_name: str,
embedding_fields: List[str],
embedding_model_func: TextEmbeddingFunction,
embedding_model_dimensions: Optional[int],
) -> Type[LanceModel]:
# Only create vector Field if there is one or more embedding fields defined.
special_fields = {
id_field_name: (str, ...),
}
if embedding_fields:
special_fields[vector_field_name] = (
Vector(embedding_model_dimensions or embedding_model_func.ndims()),
...,
)
return create_model(
"TemplateSchema",
__base__=LanceModel,
__module__=__name__,
__validators__={},
**special_fields,
)

def _make_field_schema(self, column_name: str, column: TColumnSchema) -> DictStrAny:
return {
column_name: (
self.type_mapper.to_db_type(column),
(
self.embedding_model_func.SourceField()
if column_name in self.embedding_fields
else ...
),
)
}

def _make_fields(self, table_name: str) -> List[Dict[str, Any]]:
"""Creates a Pydantic properties schema from a table schema.
Args:
table_name: The table name for which columns should be converted to a pydantic model.
"""

return [
self._make_field_schema(column_name, column)
for column_name, column in self.schema.get_table_columns(table_name).items()
]
Loading

0 comments on commit bfcc8bb

Please sign in to comment.