-
Notifications
You must be signed in to change notification settings - Fork 190
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Refactor LanceDB client code by adding schema_conversion and utils mo…
…dules Signed-off-by: Marcel Coetzee <[email protected]>
- Loading branch information
Showing
3 changed files
with
159 additions
and
103 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,93 @@ | ||
"""Utilities for creating Pydantic model schemas from table schemas.""" | ||
|
||
from typing import ( | ||
List, | ||
Any, | ||
cast, | ||
Type, | ||
Optional, | ||
Dict, | ||
) | ||
|
||
from lancedb.embeddings import TextEmbeddingFunction # type: ignore | ||
from lancedb.pydantic import LanceModel, Vector # type: ignore | ||
from pydantic import create_model | ||
|
||
from dlt.common.schema import Schema, TColumnSchema | ||
from dlt.common.typing import DictStrAny | ||
from dlt.destinations.type_mapping import TypeMapper | ||
|
||
|
||
TLanceModel = Type[LanceModel] | ||
|
||
|
||
def make_field_schema( | ||
column_name: str, | ||
column: TColumnSchema, | ||
type_mapper: TypeMapper, | ||
embedding_model_func: TextEmbeddingFunction, | ||
embedding_fields: List[str], | ||
) -> DictStrAny: | ||
return { | ||
column_name: ( | ||
type_mapper.to_db_type(column), | ||
(embedding_model_func.SourceField() if column_name in embedding_fields else ...), | ||
) | ||
} | ||
|
||
|
||
def make_fields( | ||
table_name: str, | ||
schema: Schema, | ||
type_mapper: TypeMapper, | ||
embedding_model_func: TextEmbeddingFunction, | ||
embedding_fields: List[str], | ||
) -> List[Dict[str, Any]]: | ||
"""Creates a Pydantic properties schema from a table schema. | ||
Args: | ||
embedding_fields (List[str]): | ||
embedding_model_func (TextEmbeddingFunction): | ||
type_mapper (TypeMapper): | ||
schema (Schema): Schema to use. | ||
table_name: The table name for which columns should be converted to a pydantic model. | ||
""" | ||
|
||
return [ | ||
make_field_schema( | ||
column_name, | ||
column, | ||
type_mapper=type_mapper, | ||
embedding_model_func=embedding_model_func, | ||
embedding_fields=embedding_fields, | ||
) | ||
for column_name, column in schema.get_table_columns(table_name).items() | ||
] | ||
|
||
|
||
def create_template_schema( | ||
id_field_name: str, | ||
vector_field_name: str, | ||
embedding_fields: List[str], | ||
embedding_model_func: TextEmbeddingFunction, | ||
embedding_model_dimensions: Optional[int], | ||
) -> Type[LanceModel]: | ||
# Only create vector Field if there is one or more embedding fields defined. | ||
special_fields = { | ||
id_field_name: (str, ...), | ||
} | ||
if embedding_fields: | ||
special_fields[vector_field_name] = ( | ||
Vector(embedding_model_dimensions or embedding_model_func.ndims()), | ||
..., | ||
) | ||
return cast( | ||
TLanceModel, | ||
create_model( # type: ignore[call-overload] | ||
"TemplateSchema", | ||
__base__=LanceModel, | ||
__module__=__name__, | ||
__validators__={}, | ||
**special_fields, | ||
), | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,36 @@ | ||
import uuid | ||
from typing import Sequence | ||
|
||
from dlt.common.schema import TTableSchema | ||
from dlt.common.schema.utils import get_columns_names_with_prop | ||
from dlt.common.typing import DictStrAny | ||
|
||
|
||
def generate_uuid(data: DictStrAny, unique_identifiers: Sequence[str], table_name: str) -> str: | ||
"""Generates deterministic UUID - used for deduplication. | ||
Args: | ||
data (Dict[str, Any]): Arbitrary data to generate UUID for. | ||
unique_identifiers (Sequence[str]): A list of unique identifiers. | ||
table_name (str): LanceDB table name. | ||
Returns: | ||
str: A string representation of the generated UUID. | ||
""" | ||
data_id = "_".join(str(data[key]) for key in unique_identifiers) | ||
return str(uuid.uuid5(uuid.NAMESPACE_DNS, table_name + data_id)) | ||
|
||
|
||
def list_unique_identifiers(table_schema: TTableSchema) -> Sequence[str]: | ||
"""Returns a list of unique identifiers for a table. | ||
Args: | ||
table_schema (TTableSchema): a dlt table schema. | ||
Returns: | ||
Sequence[str]: A list of unique column identifiers. | ||
""" | ||
if table_schema.get("write_disposition") == "merge": | ||
if primary_keys := get_columns_names_with_prop(table_schema, "primary_key"): | ||
return primary_keys | ||
return get_columns_names_with_prop(table_schema, "unique") |