Skip to content

Commit

Permalink
feat: support new PyIceberg table IO properties and custom IOConfig i…
Browse files Browse the repository at this point in the history
…n write_iceberg
  • Loading branch information
kevinzwang committed Dec 21, 2024
1 parent 1c0f780 commit 69b0649
Show file tree
Hide file tree
Showing 3 changed files with 56 additions and 80 deletions.
14 changes: 12 additions & 2 deletions daft/dataframe/dataframe.py
Original file line number Diff line number Diff line change
Expand Up @@ -686,7 +686,9 @@ def write_csv(
)

@DataframePublicAPI
def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") -> "DataFrame":
def write_iceberg(
self, table: "pyiceberg.table.Table", mode: str = "append", io_config: Optional[IOConfig] = None
) -> "DataFrame":
"""Writes the DataFrame to an `Iceberg <https://iceberg.apache.org/docs/nightly/>`__ table, returning a new DataFrame with the operations that occurred.
Can be run in either `append` or `overwrite` mode which will either appends the rows in the DataFrame or will delete the existing rows and then append the DataFrame rows respectively.
Expand All @@ -697,6 +699,7 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") ->
Args:
table (pyiceberg.table.Table): Destination `PyIceberg Table <https://py.iceberg.apache.org/reference/pyiceberg/table/#pyiceberg.table.Table>`__ to write dataframe to.
mode (str, optional): Operation mode of the write. `append` or `overwrite` Iceberg Table. Defaults to "append".
io_config (IOConfig, optional): A custom IOConfig to use when accessing Iceberg object storage data. If provided, configurations set in `table` are ignored.
Returns:
DataFrame: The operations that occurred with this write.
Expand All @@ -705,6 +708,8 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") ->
import pyiceberg
from packaging.version import parse

from daft.io._iceberg import _convert_iceberg_file_io_properties_to_io_config

if len(table.spec().fields) > 0 and parse(pyiceberg.__version__) < parse("0.7.0"):
raise ValueError("pyiceberg>=0.7.0 is required to write to a partitioned table")

Expand All @@ -719,12 +724,17 @@ def write_iceberg(self, table: "pyiceberg.table.Table", mode: str = "append") ->
if mode not in ["append", "overwrite"]:
raise ValueError(f"Only support `append` or `overwrite` mode. {mode} is unsupported")

io_config = (
_convert_iceberg_file_io_properties_to_io_config(table.io.properties) if io_config is None else io_config
)
io_config = get_context().daft_planning_config.default_io_config if io_config is None else io_config

operations = []
path = []
rows = []
size = []

builder = self._builder.write_iceberg(table)
builder = self._builder.write_iceberg(table, io_config)
write_df = DataFrame(builder)
write_df.collect()

Expand Down
118 changes: 43 additions & 75 deletions daft/io/_iceberg.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,82 +9,52 @@
from daft.logical.builder import LogicalPlanBuilder

if TYPE_CHECKING:
from pyiceberg.table import Table as PyIcebergTable
# from pyiceberg.table import Table as PyIcebergTable
import pyiceberg

Check warning on line 13 in daft/io/_iceberg.py

View check run for this annotation

Codecov / codecov/patch

daft/io/_iceberg.py#L13

Added line #L13 was not covered by tests


def _convert_iceberg_file_io_properties_to_io_config(props: Dict[str, Any]) -> Optional["IOConfig"]:
import pyiceberg
from packaging.version import parse
from pyiceberg.io import (
S3_ACCESS_KEY_ID,
S3_ENDPOINT,
S3_REGION,
S3_SECRET_ACCESS_KEY,
S3_SESSION_TOKEN,
)

"""Property keys defined here: https://github.com/apache/iceberg-python/blob/main/pyiceberg/io/__init__.py."""
from daft.io import AzureConfig, GCSConfig, IOConfig, S3Config

s3_mapping = {
S3_REGION: "region_name",
S3_ENDPOINT: "endpoint_url",
S3_ACCESS_KEY_ID: "key_id",
S3_SECRET_ACCESS_KEY: "access_key",
S3_SESSION_TOKEN: "session_token",
}
s3_args = dict() # type: ignore
for pyiceberg_key, daft_key in s3_mapping.items():
value = props.get(pyiceberg_key, None)
if value is not None:
s3_args[daft_key] = value

if len(s3_args) > 0:
s3_config = S3Config(**s3_args)
else:
s3_config = None

gcs_config = None
azure_config = None
if parse(pyiceberg.__version__) >= parse("0.5.0"):
from pyiceberg.io import GCS_PROJECT_ID, GCS_TOKEN

gcs_mapping = {GCS_PROJECT_ID: "project_id", GCS_TOKEN: "token"}
gcs_args = dict() # type: ignore
for pyiceberg_key, daft_key in gcs_mapping.items():
value = props.get(pyiceberg_key, None)
if value is not None:
gcs_args[daft_key] = value

if len(gcs_args) > 0:
gcs_config = GCSConfig(**gcs_args)

azure_mapping = {
"adlfs.account-name": "storage_account",
"adlfs.account-key": "access_key",
"adlfs.sas-token": "sas_token",
"adlfs.tenant-id": "tenant_id",
"adlfs.client-id": "client_id",
"adlfs.client-secret": "client_secret",
}

azure_args = dict() # type: ignore
for pyiceberg_key, daft_key in azure_mapping.items():
value = props.get(pyiceberg_key, None)
if value is not None:
azure_args[daft_key] = value

