Skip to content

Commit

Permalink
first version of ibis dataset base transformations
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Dec 19, 2024
1 parent 5a1cb69 commit cfcb5a1
Show file tree
Hide file tree
Showing 5 changed files with 148 additions and 0 deletions.
4 changes: 4 additions & 0 deletions dlt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,8 @@
from dlt.pipeline import progress
from dlt import destinations

from dlt.destinations.transformations import transformation, transformation_group

pipeline = _pipeline
current = _current
mark = _mark
Expand Down Expand Up @@ -79,6 +81,8 @@
"TCredentials",
"sources",
"destinations",
"transformation",
"transformation_group",
]

# verify that no injection context was created
Expand Down
2 changes: 2 additions & 0 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
Union,
List,
ContextManager,
runtime_checkable,
Dict,
Any,
TypeVar,
Expand Down Expand Up @@ -483,6 +484,7 @@ def create_followup_jobs(self, final_state: TLoadJobState) -> List[FollowupJobRe
return []


@runtime_checkable
class SupportsReadableRelation(Protocol):
"""A readable relation retrieved from a destination that supports it"""

Expand Down
94 changes: 94 additions & 0 deletions dlt/destinations/transformations/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
from typing import Callable, Literal, Union, Any, Generator, List, TYPE_CHECKING, Iterable

from dataclasses import dataclass
from functools import wraps

from dlt.common.destination.reference import SupportsReadableDataset, SupportsReadableRelation


TTransformationMaterialization = Literal["table", "view"]
TTransformationWriteDisposition = Literal["replace", "append"]

TTransformationFunc = Callable[[SupportsReadableDataset], SupportsReadableRelation]

TTransformationGroupFunc = Callable[[], List[TTransformationFunc]]


def transformation(
table_name: str,
materialization: TTransformationMaterialization = "table",
write_disposition: TTransformationWriteDisposition = "replace",
) -> Callable[[TTransformationFunc], TTransformationFunc]:
def decorator(func: TTransformationFunc) -> TTransformationFunc:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> SupportsReadableRelation:
return func(*args, **kwargs)

# save the arguments to the function
wrapper.__transformation_args__ = { # type: ignore
"table_name": table_name,
"materialization": materialization,
"write_disposition": write_disposition,
}

return wrapper

return decorator


def transformation_group(
name: str,
) -> Callable[[TTransformationGroupFunc], TTransformationGroupFunc]:
def decorator(func: TTransformationGroupFunc) -> TTransformationGroupFunc:
@wraps(func)
def wrapper(*args: Any, **kwargs: Any) -> List[TTransformationFunc]:
return func(*args, **kwargs)

func.__transformation_group_args__ = { # type: ignore
"name": name,
}
return wrapper

return decorator


def run_transformations(
dataset: SupportsReadableDataset,
transformations: Union[TTransformationFunc, List[TTransformationFunc]],
) -> None:
if not isinstance(transformations, Iterable):
transformations = [transformations]

# TODO: fix typing
with dataset.sql_client as client: # type: ignore
for transformation in transformations:
# get transformation settings
table_name = transformation.__transformation_args__["table_name"] # type: ignore
materialization = transformation.__transformation_args__["materialization"] # type: ignore
write_disposition = transformation.__transformation_args__["write_disposition"] # type: ignore
table_name = client.make_qualified_table_name(table_name)

# get relation from transformation
relation = transformation(dataset)
if not isinstance(relation, SupportsReadableRelation):
raise ValueError(
f"Transformation {transformation.__name__} did not return a ReadableRelation"
)

# materialize result
select_clause = relation.query

if write_disposition == "replace":
client.execute(
f"CREATE OR REPLACE {materialization} {table_name} AS {select_clause}"
)
elif write_disposition == "append" and materialization == "table":
try:
client.execute(f"INSERT INTO {table_name} {select_clause}")
except Exception:
client.execute(f"CREATE TABLE {table_name} AS {select_clause}")
else:
raise ValueError(
f"Write disposition {write_disposition} is not supported for "
f"materialization {materialization}"
)
9 changes: 9 additions & 0 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,8 @@
from dlt.common.storages.load_package import TLoadPackageState
from dlt.pipeline.helpers import refresh_source

from dlt.destinations.transformations import TTransformationFunc


def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]:
def decorator(f: TFun) -> TFun:
Expand Down Expand Up @@ -1770,3 +1772,10 @@ def dataset(
schema=schema,
dataset_type=dataset_type,
)

def transform(
self, transformations: Union[TTransformationFunc, List[TTransformationFunc]]
) -> None:
from dlt.destinations.transformations import run_transformations

run_transformations(self.dataset(), transformations)
39 changes: 39 additions & 0 deletions tests/load/test_transformation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
import dlt

from dlt.common.destination.reference import SupportsReadableDataset, SupportsReadableRelation

from functools import reduce


def test_simple_transformation() -> None:
# load some stuff into items table

@dlt.resource(table_name="items")
def items_resource():
for i in range(10):
yield {"id": i, "value": i * 2}

p = dlt.pipeline("test_pipeline", destination="duckdb", dataset_name="test_dataset")
p.run(items_resource)

print(p.dataset().items.df())

@dlt.transformation(table_name="quadrupled_items")
def simple_transformation(dataset: SupportsReadableDataset) -> SupportsReadableRelation:
items_table = dataset.items
return items_table.mutate(quadruple_id=items_table.id * 4).select("id", "quadruple_id")

@dlt.transformation(table_name="aggregated_items")
def aggregate_transformation(dataset: SupportsReadableDataset) -> SupportsReadableRelation:
items_table = dataset.items
return items_table.aggregate(sum_id=items_table.id.sum(), value_sum=items_table.value.sum())

# we run two transformations
p.transform([simple_transformation, aggregate_transformation])

# check table with quadrupled ids
assert list(p.dataset().quadrupled_items.df()["quadruple_id"]) == [i * 4 for i in range(10)]

# check aggregated table for both fields
assert p.dataset().aggregated_items.fetchone()[0] == reduce(lambda a, b: a + b, range(10))
assert p.dataset().aggregated_items.fetchone()[1] == (reduce(lambda a, b: a + b, range(10)) * 2)

0 comments on commit cfcb5a1

Please sign in to comment.