Skip to content

Commit

Permalink
Refactor LanceDB client code by adding schema_conversion and utils mo…
Browse files Browse the repository at this point in the history
…dules

Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Jun 4, 2024
1 parent 4827798 commit 53c0b0d
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 103 deletions.
133 changes: 30 additions & 103 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,17 +11,16 @@
Iterable,
Type,
Optional,
Sequence,
Dict,
)

import lancedb # type: ignore
import pyarrow as pa
from lancedb import DBConnection
from lancedb.common import DATA # type: ignore
from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore
from lancedb.pydantic import LanceModel, Vector # type: ignore
from lancedb.pydantic import LanceModel # type: ignore
from lancedb.query import LanceQueryBuilder # type: ignore
from lancedb.table import Table # type: ignore[import-untyped]
from numpy import ndarray
from pyarrow import Array, ChunkedArray
from pydantic import create_model
Expand All @@ -36,7 +35,7 @@
StateInfo,
TLoadJobState,
)
from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TColumnSchema
from dlt.common.schema import Schema, TTableSchema, TSchemaTables
from dlt.common.schema.typing import TColumnType
from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.storages import FileStorage
Expand All @@ -47,13 +46,16 @@
TEmbeddingProvider,
)
from dlt.destinations.impl.lancedb.lancedb_adapter import VECTORIZE_HINT
from dlt.destinations.impl.lancedb.schema_conversion import (
TLanceModel,
create_template_schema,
make_fields,
)
from dlt.destinations.impl.lancedb.utils import list_unique_identifiers, generate_uuid
from dlt.destinations.job_impl import EmptyLoadJob
from dlt.destinations.type_mapping import TypeMapper


TLanceModel = Type[LanceModel]


class LanceDBTypeMapper(TypeMapper):
sct_to_unbound_dbt = {
"text": pa.string(),
Expand Down Expand Up @@ -166,15 +168,16 @@ def _make_qualified_table_name(self, table_name: str) -> str:
def get_table_schema(self, table_name: str) -> pa.Schema:
return cast(pa.Schema, self.db_client[table_name].schema)

def _create_table(self, table_name: str, schema: Union[pa.Schema, LanceModel]) -> None:
def _create_table(self, table_name: str, schema: Union[pa.Schema, LanceModel]) -> Table:
"""Create a LanceDB Table from the provided LanceModel or PyArrow schema.
Args:
schema: The table schema to create.
table_name: The name of the table to create.
"""

self.db_client.create_table(table_name, schema=schema, embedding_functions=self.model_func)
# TODO: Add embedding_functions configuration to empty table creation.
return self.db_client.create_table(table_name, schema=schema)

def delete_table(self, table_name: str) -> None:
"""Delete a LanceDB table.
Expand Down Expand Up @@ -377,9 +380,8 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
if len(load_records) > 0:
state["dlt_load_id"] = state.pop("_dlt_load_id")
return StateInfo(**state)
except Exception as e:
logger.warning(str(e))
return None
except Exception:
raise

def get_stored_schema_by_hash(self, schema_hash: str) -> StorageSchemaInfo:
try:
Expand All @@ -392,9 +394,8 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> StorageSchemaInfo:
)
record = response.to_list()[0]
return StorageSchemaInfo(**record)
except Exception as e:
logger.warning(str(e))
return None
except Exception:
raise

def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
Expand Down Expand Up @@ -480,7 +481,7 @@ def __init__(
self.type_mapper = type_mapper
self.table_name = table_name
self.table_schema: TTableSchema = table_schema
self.unique_identifiers = self._list_unique_identifiers(table_schema)
self.unique_identifiers = list_unique_identifiers(table_schema)
self.embedding_fields = get_columns_names_with_prop(table_schema, VECTORIZE_HINT)
self.embedding_model_func = model_func
self.embedding_model_dimensions = client_config.embedding_model_dimensions
Expand All @@ -495,13 +496,13 @@ def __init__(

for record in records:
uuid_id = (
self._generate_uuid(record, self.unique_identifiers, self.table_name)
generate_uuid(record, self.unique_identifiers, self.table_name)
if self.unique_identifiers
else uuid.uuid4()
)
record.update({self.id_field_name: uuid_id})

template_model: TLanceModel = self._create_template_schema(
template_model: TLanceModel = create_template_schema(
self.id_field_name,
self.vector_field_name,
self.embedding_fields,
Expand All @@ -510,7 +511,15 @@ def __init__(
)

field_types: DictStrAny = {
k: v for d in self._make_fields(self.table_name) for k, v in d.items()
k: v
for d in make_fields(
self.table_name,
schema=self.schema,
type_mapper=self.type_mapper,
embedding_fields=self.embedding_fields,
embedding_model_func=self.embedding_model_func,
)
for k, v in d.items()
}

lance_model: TLanceModel = create_model(
Expand All @@ -520,9 +529,9 @@ def __init__(
**field_types,
)

self._upload_data(records, lance_model, table_name)
self.upload_data(records, lance_model, table_name)

def _upload_data(
def upload_data(
self, records: List[DictStrAny], lancedb_model: TLanceModel, table_name: str
) -> None:
"""Inserts records into a LanceDB table.
Expand All @@ -547,90 +556,8 @@ def _upload_data(
self.id_field_name
).when_matched_update_all().when_not_matched_insert_all().execute(parsed_records)

@staticmethod
def _generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str:
"""Generates deterministic UUID - used for deduplication.
Args:
data (Dict[str, Any]): Arbitrary data to generate UUID for.
unique_identifiers (Sequence[str]): A list of unique identifiers.
table_name (str): LanceDB table name.
Returns:
str: A string representation of the generated UUID.
"""
data_id = "_".join(str(data[key]) for key in unique_identifiers)
return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id))