if len(azure_args) > 0:
azure_config = AzureConfig(**azure_args)

if any([s3_config, gcs_config, azure_config]):
return IOConfig(s3=s3_config, gcs=gcs_config, azure=azure_config)
else:
any_props_set = False

def get_first_property_value(*property_names: str) -> Optional[Any]:
for property_name in property_names:
if property_value := props.get(property_name):
nonlocal any_props_set
any_props_set = True
return property_value

Check warning on line 27 in daft/io/_iceberg.py

View check run for this annotation

Codecov / codecov/patch

daft/io/_iceberg.py#L26-L27

Added lines #L26 - L27 were not covered by tests
return None

io_config = IOConfig(
s3=S3Config(
endpoint_url=props.get("s3.endpoint"),
region_name=get_first_property_value("s3.region", "client.region"),
key_id=get_first_property_value("s3.access-key-id", "client.access-key-id"),
access_key=get_first_property_value("s3.secret-access-key", "client.secret-access-key"),
session_token=get_first_property_value("s3.session-token", "client.session-token"),
),
azure=AzureConfig(
storage_account=get_first_property_value("adls.account-name", "adlfs.account-name"),
access_key=get_first_property_value("adls.account-key", "adlfs.account-key"),
sas_token=get_first_property_value("adls.sas-token", "adlfs.sas-token"),
tenant_id=get_first_property_value("adls.tenant-id", "adlfs.tenant-id"),
client_id=get_first_property_value("adls.client-id", "adlfs.client-id"),
client_secret=get_first_property_value("adls.client-secret", "adlfs.client-secret"),
),
gcs=GCSConfig(
project_id=props.get("gcs.project-id"),
token=props.get("gcs.oauth2.token"),
),
)

return io_config if any_props_set else None


@PublicAPI
def read_iceberg(
pyiceberg_table: "PyIcebergTable",
table: "pyiceberg.table.Table",
snapshot_id: Optional[int] = None,
io_config: Optional["IOConfig"] = None,
) -> DataFrame:
Expand All @@ -93,8 +63,8 @@ def read_iceberg(
Example:
>>> import pyiceberg
>>>
>>> pyiceberg_table = pyiceberg.Table(...)
>>> df = daft.read_iceberg(pyiceberg_table)
>>> table = pyiceberg.Table(...)
>>> df = daft.read_iceberg(table)
>>>
>>> # Filters on this dataframe can now be pushed into
>>> # the read operation from Iceberg
Expand All @@ -106,26 +76,24 @@ def read_iceberg(
official project for Python.
Args:
pyiceberg_table: Iceberg table created using the PyIceberg library
snapshot_id: Snapshot ID of the table to query
io_config: A custom IOConfig to use when accessing Iceberg object storage data. Defaults to None.
table (pyiceberg.table.Table): `PyIceberg Table <https://py.iceberg.apache.org/reference/pyiceberg/table/#pyiceberg.table.Table>`__ created using the PyIceberg library
snapshot_id (int, optional): Snapshot ID of the table to query
io_config (IOConfig, optional): A custom IOConfig to use when accessing Iceberg object storage data. If provided, configurations set in `table` are ignored.
Returns:
DataFrame: a DataFrame with the schema converted from the specified Iceberg table
"""
from daft.iceberg.iceberg_scan import IcebergScanOperator

io_config = (
_convert_iceberg_file_io_properties_to_io_config(pyiceberg_table.io.properties)
if io_config is None
else io_config
_convert_iceberg_file_io_properties_to_io_config(table.io.properties) if io_config is None else io_config
)
io_config = context.get_context().daft_planning_config.default_io_config if io_config is None else io_config

multithreaded_io = context.get_context().get_or_create_runner().name != "ray"
storage_config = StorageConfig(multithreaded_io, io_config)

iceberg_operator = IcebergScanOperator(pyiceberg_table, snapshot_id=snapshot_id, storage_config=storage_config)
iceberg_operator = IcebergScanOperator(table, snapshot_id=snapshot_id, storage_config=storage_config)

handle = ScanOperatorHandle.from_python_scan_operator(iceberg_operator)
builder = LogicalPlanBuilder.from_tabular_scan(scan_operator=handle)
Expand Down
4 changes: 1 addition & 3 deletions daft/logical/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,9 +297,8 @@ def write_tabular(
builder = self._builder.table_write(str(root_dir), file_format, part_cols_pyexprs, compression, io_config)
return LogicalPlanBuilder(builder)

def write_iceberg(self, table: IcebergTable) -> LogicalPlanBuilder:
def write_iceberg(self, table: IcebergTable, io_config: IOConfig) -> LogicalPlanBuilder:
from daft.iceberg.iceberg_write import get_missing_columns, partition_field_to_expr
from daft.io._iceberg import _convert_iceberg_file_io_properties_to_io_config

name = ".".join(table.name())
location = f"{table.location()}/data"
Expand All @@ -314,7 +313,6 @@ def write_iceberg(self, table: IcebergTable) -> LogicalPlanBuilder:
partition_cols = [partition_field_to_expr(field, schema)._expr for field in partition_spec.fields]
props = table.properties
columns = [col.name for col in schema.columns]
io_config = _convert_iceberg_file_io_properties_to_io_config(table.io.properties)
builder = builder.iceberg_write(
name, location, partition_spec.spec_id, partition_cols, schema, props, columns, io_config
)
Expand Down

0 comments on commit 69b0649

Please sign in to comment.