Skip to content

Commit

Permalink
Refactor LanceDB client and schema for better table creation and mana…
Browse files Browse the repository at this point in the history
…gement

Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Jun 12, 2024
1 parent 175b6db commit 719fcfb
Show file tree
Hide file tree
Showing 2 changed files with 87 additions and 25 deletions.
95 changes: 74 additions & 21 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
make_arrow_table_schema,
TArrowSchema,
NULL_SCHEMA,
TArrowField,
)
from dlt.destinations.impl.lancedb.utils import (
list_unique_identifiers,
Expand All @@ -75,7 +76,9 @@


TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {
v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()
}


class LanceDBTypeMapper(TypeMapper):
Expand Down Expand Up @@ -191,7 +194,9 @@ def upload_batch(
tbl.add(records, mode="replace")
elif write_disposition == "merge":
if not id_field_name:
raise ValueError("To perform a merge update, 'id_field_name' must be specified.")
raise ValueError(
"To perform a merge update, 'id_field_name' must be specified."
)
tbl.merge_insert(
id_field_name
).when_matched_update_all().when_not_matched_insert_all().execute(records)
Expand Down Expand Up @@ -262,14 +267,20 @@ def get_table_schema(self, table_name: str) -> TArrowSchema:
schema,
)

def create_table(self, table_name: str, schema: TArrowSchema) -> Table:
@lancedb_error
def create_table(
self, table_name: str, schema: TArrowSchema, mode: str = "create"
) -> Table:
"""Create a LanceDB Table from the provided LanceModel or PyArrow schema.
Args:
schema: The table schema to create.
table_name: The name of the table to create.
mode (): The mode to use when creating the table. Can be either "create" or "overwrite".
By default, if the table already exists, an exception is raised.
If you want to overwrite the table, use mode="overwrite".
"""
return self.db_client.create_table(table_name, schema=schema)
return self.db_client.create_table(table_name, schema=schema, mode=mode)

def delete_table(self, table_name: str) -> None:
"""Delete a LanceDB table.
Expand Down Expand Up @@ -360,7 +371,9 @@ def update_stored_schema(
applied_update: TSchemaTables = {}

try:
schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash)
schema_info = self.get_stored_schema_by_hash(
self.schema.stored_version_hash
)
except DestinationUndefinedEntity:
schema_info = None

Expand Down Expand Up @@ -397,30 +410,52 @@ def get_storage_table(self, table_name: str) -> Tuple[bool, TTableSchemaColumns]
table_schema[field_type] = schema_c
return True, table_schema

def add_table_field(self, table_name: str, field_schema: pa.DataType) -> None:
@lancedb_error
def add_table_field(self, table_name: str, field_schema: TArrowField) -> Table:
"""Add a field to the LanceDB table.
Since arrow tables are immutable, this is done via a staging mechanism.
The data is stored in-memory in a staging arrow table, evolved then stored
written over the old table.
Args:
table_name: The name of the table to create the field on.
field_schema: The field to create.
"""
# TODO: Arrow tables are immutable.
# This is tricky without creating a new table.
# Perhaps my performing a merge this can work tbl.merge
raise NotImplementedError
# Open existing LanceDB table directly as PyArrow Table
arrow_table = self.db_client.open_table(table_name).to_arrow()

# Create an array of null values for the new column.
null_array = pa.nulls(len(arrow_table), type=field_schema.type)

# Create staging Table with new column appended.
stage = arrow_table.append_column(field_schema, null_array)

return self.db_client.create_table(table_name, stage, mode="overwrite")

def _execute_schema_update(self, only_tables: Iterable[str]) -> None:
for table_name in only_tables or self.schema.tables:
exists, existing_columns = self.get_storage_table(table_name)
new_columns = self.schema.get_new_table_columns(table_name, existing_columns)
logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}")
new_columns = self.schema.get_new_table_columns(
table_name, existing_columns
)
embedding_fields: List[str] = get_columns_names_with_prop(
self.schema.get_table(table_name), VECTORIZE_HINT
)
logger.info(
f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}"
)
if len(new_columns) > 0:
if exists:
for column in new_columns:
field_schema = make_arrow_field_schema(
column["name"], column, self.type_mapper
field_schema: TArrowField = make_arrow_field_schema(
column["name"], column, self.type_mapper, embedding_fields
)
self.add_table_field(table_name, field_schema)
fq_table_name = self.make_qualified_table_name(table_name)
self.add_table_field(fq_table_name, field_schema)
else:
embedding_fields = get_columns_names_with_prop(
self.schema.get_table(table_name=table_name), VECTORIZE_HINT
Expand Down Expand Up @@ -450,7 +485,9 @@ def update_schema_in_storage(self) -> None:
"schema": json.dumps(self.schema.to_dict()),
}
]
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)
write_disposition = self.schema.get_table(self.schema.version_table_name).get(
"write_disposition"
)
Expand All @@ -464,8 +501,12 @@ def update_schema_in_storage(self) -> None:
@lancedb_error
def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
"""Loads compressed state from destination storage by finding a load ID that was completed."""
fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name)
fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name)
fq_state_table_name = self.make_qualified_table_name(
self.schema.state_table_name
)
fq_loads_table_name = self.make_qualified_table_name(
self.schema.loads_table_name
)

state_records = (
self.db_client.open_table(fq_state_table_name)
Expand All @@ -489,8 +530,12 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
return None

@lancedb_error
def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]:
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)
def get_stored_schema_by_hash(
self, schema_hash: str
) -> Optional[StorageSchemaInfo]:
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)

try:
response = (
Expand All @@ -507,7 +552,9 @@ def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaI
@lancedb_error
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)

try:
response = (
Expand Down Expand Up @@ -542,7 +589,9 @@ def complete_load(self, load_id: str) -> None:
"inserted_at": str(pendulum.now()),
}
]
fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name)
fq_loads_table_name = self.make_qualified_table_name(
self.schema.loads_table_name
)
write_disposition = self.schema.get_table(self.schema.loads_table_name).get(
"write_disposition"
)
Expand All @@ -556,7 +605,9 @@ def complete_load(self, load_id: str) -> None:
def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
def start_file_load(
self, table: TTableSchema, file_path: str, load_id: str
) -> LoadJob:
return LoadLanceDBJob(
self.schema,
table,
Expand Down Expand Up @@ -600,7 +651,9 @@ def __init__(
self.table_name: str = table_schema["name"]
self.fq_table_name: str = fq_table_name
self.unique_identifiers: Sequence[str] = list_unique_identifiers(table_schema)
self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT)
self.embedding_fields: List[str] = get_columns_names_with_prop(
table_schema, VECTORIZE_HINT
)
self.embedding_model_func: TextEmbeddingFunction = model_func
self.embedding_model_dimensions: int = client_config.embedding_model_dimensions
self.id_field_name: str = client_config.id_field_name
Expand Down
17 changes: 13 additions & 4 deletions dlt/destinations/impl/lancedb/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,10 +30,17 @@ def make_arrow_field_schema(
column_name: str,
column: TColumnSchema,
type_mapper: TypeMapper,
embedding_model_func: Optional[TextEmbeddingFunction] = None,
embedding_fields: Optional[List[str]] = None,
) -> TArrowDataType:
raise NotImplementedError
) -> TArrowField:
"""Creates a PyArrow field from a dlt column schema."""
dtype = cast(TArrowDataType, type_mapper.to_db_type(column))

if embedding_fields and column_name in embedding_fields:
metadata = {"embedding_source": "true"}
else:
metadata = None

return pa.field(column_name, dtype, metadata=metadata)


def make_arrow_table_schema(
Expand All @@ -54,7 +61,9 @@ def make_arrow_table_schema(

if embedding_fields:
vec_size = embedding_model_dimensions or embedding_model_func.ndims()
arrow_schema.append(pa.field(vector_field_name, pa.list_(pa.float32(), vec_size)))
arrow_schema.append(
pa.field(vector_field_name, pa.list_(pa.float32(), vec_size))
)

for column_name, column in schema.get_table_columns(table_name).items():
dtype = cast(TArrowDataType, type_mapper.to_db_type(column))
Expand Down

0 comments on commit 719fcfb

Please sign in to comment.