Skip to content

Commit

Permalink
Update to new load interface
Browse files Browse the repository at this point in the history
Signed-off-by: Marcel Coetzee <[email protected]>
  • Loading branch information
Pipboyguy committed Aug 7, 2024
1 parent 8e74815 commit c8f7468
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 26 deletions.
45 changes: 20 additions & 25 deletions dlt/destinations/impl/lancedb/lancedb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,6 @@
RunnableLoadJob,
StorageSchemaInfo,
StateInfo,
TLoadJobState,
NewLoadJob,
FollowupJob,
LoadJob,
)
Expand Down Expand Up @@ -79,8 +77,7 @@
set_non_standard_providers_environment_variables,
generate_arrow_uuid_column,
)
from dlt.destinations.job_impl import FinalizedLoadJobWithFollowupJobs
from dlt.destinations.job_impl import EmptyLoadJob, NewLoadJobImpl
from dlt.destinations.job_impl import FollowupJobImpl
from dlt.destinations.type_mapping import TypeMapper

if TYPE_CHECKING:
Expand Down Expand Up @@ -153,7 +150,7 @@ def from_db_type(
)
if isinstance(db_type, pa.Decimal128Type):
precision, scale = db_type.precision, db_type.scale
if (precision, scale)==self.capabilities.wei_precision:
if (precision, scale) == self.capabilities.wei_precision:
return cast(TColumnType, dict(data_type="wei"))
return dict(data_type="decimal", precision=precision, scale=scale)
return super().from_db_type(cast(str, db_type), precision, scale)
Expand Down Expand Up @@ -193,9 +190,9 @@ def write_to_db(
try:
if write_disposition in ("append", "skip"):
tbl.add(records)
elif write_disposition=="replace":
elif write_disposition == "replace":
tbl.add(records, mode="overwrite")
elif write_disposition=="merge":
elif write_disposition == "merge":
if not id_field_name:
raise ValueError("To perform a merge update, 'id_field_name' must be specified.")
tbl.merge_insert(
Expand Down Expand Up @@ -244,7 +241,7 @@ def __init__(
self.config.credentials.embedding_model_provider_api_key,
)
# Use the monkey-patched implementation if openai was chosen.
if embedding_model_provider=="openai":
if embedding_model_provider == "openai":
from dlt.destinations.impl.lancedb.models import PatchedOpenAIEmbeddings

self.model_func = PatchedOpenAIEmbeddings(
Expand Down Expand Up @@ -339,7 +336,7 @@ def _get_table_names(self) -> List[str]:
else:
table_names = self.db_client.table_names()

return [table_name for table_name in table_names if table_name!=self.sentinel_table]
return [table_name for table_name in table_names if table_name != self.sentinel_table]

@lancedb_error
def drop_storage(self) -> None:
Expand Down Expand Up @@ -578,7 +575,7 @@ def get_stored_state(self, pipeline_name: str) -> Optional[StateInfo]:
loads_table, keys=p_dlt_load_id, right_keys=p_load_id, join_type="inner"
).sort_by([(p_dlt_load_id, "descending")])

if joined_table.num_rows==0:
if joined_table.num_rows == 0:
return None

state = joined_table.take([0]).to_pylist()[0]
Expand Down Expand Up @@ -711,7 +708,7 @@ def create_table_chain_completed_followup_jobs(
self,
table_chain: Sequence[TTableSchema],
completed_table_chain_jobs: Optional[Sequence[LoadJobInfo]] = None,
) -> List[NewLoadJob]:
) -> List[FollowupJob]:
assert completed_table_chain_jobs is not None
jobs = super().create_table_chain_completed_followup_jobs(
table_chain, completed_table_chain_jobs
Expand All @@ -722,7 +719,7 @@ def create_table_chain_completed_followup_jobs(
continue

# Only tables with merge disposition are dispatched for orphan removal jobs.
if table.get("write_disposition")=="merge":
if table.get("write_disposition") == "merge":
parent_table = table.get("parent")
jobs.append(
LanceDBRemoveOrphansJob(
Expand All @@ -742,7 +739,7 @@ def table_exists(self, table_name: str) -> bool:
return table_name in self.db_client.table_names()


class LanceDBLoadJob(RunnableLoadJob, FollowupJob):
class LanceDBLoadJob(RunnableLoadJob):
arrow_schema: TArrowSchema

def __init__(
Expand All @@ -765,7 +762,9 @@ def run(self) -> None:
self._embedding_model_dimensions: int = self._job_client.config.embedding_model_dimensions
self._id_field_name: str = self._job_client.config.id_field_name

unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema(self._load_table)
unique_identifiers: Sequence[str] = get_unique_identifiers_from_table_schema(
self._load_table
)
write_disposition: TWriteDisposition = cast(
TWriteDisposition, self._load_table.get("write_disposition", "append")
)
Expand All @@ -776,9 +775,9 @@ def run(self) -> None:
if self._load_table not in self._schema.dlt_tables():
arrow_table = generate_arrow_uuid_column(
arrow_table,
unique_identifiers=self.unique_identifiers,
table_name=self.fq_table_name,
id_field_name=self.id_field_name,
unique_identifiers=unique_identifiers,
table_name=self._fq_table_name,
id_field_name=self._id_field_name,
)

write_to_db(
Expand All @@ -790,7 +789,7 @@ def run(self) -> None:
)


class LanceDBRemoveOrphansJob(NewLoadJobImpl):
class LanceDBRemoveOrphansJob(FollowupJobImpl):
def __init__(
self,
db_client: DBConnection,
Expand All @@ -817,7 +816,6 @@ def __init__(

super().__init__(
file_name=job_id,
status="running",
)

self._save_text_file("")
Expand All @@ -827,7 +825,7 @@ def __init__(
def execute(self) -> None:
orphaned_ids: Set[str]

if self.write_disposition!="merge":
if self.write_disposition != "merge":
raise DestinationTerminalException(
f"Unsupported write disposition {self.write_disposition} for LanceDB Destination"
" Orphan Removal Job - failed AND WILL **NOT** BE RETRIED."
Expand Down Expand Up @@ -866,7 +864,7 @@ def execute(self) -> None:
if orphaned_ids := child_ids - parent_ids:
if len(orphaned_ids) > 1:
child_table.delete(f"_dlt_parent_id IN {tuple(orphaned_ids)}")
elif len(orphaned_ids)==1:
elif len(orphaned_ids) == 1:
child_table.delete(f"_dlt_parent_id = '{orphaned_ids.pop()}'")

else:
Expand Down Expand Up @@ -898,13 +896,10 @@ def execute(self) -> None:

if len(orphaned_ids) > 1:
child_table.delete(f"_dlt_id IN {tuple(orphaned_ids)}")
elif len(orphaned_ids)==1:
elif len(orphaned_ids) == 1:
child_table.delete(f"_dlt_id = '{orphaned_ids.pop()}'")

except ArrowInvalid as e:
raise DestinationTerminalException(
"Python and Arrow datatype mismatch - batch failed AND WILL **NOT** BE RETRIED."
) from e

def state(self) -> TLoadJobState:
return "completed"
2 changes: 1 addition & 1 deletion dlt/destinations/impl/lancedb/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

# TODO: Update `generate_arrow_uuid_column` when pyarrow 17.0.0 becomes available with vectorized operations (batched + memory-mapped)
def generate_arrow_uuid_column(
table: pa.Table, unique_identifiers: List[str], id_field_name: str, table_name: str
table: pa.Table, unique_identifiers: Sequence[str], id_field_name: str, table_name: str
) -> pa.Table:
"""Generates deterministic UUID - used for deduplication, returning a new arrow
table with added UUID column.
Expand Down

0 comments on commit c8f7468

Please sign in to comment.