Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable schema evolution for merge write disposition with delta table format #1742

Merged
merged 12 commits into from
Aug 27, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
77 changes: 67 additions & 10 deletions dlt/common/libs/deltalake.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,15 @@
from dlt.common import logger
from dlt.common.libs.pyarrow import pyarrow as pa
from dlt.common.libs.pyarrow import cast_arrow_schema_types
from dlt.common.schema.typing import TWriteDisposition
from dlt.common.schema.typing import TWriteDisposition, TTableSchema
from dlt.common.schema.utils import get_first_column_name_with_prop, get_columns_names_with_prop
from dlt.common.exceptions import MissingDependencyException
from dlt.common.storages import FilesystemConfiguration
from dlt.common.utils import assert_min_pkg_version
from dlt.destinations.impl.filesystem.filesystem import FilesystemClient

try:
import deltalake
from deltalake import write_deltalake, DeltaTable
from deltalake.writer import try_get_deltatable
except ModuleNotFoundError:
Expand Down Expand Up @@ -74,7 +76,7 @@ def write_delta_table(
partition_by: Optional[Union[List[str], str]] = None,
storage_options: Optional[Dict[str, str]] = None,
) -> None:
"""Writes in-memory Arrow table to on-disk Delta table.
"""Writes in-memory Arrow data to on-disk Delta table.

Thin wrapper around `deltalake.write_deltalake`.
"""
Expand All @@ -93,31 +95,73 @@ def write_delta_table(
)


def get_delta_tables(pipeline: Pipeline, *tables: str) -> Dict[str, DeltaTable]:
"""Returns Delta tables in `pipeline.default_schema` as `deltalake.DeltaTable` objects.
def merge_delta_table(
table: DeltaTable,
data: Union[pa.Table, pa.RecordBatchReader],
schema: TTableSchema,
) -> None:
"""Merges in-memory Arrow data into on-disk Delta table."""

strategy = schema["x-merge-strategy"] # type: ignore[typeddict-item]
if strategy == "upsert":
# `DeltaTable.merge` does not support automatic schema evolution
# https://github.com/delta-io/delta-rs/issues/2282
_evolve_delta_table_schema(table, data.schema)

if "parent" in schema:
unique_column = get_first_column_name_with_prop(schema, "unique")
predicate = f"target.{unique_column} = source.{unique_column}"
else:
primary_keys = get_columns_names_with_prop(schema, "primary_key")
predicate = " AND ".join([f"target.{c} = source.{c}" for c in primary_keys])

qry = (
table.merge(
source=ensure_delta_compatible_arrow_data(data),
predicate=predicate,
source_alias="source",
target_alias="target",
)
.when_matched_update_all()
.when_not_matched_insert_all()
)

qry.execute()
else:
ValueError(f'Merge strategy "{strategy}" not supported.')


def get_delta_tables(
pipeline: Pipeline, *tables: str, schema_name: str = None
) -> Dict[str, DeltaTable]:
"""Returns Delta tables in `pipeline.default_schema (default)` as `deltalake.DeltaTable` objects.

Returned object is a dictionary with table names as keys and `DeltaTable` objects as values.
Optionally filters dictionary by table names specified as `*tables*`.
Raises ValueError if table name specified as `*tables` is not found.
Raises ValueError if table name specified as `*tables` is not found. You may try to switch to other
schemas via `schema_name` argument.
"""
from dlt.common.schema.utils import get_table_format

with pipeline.destination_client() as client:
with pipeline.destination_client(schema_name=schema_name) as client:
assert isinstance(
client, FilesystemClient
), "The `get_delta_tables` function requires a `filesystem` destination."

