Skip to content

Commit

Permalink
fix: parquet read options config class (#8)
Browse files Browse the repository at this point in the history
Co-authored-by: Stefan Verbruggen <[email protected]>
  • Loading branch information
sverbruggen and Stefan Verbruggen authored Aug 26, 2024
1 parent 13acb56 commit 75b8f6a
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 13 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,11 +58,16 @@ def load_input(
"""Loads the input as a Polars DataFrame or LazyFrame."""
metadata = context.metadata if context.metadata is not None else {}
date_format = extract_date_format_from_partition_definition(context)
parquet_read_options = (
context.resource_config.get("parquet_read_options", None)
if context.resource_config is not None
else None
)

parquet_read_options = None
if context.resource_config is not None:
parquet_read_options = context.resource_config.get("parquet_read_options", None)
parquet_read_options = (
ds.ParquetReadOptions(**parquet_read_options)
if parquet_read_options is not None
else None
)

dataset = _table_reader(
table_slice,
connection,
Expand Down
14 changes: 9 additions & 5 deletions libraries/dagster-delta/dagster_delta/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,11 +297,15 @@ def load_input(
connection: TableConnection,
) -> T:
"""Loads the input as a pyarrow Table or RecordBatchReader."""
parquet_read_options = (
context.resource_config.get("parquet_read_options", None)
if context.resource_config is not None
else None
)
parquet_read_options = None
if context.resource_config is not None:
parquet_read_options = context.resource_config.get("parquet_read_options", None)
parquet_read_options = (
ds.ParquetReadOptions(**parquet_read_options)
if parquet_read_options is not None
else None
)

dataset = _table_reader(table_slice, connection, parquet_read_options=parquet_read_options)

if context.dagster_type.typing_type == ds.Dataset:
Expand Down
7 changes: 5 additions & 2 deletions libraries/dagster-delta/dagster_delta/io_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
from enum import Enum
from typing import Any, Optional, TypedDict, Union, cast

import pyarrow.dataset as ds
from dagster import InputContext, OutputContext
from dagster._config.pythonic_config import ConfigurableIOManagerFactory
from dagster._core.definitions.time_window_partitions import TimeWindow
Expand Down Expand Up @@ -82,7 +81,7 @@ class _DeltaTableIOManagerResourceConfig(TypedDict):
table_config: NotRequired[dict[str, str]]
custom_metadata: NotRequired[dict[str, str]]
writer_properties: NotRequired[dict[str, str]]
parquet_read_options: NotRequired[ds.ParquetReadOptions]
parquet_read_options: NotRequired[dict[str, Any]]


class DeltaLakeIOManager(ConfigurableIOManagerFactory):
Expand Down Expand Up @@ -185,6 +184,10 @@ def my_table_a(my_table: pd.DataFrame):
default=None,
description="Writer properties passed to the rust engine writer.",
)
parquet_read_options: Optional[dict[str, Any]] = Field(
default=None,
description="Parquet read options, to be passed to pyarrow.",
)

@staticmethod
@abstractmethod
Expand Down
64 changes: 63 additions & 1 deletion libraries/dagster-delta/dagster_delta_tests/test_io_manager.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,22 @@
import os
from datetime import datetime

import pyarrow as pa
import pyarrow.compute as pc
import pytest
from dagster import TimeWindow
from dagster import (
Out,
TimeWindow,
graph,
op,
)
from dagster._core.storage.db_io_manager import TablePartitionDimension
from deltalake import DeltaTable
from deltalake.schema import Field, PrimitiveType, Schema

from dagster_delta import DeltaLakePyarrowIOManager, LocalConfig
from dagster_delta.handler import partition_dimensions_to_dnf
from dagster_delta.io_manager import WriteMode

TablePartitionDimension(
partitions=TimeWindow(datetime(2020, 1, 2), datetime(2020, 2, 3)),
Expand Down Expand Up @@ -41,3 +52,54 @@ def test_partition_dimensions_to_dnf(test_schema) -> None:
]
dnf = partition_dimensions_to_dnf(parts, test_schema, True)
assert dnf == [("date_col", "=", "2020-01-02")]


@op(out=Out(metadata={"schema": "a_df"}))
def a_pa_table() -> pa.Table:
return pa.Table.from_pydict({"a": [1, 2, 3], "b": [4, 5, 6]})


@op(out=Out(metadata={"schema": "add_one"}))
def add_one(table: pa.Table) -> pa.Table:
# Add one to each column in the table
updated_columns = [pc.add_checked(table[column], pa.scalar(1)) for column in table.column_names] # type: ignore
updated_table = pa.Table.from_arrays(updated_columns, names=table.column_names)
return updated_table


@graph
def add_one_to_dataset():
add_one(a_pa_table())


@pytest.fixture()
def io_manager_with_parquet_read_options(tmp_path) -> DeltaLakePyarrowIOManager:
return DeltaLakePyarrowIOManager(
root_uri=str(tmp_path),
storage_options=LocalConfig(),
mode=WriteMode.overwrite,
parquet_read_options={"coerce_int96_timestamp_unit": "us"},
)


def test_deltalake_io_manager_with_parquet_read_options(
tmp_path,
io_manager_with_parquet_read_options,
):
resource_defs = {"io_manager": io_manager_with_parquet_read_options}

job = add_one_to_dataset.to_job(resource_defs=resource_defs)

# run the job twice to ensure that tables get properly deleted
for _ in range(2):
res = job.execute_in_process()

assert res.success

dt = DeltaTable(os.path.join(tmp_path, "a_df/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [1, 2, 3]

dt = DeltaTable(os.path.join(tmp_path, "add_one/result"))
out_df = dt.to_pyarrow_table()
assert out_df["a"].to_pylist() == [2, 3, 4]

0 comments on commit 75b8f6a

Please sign in to comment.