From 53c0b0d145ce2c5eac74eff58bae8e91599bcfbd Mon Sep 17 00:00:00 2001 From: Marcel Coetzee Date: Tue, 4 Jun 2024 23:05:48 +0200 Subject: [PATCH] Refactor LanceDB client code by adding schema_conversion and utils modules Signed-off-by: Marcel Coetzee --- .../impl/lancedb/lancedb_client.py | 133 ++++-------------- .../impl/lancedb/schema_conversion.py | 93 ++++++++++++ dlt/destinations/impl/lancedb/utils.py | 36 +++++ 3 files changed, 159 insertions(+), 103 deletions(-) create mode 100644 dlt/destinations/impl/lancedb/schema_conversion.py create mode 100644 dlt/destinations/impl/lancedb/utils.py diff --git a/dlt/destinations/impl/lancedb/lancedb_client.py b/dlt/destinations/impl/lancedb/lancedb_client.py index 12ced96b36..700b7be28a 100644 --- a/dlt/destinations/impl/lancedb/lancedb_client.py +++ b/dlt/destinations/impl/lancedb/lancedb_client.py @@ -11,8 +11,6 @@ Iterable, Type, Optional, - Sequence, - Dict, ) import lancedb # type: ignore @@ -20,8 +18,9 @@ from lancedb import DBConnection from lancedb.common import DATA # type: ignore from lancedb.embeddings import EmbeddingFunctionRegistry, TextEmbeddingFunction # type: ignore -from lancedb.pydantic import LanceModel, Vector # type: ignore +from lancedb.pydantic import LanceModel # type: ignore from lancedb.query import LanceQueryBuilder # type: ignore +from lancedb.table import Table # type: ignore[import-untyped] from numpy import ndarray from pyarrow import Array, ChunkedArray from pydantic import create_model @@ -36,7 +35,7 @@ StateInfo, TLoadJobState, ) -from dlt.common.schema import Schema, TTableSchema, TSchemaTables, TColumnSchema +from dlt.common.schema import Schema, TTableSchema, TSchemaTables from dlt.common.schema.typing import TColumnType from dlt.common.schema.utils import get_columns_names_with_prop from dlt.common.storages import FileStorage @@ -47,13 +46,16 @@ TEmbeddingProvider, ) from dlt.destinations.impl.lancedb.lancedb_adapter import VECTORIZE_HINT +from dlt.destinations.impl.lancedb.schema_conversion import ( + TLanceModel, + create_template_schema, + make_fields, +) +from dlt.destinations.impl.lancedb.utils import list_unique_identifiers, generate_uuid from dlt.destinations.job_impl import EmptyLoadJob from dlt.destinations.type_mapping import TypeMapper -TLanceModel = Type[LanceModel] - - class LanceDBTypeMapper(TypeMapper): sct_to_unbound_dbt = { "text": pa.string(), @@ -166,7 +168,7 @@ def _make_qualified_table_name(self, table_name: str) -> str: def get_table_schema(self, table_name: str) -> pa.Schema: return cast(pa.Schema, self.db_client[table_name].schema) - def _create_table(self, table_name: str, schema: Union[pa.Schema, LanceModel]) -> None: + def _create_table(self, table_name: str, schema: Union[pa.Schema, LanceModel]) -> Table: """Create a LanceDB Table from the provided LanceModel or PyArrow schema. Args: @@ -174,7 +176,8 @@ def _create_table(self, table_name: str, schema: Union[pa.Schema, LanceModel]) - table_name: The name of the table to create. """ - self.db_client.create_table(table_name, schema=schema, embedding_functions=self.model_func) + # TODO: Add embedding_functions configuration to empty table creation. + return self.db_client.create_table(table_name, schema=schema) def delete_table(self, table_name: str) -> None: """Delete a LanceDB table. @@ -377,9 +380,8 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]: if len(load_records) > 0: state["dlt_load_id"] = state.pop("_dlt_load_id") return StateInfo(**state) - except Exception as e: - logger.warning(str(e)) - return None + except Exception: + raise def get_stored_schema_by_hash(self, schema_hash: str) -> StorageSchemaInfo: try: @@ -392,9 +394,8 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> StorageSchemaInfo: ) record = response.to_list()[0] return StorageSchemaInfo(**record) - except Exception as e: - logger.warning(str(e)) - return None + except Exception: + raise def get_stored_schema(self) -> Optional[StorageSchemaInfo]: """Retrieves newest schema from destination storage.""" @@ -480,7 +481,7 @@ def __init__( self.type_mapper = type_mapper self.table_name = table_name self.table_schema: TTableSchema = table_schema - self.unique_identifiers = self._list_unique_identifiers(table_schema) + self.unique_identifiers = list_unique_identifiers(table_schema) self.embedding_fields = get_columns_names_with_prop(table_schema, VECTORIZE_HINT) self.embedding_model_func = model_func self.embedding_model_dimensions = client_config.embedding_model_dimensions @@ -495,13 +496,13 @@ def __init__( for record in records: uuid_id = ( - self._generate_uuid(record, self.unique_identifiers, self.table_name) + generate_uuid(record, self.unique_identifiers, self.table_name) if self.unique_identifiers else uuid.uuid4() ) record.update({self.id_field_name: uuid_id}) - template_model: TLanceModel = self._create_template_schema( + template_model: TLanceModel = create_template_schema( self.id_field_name, self.vector_field_name, self.embedding_fields, @@ -510,7 +511,15 @@ def __init__( ) field_types: DictStrAny = { - k: v for d in self._make_fields(self.table_name) for k, v in d.items() + k: v + for d in make_fields( + self.table_name, + schema=self.schema, + type_mapper=self.type_mapper, + embedding_fields=self.embedding_fields, + embedding_model_func=self.embedding_model_func, + ) + for k, v in d.items() } lance_model: TLanceModel = create_model( @@ -520,9 +529,9 @@ def __init__( **field_types, ) - self._upload_data(records, lance_model, table_name) + self.upload_data(records, lance_model, table_name) - def _upload_data( + def upload_data( self, records: List[DictStrAny], lancedb_model: TLanceModel, table_name: str ) -> None: """Inserts records into a LanceDB table. @@ -547,90 +556,8 @@ def _upload_data( self.id_field_name ).when_matched_update_all().when_not_matched_insert_all().execute(parsed_records) - @staticmethod - def _generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str: - """Generates deterministic UUID - used for deduplication. - - Args: - data (Dict[str, Any]): Arbitrary data to generate UUID for. - unique_identifiers (Sequence[str]): A list of unique identifiers. - table_name (str): LanceDB table name. - - Returns: - str: A string representation of the generated UUID. - """ - data_id = "_".join(str(data[key]) for key in unique_identifiers) - return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) - - @staticmethod - def _list_unique_identifiers(table_schema: TTableSchema) -> Sequence[str]: - """Returns a list of unique identifiers for a table. - - Args: - table_schema (TTableSchema): a dlt table schema. - - Returns: - Sequence[str]: A list of unique column identifiers. - """ - if table_schema.get("write_disposition") == "merge": - if primary_keys := get_columns_names_with_prop(table_schema, "primary_key"): - return primary_keys - return get_columns_names_with_prop(table_schema, "unique") - def state(self) -> TLoadJobState: return "completed" def exception(self) -> str: raise NotImplementedError() - - @staticmethod - def _create_template_schema( - id_field_name: str, - vector_field_name: str, - embedding_fields: List[str], - embedding_model_func: TextEmbeddingFunction, - embedding_model_dimensions: Optional[int], - ) -> Type[LanceModel]: - # Only create vector Field if there is one or more embedding fields defined. - special_fields = { - id_field_name: (str, ...), - } - if embedding_fields: - special_fields[vector_field_name] = ( - Vector(embedding_model_dimensions or embedding_model_func.ndims()), - ..., - ) - return cast( - TLanceModel, - create_model( # type: ignore[call-overload] - "TemplateSchema", - __base__=LanceModel, - __module__=__name__, - __validators__={}, - **special_fields, - ), - ) - - def _make_field_schema(self, column_name: str, column: TColumnSchema) -> DictStrAny: - return { - column_name: ( - self.type_mapper.to_db_type(column), - ( - self.embedding_model_func.SourceField() - if column_name in self.embedding_fields - else ... - ), - ) - } - - def _make_fields(self, table_name: str) -> List[Dict[str, Any]]: - """Creates a Pydantic properties schema from a table schema. - - Args: - table_name: The table name for which columns should be converted to a pydantic model. - """ - - return [ - self._make_field_schema(column_name, column) - for column_name, column in self.schema.get_table_columns(table_name).items() - ] diff --git a/dlt/destinations/impl/lancedb/schema_conversion.py b/dlt/destinations/impl/lancedb/schema_conversion.py new file mode 100644 index 0000000000..dc14afa44d --- /dev/null +++ b/dlt/destinations/impl/lancedb/schema_conversion.py @@ -0,0 +1,93 @@ +"""Utilities for creating Pydantic model schemas from table schemas.""" + +from typing import ( + List, + Any, + cast, + Type, + Optional, + Dict, +) + +from lancedb.embeddings import TextEmbeddingFunction # type: ignore +from lancedb.pydantic import LanceModel, Vector # type: ignore +from pydantic import create_model + +from dlt.common.schema import Schema, TColumnSchema +from dlt.common.typing import DictStrAny +from dlt.destinations.type_mapping import TypeMapper + + +TLanceModel = Type[LanceModel] + + +def make_field_schema( + column_name: str, + column: TColumnSchema, + type_mapper: TypeMapper, + embedding_model_func: TextEmbeddingFunction, + embedding_fields: List[str], +) -> DictStrAny: + return { + column_name: ( + type_mapper.to_db_type(column), + (embedding_model_func.SourceField() if column_name in embedding_fields else ...), + ) + } + + +def make_fields( + table_name: str, + schema: Schema, + type_mapper: TypeMapper, + embedding_model_func: TextEmbeddingFunction, + embedding_fields: List[str], +) -> List[Dict[str, Any]]: + """Creates a Pydantic properties schema from a table schema. + + Args: + embedding_fields (List[str]): + embedding_model_func (TextEmbeddingFunction): + type_mapper (TypeMapper): + schema (Schema): Schema to use. + table_name: The table name for which columns should be converted to a pydantic model. + """ + + return [ + make_field_schema( + column_name, + column, + type_mapper=type_mapper, + embedding_model_func=embedding_model_func, + embedding_fields=embedding_fields, + ) + for column_name, column in schema.get_table_columns(table_name).items() + ] + + +def create_template_schema( + id_field_name: str, + vector_field_name: str, + embedding_fields: List[str], + embedding_model_func: TextEmbeddingFunction, + embedding_model_dimensions: Optional[int], +) -> Type[LanceModel]: + # Only create vector Field if there is one or more embedding fields defined. + special_fields = { + id_field_name: (str, ...), + } + if embedding_fields: + special_fields[vector_field_name] = ( + Vector(embedding_model_dimensions or embedding_model_func.ndims()), + ..., + ) + return cast( + TLanceModel, + create_model( # type: ignore[call-overload] + "TemplateSchema", + __base__=LanceModel, + __module__=__name__, + __validators__={}, + **special_fields, + ), + ) diff --git a/dlt/destinations/impl/lancedb/utils.py b/dlt/destinations/impl/lancedb/utils.py new file mode 100644 index 0000000000..741f4ee746 --- /dev/null +++ b/dlt/destinations/impl/lancedb/utils.py @@ -0,0 +1,36 @@ +import uuid +from typing import Sequence + +from dlt.common.schema import TTableSchema +from dlt.common.schema.utils import get_columns_names_with_prop +from dlt.common.typing import DictStrAny + + +def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str: + """Generates deterministic UUID - used for deduplication. + + Args: + data (Dict[str, Any]): Arbitrary data to generate UUID for. + unique_identifiers (Sequence[str]): A list of unique identifiers. + table_name (str): LanceDB table name. + + Returns: + str: A string representation of the generated UUID. + """ + data_id = "_".join(str(data[key]) for key in unique_identifiers) + return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) + + +def list_unique_identifiers(table_schema: TTableSchema) -> Sequence[str]: + """Returns a list of unique identifiers for a table. + + Args: + table_schema (TTableSchema): a dlt table schema. + + Returns: + Sequence[str]: A list of unique column identifiers. + """ + if table_schema.get("write_disposition") == "merge": + if primary_keys := get_columns_names_with_prop(table_schema, "primary_key"): + return primary_keys + return get_columns_names_with_prop(table_schema, "unique")