schema_delta_tables = [
t["name"]
for t in pipeline.default_schema.tables.values()
if get_table_format(pipeline.default_schema.tables, t["name"]) == "delta"
for t in client.schema.tables.values()
if get_table_format(client.schema.tables, t["name"]) == "delta"
]
if len(tables) > 0:
invalid_tables = set(tables) - set(schema_delta_tables)
if len(invalid_tables) > 0:
available_schemas = ""
if len(pipeline.schema_names) > 1:
available_schemas = f" Available schemas are {pipeline.schema_names}"
raise ValueError(
"Schema does not contain Delta tables with these names: "
f"{', '.join(invalid_tables)}."
f"Schema {client.schema.name} does not contain Delta tables with these names: "
f"{', '.join(invalid_tables)}.{available_schemas}"
)
schema_delta_tables = [t for t in schema_delta_tables if t in tables]
table_dirs = client.get_table_dirs(schema_delta_tables, remote=True)
Expand Down Expand Up @@ -145,3 +189,16 @@ def _deltalake_storage_options(config: FilesystemConfiguration) -> Dict[str, str
+ ". dlt will use the values in `deltalake_storage_options`."
)
return {**creds, **extra_options}


def _evolve_delta_table_schema(delta_table: DeltaTable, arrow_schema: pa.Schema) -> None:
"""Evolves `delta_table` schema if different from `arrow_schema`.

Adds column(s) to `delta_table` present in `arrow_schema` but not in `delta_table`.
"""
new_fields = [
deltalake.Field.from_pyarrow(field)
for field in ensure_delta_compatible_arrow_schema(arrow_schema)
if field not in delta_table.to_pyarrow_dataset().schema
]
delta_table.alter.add_columns(new_fields)
3 changes: 3 additions & 0 deletions dlt/destinations/fs_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,12 @@
from abc import ABC, abstractmethod
from fsspec import AbstractFileSystem

from dlt.common.schema import Schema


class FSClientBase(ABC):
fs_client: AbstractFileSystem
schema: Schema

@property
@abstractmethod
Expand Down
166 changes: 79 additions & 87 deletions dlt/destinations/impl/filesystem/filesystem.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import base64

from types import TracebackType
from typing import ClassVar, List, Type, Iterable, Iterator, Optional, Tuple, Sequence, cast
from typing import Dict, List, Type, Iterable, Iterator, Optional, Tuple, Sequence, cast
from fsspec import AbstractFileSystem
from contextlib import contextmanager