@staticmethod
def _list_unique_identifiers(table_schema: TTableSchema) -> Sequence[str]:
"""Returns a list of unique identifiers for a table.
Args:
table_schema (TTableSchema): a dlt table schema.
Returns:
Sequence[str]: A list of unique column identifiers.
"""
if table_schema.get("write_disposition") == "merge":
if primary_keys := get_columns_names_with_prop(table_schema, "primary_key"):
return primary_keys
return get_columns_names_with_prop(table_schema, "unique")

def state(self) -> TLoadJobState:
return "completed"

def exception(self) -> str:
raise NotImplementedError()

@staticmethod
def _create_template_schema(
id_field_name: str,
vector_field_name: str,
embedding_fields: List[str],
embedding_model_func: TextEmbeddingFunction,
embedding_model_dimensions: Optional[int],
) -> Type[LanceModel]:
# Only create vector Field if there is one or more embedding fields defined.
special_fields = {
id_field_name: (str, ...),
}
if embedding_fields:
special_fields[vector_field_name] = (
Vector(embedding_model_dimensions or embedding_model_func.ndims()),
...,
)
return cast(
TLanceModel,
create_model( # type: ignore[call-overload]
"TemplateSchema",
__base__=LanceModel,
__module__=__name__,
__validators__={},
**special_fields,
),
)

def _make_field_schema(self, column_name: str, column: TColumnSchema) -> DictStrAny:
return {
column_name: (
self.type_mapper.to_db_type(column),
(
self.embedding_model_func.SourceField()
if column_name in self.embedding_fields
else ...
),
)
}

def _make_fields(self, table_name: str) -> List[Dict[str, Any]]:
"""Creates a Pydantic properties schema from a table schema.
Args:
table_name: The table name for which columns should be converted to a pydantic model.
"""

return [
self._make_field_schema(column_name, column)
for column_name, column in self.schema.get_table_columns(table_name).items()
]
93 changes: 93 additions & 0 deletions dlt/destinations/impl/lancedb/schema_conversion.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,93 @@
"""Utilities for creating Pydantic model schemas from table schemas."""

from typing import (
List,
Any,
cast,
Type,
Optional,
Dict,
)

from lancedb.embeddings import TextEmbeddingFunction # type: ignore
from lancedb.pydantic import LanceModel, Vector # type: ignore
from pydantic import create_model

from dlt.common.schema import Schema, TColumnSchema
from dlt.common.typing import DictStrAny
from dlt.destinations.type_mapping import TypeMapper


TLanceModel = Type[LanceModel]


def make_field_schema(
column_name: str,
column: TColumnSchema,
type_mapper: TypeMapper,
embedding_model_func: TextEmbeddingFunction,
embedding_fields: List[str],
) -> DictStrAny:
return {
column_name: (
type_mapper.to_db_type(column),
(embedding_model_func.SourceField() if column_name in embedding_fields else ...),
)
}


def make_fields(
table_name: str,
schema: Schema,
type_mapper: TypeMapper,
embedding_model_func: TextEmbeddingFunction,
embedding_fields: List[str],
) -> List[Dict[str, Any]]:
"""Creates a Pydantic properties schema from a table schema.
Args:
embedding_fields (List[str]):
embedding_model_func (TextEmbeddingFunction):
type_mapper (TypeMapper):
schema (Schema): Schema to use.
table_name: The table name for which columns should be converted to a pydantic model.
"""

return [
make_field_schema(
column_name,
column,
type_mapper=type_mapper,
embedding_model_func=embedding_model_func,
embedding_fields=embedding_fields,
)
for column_name, column in schema.get_table_columns(table_name).items()
]


def create_template_schema(
id_field_name: str,
vector_field_name: str,
embedding_fields: List[str],
embedding_model_func: TextEmbeddingFunction,
embedding_model_dimensions: Optional[int],
) -> Type[LanceModel]:
# Only create vector Field if there is one or more embedding fields defined.
special_fields = {
id_field_name: (str, ...),
}
if embedding_fields:
special_fields[vector_field_name] = (
Vector(embedding_model_dimensions or embedding_model_func.ndims()),
...,
)
return cast(
TLanceModel,
create_model( # type: ignore[call-overload]
"TemplateSchema",
__base__=LanceModel,
__module__=__name__,
__validators__={},
**special_fields,
),
)
36 changes: 36 additions & 0 deletions dlt/destinations/impl/lancedb/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import uuid
from typing import Sequence

from dlt.common.schema import TTableSchema
from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.typing import DictStrAny


def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str:
"""Generates deterministic UUID - used for deduplication.
Args:
data (Dict[str, Any]): Arbitrary data to generate UUID for.
unique_identifiers (Sequence[str]): A list of unique identifiers.
table_name (str): LanceDB table name.
Returns:
str: A string representation of the generated UUID.
"""
data_id = "_".join(str(data[key]) for key in unique_identifiers)
return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id))


def list_unique_identifiers(table_schema: TTableSchema) -> Sequence[str]:
"""Returns a list of unique identifiers for a table.
Args:
table_schema (TTableSchema): a dlt table schema.
Returns:
Sequence[str]: A list of unique column identifiers.
"""
if table_schema.get("write_disposition") == "merge":
if primary_keys := get_columns_names_with_prop(table_schema, "primary_key"):
return primary_keys
return get_columns_names_with_prop(table_schema, "unique")

0 comments on commit 53c0b0d

Please sign in to comment.