Skip to content

Commit

Permalink
Refactor lancedb_client.py for improved code readability
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Jun 15, 2024
1 parent 2751699 commit 348825f
Showing 1 changed file with 20 additions and 60 deletions.
80 changes: 20 additions & 60 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,9 +73,7 @@


TIMESTAMP_PRECISION_TO_UNIT: Dict[int, str] = {0: "s", 3: "ms", 6: "us", 9: "ns"}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {
v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()
}
UNIT_TO_TIMESTAMP_PRECISION: Dict[str, int] = {v: k for k, v in TIMESTAMP_PRECISION_TO_UNIT.items()}


class LanceDBTypeMapper(TypeMapper):
Expand Down Expand Up @@ -181,9 +179,7 @@ def upload_batch(
tbl.add(records, mode="overwrite")
elif write_disposition == "merge":
if not id_field_name:
raise ValueError(
"To perform a merge update, 'id_field_name' must be specified."
)
raise ValueError("To perform a merge update, 'id_field_name' must be specified.")
tbl.merge_insert(
id_field_name
).when_matched_update_all().when_not_matched_insert_all().execute(records)
Expand Down Expand Up @@ -255,9 +251,7 @@ def get_table_schema(self, table_name: str) -> TArrowSchema:
)

@lancedb_error
def create_table(
self, table_name: str, schema: TArrowSchema, mode: str = "create"
) -> Table:
def create_table(self, table_name: str, schema: TArrowSchema, mode: str = "create") -> Table:
"""Create a LanceDB Table from the provided LanceModel or PyArrow schema.
Args:
Expand Down Expand Up @@ -360,9 +354,7 @@ def update_stored_schema(
applied_update: TSchemaTables = {}

try:
schema_info = self.get_stored_schema_by_hash(
self.schema.stored_version_hash
)
schema_info = self.get_stored_schema_by_hash(self.schema.stored_version_hash)
except DestinationUndefinedEntity:
schema_info = None

Expand Down Expand Up @@ -411,41 +403,31 @@ def add_table_fields(

# Check if any of the new fields already exist in the table.
existing_fields = set(arrow_table.schema.names)
new_fields = [
field for field in field_schemas if field.name not in existing_fields
]
new_fields = [field for field in field_schemas if field.name not in existing_fields]

if not new_fields:
# All fields already present, skip.
return None

null_arrays = [
pa.nulls(len(arrow_table), type=field.type) for field in new_fields
]
null_arrays = [pa.nulls(len(arrow_table), type=field.type) for field in new_fields]

for field, null_array in zip(new_fields, null_arrays):
arrow_table = arrow_table.append_column(field, null_array)

try:
return self.db_client.create_table(
table_name, arrow_table, mode="overwrite"
)
return self.db_client.create_table(table_name, arrow_table, mode="overwrite")
except OSError:
# Error occurred while creating the table, skip.
return None

def _execute_schema_update(self, only_tables: Iterable[str]) -> None:
for table_name in only_tables or self.schema.tables:
exists, existing_columns = self.get_storage_table(table_name)
new_columns = self.schema.get_new_table_columns(
table_name, existing_columns
)
new_columns = self.schema.get_new_table_columns(table_name, existing_columns)
embedding_fields: List[str] = get_columns_names_with_prop(
self.schema.get_table(table_name), VECTORIZE_HINT
)
logger.info(
f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}"
)
logger.info(f"Found {len(new_columns)} updates for {table_name} in {self.schema.name}")
if len(new_columns) > 0:
if exists:
field_schemas: List[TArrowField] = [
Expand All @@ -464,9 +446,7 @@ def _execute_schema_update(self, only_tables: Iterable[str]) -> None:
vector_field_name = self.vector_field_name
id_field_name = self.id_field_name
embedding_model_func = self.model_func
embedding_model_dimensions = (
self.config.embedding_model_dimensions
)
embedding_model_dimensions = self.config.embedding_model_dimensions
else:
embedding_fields = None
vector_field_name = None
Expand Down Expand Up @@ -501,9 +481,7 @@ def update_schema_in_storage(self) -> None:
"schema": json.dumps(self.schema.to_dict()),
}
]
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)
write_disposition = self.schema.get_table(self.schema.version_table_name).get(
"write_disposition"
)
Expand All @@ -517,12 +495,8 @@ def update_schema_in_storage(self) -> None:
@lancedb_error
def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
"""Loads compressed state from destination storage by finding a load ID that was completed."""
fq_state_table_name = self.make_qualified_table_name(
self.schema.state_table_name
)
fq_loads_table_name = self.make_qualified_table_name(
self.schema.loads_table_name
)
fq_state_table_name = self.make_qualified_table_name(self.schema.state_table_name)
fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name)

state_records = (
self.db_client.open_table(fq_state_table_name)
Expand All @@ -543,18 +517,12 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
.to_list()
):
state["dlt_load_id"] = state.pop("_dlt_load_id")
return StateInfo(
**{k: v for k, v in state.items() if k in StateInfo._fields}
)
return StateInfo(**{k: v for k, v in state.items() if k in StateInfo._fields})
return None

@lancedb_error
def get_stored_schema_by_hash(
self, schema_hash: str
) -> Optional[StorageSchemaInfo]:
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)
def get_stored_schema_by_hash(self, schema_hash: str) -> Optional[StorageSchemaInfo]:
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)

try:
response = (
Expand All @@ -571,9 +539,7 @@ def get_stored_schema_by_hash(
@lancedb_error
def get_stored_schema(self) -> Optional[StorageSchemaInfo]:
"""Retrieves newest schema from destination storage."""
fq_version_table_name = self.make_qualified_table_name(
self.schema.version_table_name
)
fq_version_table_name = self.make_qualified_table_name(self.schema.version_table_name)

try:
response = (
Expand Down Expand Up @@ -609,9 +575,7 @@ def complete_load(self, load_id: str) -> None:
"schema_version_hash": None, # Payload schema must match the target schema.
}
]
fq_loads_table_name = self.make_qualified_table_name(
self.schema.loads_table_name
)
fq_loads_table_name = self.make_qualified_table_name(self.schema.loads_table_name)
write_disposition = self.schema.get_table(self.schema.loads_table_name).get(
"write_disposition"
)
Expand All @@ -625,9 +589,7 @@ def complete_load(self, load_id: str) -> None:
def restore_file_load(self, file_path: str) -> LoadJob:
return EmptyLoadJob.from_file_path(file_path, "completed")

def start_file_load(
self, table: TTableSchema, file_path: str, load_id: str
) -> LoadJob:
def start_file_load(self, table: TTableSchema, file_path: str, load_id: str) -> LoadJob:
return LoadLanceDBJob(
self.schema,
table,
Expand Down Expand Up @@ -666,9 +628,7 @@ def __init__(
self.table_name: str = table_schema["name"]
self.fq_table_name: str = fq_table_name
self.unique_identifiers: Sequence[str] = list_unique_identifiers(table_schema)
self.embedding_fields: List[str] = get_columns_names_with_prop(
table_schema, VECTORIZE_HINT
)
self.embedding_fields: List[str] = get_columns_names_with_prop(table_schema, VECTORIZE_HINT)
self.embedding_model_func: TextEmbeddingFunction = model_func
self.embedding_model_dimensions: int = client_config.embedding_model_dimensions
self.id_field_name: str = client_config.id_field_name
Expand Down

0 comments on commit 348825f

Please sign in to comment.