Expand All @@ -13,7 +13,7 @@
from dlt.common.storages.fsspec_filesystem import glob_files
from dlt.common.typing import DictStrAny
from dlt.common.schema import Schema, TSchemaTables, TTableSchema
from dlt.common.schema.utils import get_first_column_name_with_prop, get_columns_names_with_prop
from dlt.common.schema.utils import get_columns_names_with_prop
from dlt.common.storages import FileStorage, fsspec_from_config
from dlt.common.storages.load_package import (
LoadJobInfo,
Expand Down Expand Up @@ -56,36 +56,36 @@ def __init__(
self._job_client: FilesystemClient = None

def run(self) -> None:
# pick local filesystem pathlib or posix for buckets
self.is_local_filesystem = self._job_client.config.protocol == "file"
self.pathlib = os.path if self.is_local_filesystem else posixpath

self.destination_file_name = path_utils.create_path(
self._job_client.config.layout,
self._file_name,
self._job_client.schema.name,
self._load_id,
current_datetime=self._job_client.config.current_datetime,
load_package_timestamp=dlt.current.load_package()["state"]["created_at"],
extra_placeholders=self._job_client.config.extra_placeholders,
)
self.__is_local_filesystem = self._job_client.config.protocol == "file"
# We would like to avoid failing for local filesystem where
# deeply nested directory will not exist before writing a file.
# It `auto_mkdir` is disabled by default in fsspec so we made some
# trade offs between different options and decided on this.
# remote_path = f"{client.config.protocol}://{posixpath.join(dataset_path, destination_file_name)}"
remote_path = self.make_remote_path()
if self.is_local_filesystem:
self._job_client.fs_client.makedirs(self.pathlib.dirname(remote_path), exist_ok=True)
if self.__is_local_filesystem:
# use os.path for local file name
self._job_client.fs_client.makedirs(os.path.dirname(remote_path), exist_ok=True)
self._job_client.fs_client.put_file(self._file_path, remote_path)

def make_remote_path(self) -> str:
"""Returns path on the remote filesystem to which copy the file, without scheme. For local filesystem a native path is used"""
destination_file_name = path_utils.create_path(
self._job_client.config.layout,
self._file_name,
self._job_client.schema.name,
self._load_id,
current_datetime=self._job_client.config.current_datetime,
load_package_timestamp=dlt.current.load_package()["state"]["created_at"],
extra_placeholders=self._job_client.config.extra_placeholders,
)
# pick local filesystem pathlib or posix for buckets
pathlib = os.path if self.__is_local_filesystem else posixpath
# path.join does not normalize separators and available
# normalization functions are very invasive and may string the trailing separator
return self.pathlib.join( # type: ignore[no-any-return]
return pathlib.join( # type: ignore[no-any-return]
self._job_client.dataset_path,
path_utils.normalize_path_sep(self.pathlib, self.destination_file_name),
path_utils.normalize_path_sep(pathlib, destination_file_name),
)

def make_remote_uri(self) -> str:
Expand All @@ -98,89 +98,81 @@ def metrics(self) -> Optional[LoadJobMetrics]:

class DeltaLoadFilesystemJob(FilesystemLoadJob):
def __init__(self, file_path: str) -> None:
super().__init__(
file_path=file_path,
)

def run(self) -> None:
# pick local filesystem pathlib or posix for buckets
# TODO: since we pass _job_client via run_managed and not set_env_vars it is hard
# to write a handler with those two line below only in FilesystemLoadJob
self.is_local_filesystem = self._job_client.config.protocol == "file"
self.pathlib = os.path if self.is_local_filesystem else posixpath
self.destination_file_name = self._job_client.make_remote_uri(
self._job_client.get_table_dir(self.load_table_name)
)
super().__init__(file_path=file_path)

# create Arrow dataset from Parquet files
from dlt.common.libs.pyarrow import pyarrow as pa
from dlt.common.libs.deltalake import (
DeltaTable,
write_delta_table,
ensure_delta_compatible_arrow_schema,
_deltalake_storage_options,
try_get_deltatable,
)

# create Arrow dataset from Parquet files
file_paths = ReferenceFollowupJobRequest.resolve_references(self._file_path)
arrow_ds = pa.dataset.dataset(file_paths)
self.file_paths = ReferenceFollowupJobRequest.resolve_references(self._file_path)
self.arrow_ds = pa.dataset.dataset(self.file_paths)

# create Delta table object
def make_remote_path(self) -> str:
# remote path is table dir - delta will create its file structure inside it
return self._job_client.get_table_dir(self.load_table_name)

storage_options = _deltalake_storage_options(self._job_client.config)
dt = try_get_deltatable(self.destination_file_name, storage_options=storage_options)
def run(self) -> None:
logger.info(f"Will copy file(s) {self.file_paths} to delta table {self.make_remote_uri()}")

# get partition columns
part_cols = get_columns_names_with_prop(self._load_table, "partition")
from dlt.common.libs.deltalake import write_delta_table, merge_delta_table

# explicitly check if there is data
# (https://github.com/delta-io/delta-rs/issues/2686)
if arrow_ds.head(1).num_rows == 0:
if dt is None:
# create new empty Delta table with schema from Arrow table
DeltaTable.create(
table_uri=self.destination_file_name,
schema=ensure_delta_compatible_arrow_schema(arrow_ds.schema),
mode="overwrite",
partition_by=part_cols,
storage_options=storage_options,
)
if self.arrow_ds.head(1).num_rows == 0:
self._create_or_evolve_delta_table()
return

arrow_rbr = arrow_ds.scanner().to_reader() # RecordBatchReader

if self._load_table["write_disposition"] == "merge" and dt is not None:
assert self._load_table["x-merge-strategy"] in self._job_client.capabilities.supported_merge_strategies # type: ignore[typeddict-item]

if self._load_table["x-merge-strategy"] == "upsert": # type: ignore[typeddict-item]
if "parent" in self._load_table:
unique_column = get_first_column_name_with_prop(self._load_table, "unique")
predicate = f"target.{unique_column} = source.{unique_column}"
else:
primary_keys = get_columns_names_with_prop(self._load_table, "primary_key")
predicate = " AND ".join([f"target.{c} = source.{c}" for c in primary_keys])

qry = (
dt.merge(
source=arrow_rbr,
predicate=predicate,
source_alias="source",
target_alias="target",
)
.when_matched_update_all()
.when_not_matched_insert_all()
with self.arrow_ds.scanner().to_reader() as arrow_rbr: # RecordBatchReader
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Curious why you inserted a with context here. Is it because arrow_rbr gets exhausted and is effectively useless after the context?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it has a close method... so it has internal unmanaged resources that we should free ASAP. otherwise garbage collector does it way later

if self._load_table["write_disposition"] == "merge" and self._delta_table is not None:
assert self._load_table["x-merge-strategy"] in self._job_client.capabilities.supported_merge_strategies # type: ignore[typeddict-item]
merge_delta_table(
table=self._delta_table,
data=arrow_rbr,
schema=self._load_table,
)
else:
write_delta_table(
table_or_uri=(
self.make_remote_uri() if self._delta_table is None else self._delta_table
),
data=arrow_rbr,
write_disposition=self._load_table["write_disposition"],
partition_by=self._partition_columns,
storage_options=self._storage_options,
)

qry.execute()
@property
def _storage_options(self) -> Dict[str, str]:
from dlt.common.libs.deltalake import _deltalake_storage_options

return _deltalake_storage_options(self._job_client.config)

else:
write_delta_table(
table_or_uri=self.destination_file_name if dt is None else dt,
data=arrow_rbr,
write_disposition=self._load_table["write_disposition"],
partition_by=part_cols,
storage_options=storage_options,
@property
def _delta_table(self) -> Optional["DeltaTable"]: # type: ignore[name-defined] # noqa: F821
from dlt.common.libs.deltalake import try_get_deltatable

return try_get_deltatable(self.make_remote_uri(), storage_options=self._storage_options)

@property
def _partition_columns(self) -> List[str]:
return get_columns_names_with_prop(self._load_table, "partition")

def _create_or_evolve_delta_table(self) -> None:
from dlt.common.libs.deltalake import (
DeltaTable,
ensure_delta_compatible_arrow_schema,
_evolve_delta_table_schema,
)

if self._delta_table is None:
DeltaTable.create(
table_uri=self.make_remote_uri(),
schema=ensure_delta_compatible_arrow_schema(self.arrow_ds.schema),
mode="overwrite",
partition_by=self._partition_columns,
storage_options=self._storage_options,
)
else:
_evolve_delta_table_schema(self._delta_table, self.arrow_ds.schema)


class FilesystemLoadJobWithFollowup(HasFollowupJobs, FilesystemLoadJob):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
from typing import List, Dict, Any, Generator
import dlt


# Define a dlt resource with write disposition to 'merge'
@dlt.resource(name="parent_with_children", write_disposition={"disposition": "merge"})
def data_source() -> Generator[List[Dict[str, Any]], None, None]:
Expand All @@ -44,13 +45,15 @@ def data_source() -> Generator[List[Dict[str, Any]], None, None]:

yield data


# Function to add parent_id to each child record within a parent record
def add_parent_id(record: Dict[str, Any]) -> Dict[str, Any]:
parent_id_key = "parent_id"
for child in record["children"]:
child[parent_id_key] = record[parent_id_key]
return record


if __name__ == "__main__":
# Create and configure the dlt pipeline
pipeline = dlt.pipeline(
Expand All @@ -60,10 +63,6 @@ def add_parent_id(record: Dict[str, Any]) -> Dict[str, Any]:
)

# Run the pipeline
load_info = pipeline.run(
data_source()
.add_map(add_parent_id),
primary_key="parent_id"
)
load_info = pipeline.run(data_source().add_map(add_parent_id), primary_key="parent_id")
# Output the load information after pipeline execution
print(load_info)
Loading