Skip to content

Commit

Permalink
add support for column schema in certion query cases
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Dec 6, 2024
1 parent 3886638 commit 289e289
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 97 deletions.
2 changes: 1 addition & 1 deletion dlt/destinations/dataset/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ def table(self, table_name: str) -> SupportsReadableRelation:
from dlt.destinations.dataset.ibis_relation import ReadableIbisRelation

unbound_table = create_unbound_ibis_table(self.sql_client, self.schema, table_name)
return ReadableIbisRelation(readable_dataset=self, ibis_object=unbound_table) # type: ignore[abstract]
return ReadableIbisRelation(readable_dataset=self, ibis_object=unbound_table, columns_schema=self.schema.tables[table_name]["columns"]) # type: ignore[abstract]
except MissingDependencyException:
# if ibis is explicitly requested, reraise
if self._dataset_type == "ibis":
Expand Down
35 changes: 31 additions & 4 deletions dlt/destinations/dataset/ibis_relation.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,17 +47,18 @@
}


# TODO: provide ibis expression typing for the readable relation
class ReadableIbisRelation(BaseReadableDBAPIRelation):
def __init__(
self,
*,
readable_dataset: ReadableDBAPIDataset,
ibis_object: Any = None,
columns_schema: TTableSchemaColumns = None,
) -> None:
"""Create a lazy evaluated relation to for the dataset of a destination"""
super().__init__(readable_dataset=readable_dataset)
self._ibis_object = ibis_object
self._columns_schema = columns_schema

@property
def query(self) -> Any:
Expand Down Expand Up @@ -89,7 +90,7 @@ def columns_schema(self, new_value: TTableSchemaColumns) -> None:
def compute_columns_schema(self) -> TTableSchemaColumns:
"""provide schema columns for the cursor, may be filtered by selected columns"""
# TODO: provide column lineage tracing with sqlglot lineage
return None
return self._columns_schema

def _proxy_expression_method(self, method_name: str, *args: Any, **kwargs: Any) -> Any:
"""Proxy method calls to the underlying ibis expression, allowing to wrap the resulting expression in a new relation"""
Expand Down Expand Up @@ -119,8 +120,18 @@ def _proxy_expression_method(self, method_name: str, *args: Any, **kwargs: Any)
# Call it with provided args
result = method(*args, **kwargs)

# calculate columns schema for the result, some operations we know will not change the schema
# and select will just reduce the amount of column
columns_schema = None
if method_name == "select":
columns_schema = self._get_filtered_columns_schema(args)
elif method_name in ["filter", "limit", "order_by", "head"]:
columns_schema = self._columns_schema

# If result is an ibis expression, wrap it in a new relation else return raw result
return self.__class__(readable_dataset=self._dataset, ibis_object=result)
return self.__class__(
readable_dataset=self._dataset, ibis_object=result, columns_schema=columns_schema
)

def __getattr__(self, name: str) -> Any:
"""Wrap all callable attributes of the expression"""
Expand All @@ -136,15 +147,31 @@ def __getattr__(self, name: str) -> Any:
raise AttributeError(f"'{self.__class__.__name__}' object has no attribute '{name}'")

if not callable(attr):
# NOTE: we don't need to forward columns schema for non-callable attributes, these are usually columns
return self.__class__(readable_dataset=self._dataset, ibis_object=attr)

return partial(self._proxy_expression_method, name)

def __getitem__(self, columns: Union[str, Sequence[str]]) -> "ReadableIbisRelation":
# casefold column-names
columns = [columns] if isinstance(columns, str) else columns
columns = [self.sql_client.capabilities.casefold_identifier(col) for col in columns]
expr = self._ibis_object[columns]
return self.__class__(readable_dataset=self._dataset, ibis_object=expr)
return self.__class__(
readable_dataset=self._dataset,
ibis_object=expr,
columns_schema=self._get_filtered_columns_schema(columns),
)

def _get_filtered_columns_schema(self, columns: Sequence[str]) -> TTableSchemaColumns:
if not self._columns_schema:
return None
try:
return {col: self._columns_schema[col] for col in columns}
except KeyError:
# NOTE: select statements can contain new columns not present in the original schema
# here we just break the column schema inheritance chain
return None

# forward ibis methods defined on interface
def limit(self, limit: int, **kwargs: Any) -> "ReadableIbisRelation":
Expand Down
191 changes: 99 additions & 92 deletions tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import Any, cast
from typing import Any, cast, Tuple, List
import re
import pytest
import dlt
Expand Down Expand Up @@ -493,142 +493,149 @@ def test_ibis_expression_relation(populated_pipeline: Pipeline) -> None:
return

# we check a bunch of expressions without executing them to see that they produce correct sql
def sql_from_expr(expr: Any) -> str:
# also we return the keys of the disovered schema columns
def sql_from_expr(expr: Any) -> Tuple[str, List[str]]:
query = str(expr.query).replace(populated_pipeline.dataset_name, "dataset")
return re.sub(r"\s+", " ", query)
columns = list(expr.columns_schema.keys()) if expr.columns_schema else None
return re.sub(r"\s+", " ", query), columns

# test all functions discussed here: https://ibis-project.org/tutorials/ibis-for-sql-users
ALL_COLUMNS = ["id", "decimal", "other_decimal", "_dlt_load_id", "_dlt_id"]

# selecting two columns
assert (
sql_from_expr(items_table.select("id", "decimal"))
== 'SELECT "t0"."id", "t0"."decimal" FROM "dataset"."items" AS "t0"'
assert sql_from_expr(items_table.select("id", "decimal")) == (
'SELECT "t0"."id", "t0"."decimal" FROM "dataset"."items" AS "t0"',
["id", "decimal"],
)

# selecting all columns
assert sql_from_expr(items_table) == ('SELECT * FROM "dataset"."items"', ALL_COLUMNS)

# selecting two other columns via item getter
assert sql_from_expr(items_table["id", "decimal"]) == (
'SELECT "t0"."id", "t0"."decimal" FROM "dataset"."items" AS "t0"',
["id", "decimal"],
)

# adding a new columns
new_col = (items_table.id * 2).name("new_col")
assert (
sql_from_expr(items_table.select("id", "decimal", new_col))
== 'SELECT "t0"."id", "t0"."decimal", "t0"."id" * 2 AS "new_col" FROM "dataset"."items" AS'
' "t0"'
assert sql_from_expr(items_table.select("id", "decimal", new_col)) == (
(
'SELECT "t0"."id", "t0"."decimal", "t0"."id" * 2 AS "new_col" FROM'
' "dataset"."items" AS "t0"'
),
None,
)

# mutating table (add a new column computed from existing columns)
assert (
sql_from_expr(items_table.mutate(double_id=items_table.id * 2).select("id", "double_id"))
== 'SELECT "t0"."id", "t0"."id" * 2 AS "double_id" FROM "dataset"."items" AS "t0"'
assert sql_from_expr(
items_table.mutate(double_id=items_table.id * 2).select("id", "double_id")
) == (
'SELECT "t0"."id", "t0"."id" * 2 AS "double_id" FROM "dataset"."items" AS "t0"',
None,
)

# mutating table add new static column
assert (
sql_from_expr(
items_table.mutate(new_col=ibis.literal("static_value")).select("id", "new_col")
)
== 'SELECT "t0"."id", \'static_value\' AS "new_col" FROM "dataset"."items" AS "t0"'
)

# check filtering
assert (
sql_from_expr(items_table.filter(items_table.id < 10))
== 'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10'
assert sql_from_expr(
items_table.mutate(new_col=ibis.literal("static_value")).select("id", "new_col")
) == ('SELECT "t0"."id", \'static_value\' AS "new_col" FROM "dataset"."items" AS "t0"', None)

# check filtering (preserves all columns)
assert sql_from_expr(items_table.filter(items_table.id < 10)) == (
'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10',
ALL_COLUMNS,
)

# filtering and selecting a single column
assert (
sql_from_expr(items_table.filter(items_table.id < 10).select("id"))
== 'SELECT "t0"."id" FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10'
assert sql_from_expr(items_table.filter(items_table.id < 10).select("id")) == (
'SELECT "t0"."id" FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10',
["id"],
)

# check filter and
assert (
sql_from_expr(items_table.filter(items_table.id < 10).filter(items_table.id > 5))
== 'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10 AND "t0"."id" > 5'
# check filter "and" condition
assert sql_from_expr(items_table.filter(items_table.id < 10).filter(items_table.id > 5)) == (
'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."id" < 10 AND "t0"."id" > 5',
ALL_COLUMNS,
)

# check filter or
assert (
sql_from_expr(items_table.filter((items_table.id < 10) | (items_table.id > 5)))
== 'SELECT * FROM "dataset"."items" AS "t0" WHERE ( "t0"."id" < 10 ) OR ( "t0"."id" > 5 )'
# check filter "or" condition
assert sql_from_expr(items_table.filter((items_table.id < 10) | (items_table.id > 5))) == (
'SELECT * FROM "dataset"."items" AS "t0" WHERE ( "t0"."id" < 10 ) OR ( "t0"."id" > 5 )',
ALL_COLUMNS,
)

# check group by and aggregate
assert (
sql_from_expr(
items_table.group_by("id")
.having(items_table.count() >= 1000)
.aggregate(sum_id=items_table.id.sum())
)
== 'SELECT "t1"."id", "t1"."sum_id" FROM ( SELECT "t0"."id", SUM("t0"."id") AS "sum_id",'
' COUNT(*) AS "CountStar(items)" FROM "dataset"."items" AS "t0" GROUP BY 1 ) AS "t1"'
' WHERE "t1"."CountStar(items)" >= 1000'
assert sql_from_expr(
items_table.group_by("id")
.having(items_table.count() >= 1000)
.aggregate(sum_id=items_table.id.sum())
) == (
(
'SELECT "t1"."id", "t1"."sum_id" FROM ( SELECT "t0"."id", SUM("t0"."id") AS "sum_id",'
' COUNT(*) AS "CountStar(items)" FROM "dataset"."items" AS "t0" GROUP BY 1 ) AS "t1"'
' WHERE "t1"."CountStar(items)" >= 1000'
),
None,
)

# sorting and ordering
assert (
sql_from_expr(items_table.order_by("id", "decimal").limit(10))
== 'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" ASC, "t0"."decimal" ASC'
" LIMIT 10"
assert sql_from_expr(items_table.order_by("id", "decimal").limit(10)) == (
(
'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" ASC, "t0"."decimal" ASC'
" LIMIT 10"
),
ALL_COLUMNS,
)

# sort desc and asc
assert (
sql_from_expr(items_table.order_by(ibis.desc("id"), ibis.asc("decimal")).limit(10))
== 'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" DESC, "t0"."decimal" ASC'
" LIMIT 10"
assert sql_from_expr(items_table.order_by(ibis.desc("id"), ibis.asc("decimal")).limit(10)) == (
(
'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" DESC, "t0"."decimal" ASC'
" LIMIT 10"
),
ALL_COLUMNS,
)

# offset and limit
assert (
sql_from_expr(items_table.order_by("id").limit(10, offset=5))
== 'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" ASC LIMIT 10 OFFSET 5'
assert sql_from_expr(items_table.order_by("id").limit(10, offset=5)) == (
'SELECT * FROM "dataset"."items" AS "t0" ORDER BY "t0"."id" ASC LIMIT 10 OFFSET 5',
ALL_COLUMNS,
)

# join
assert (
sql_from_expr(
items_table.join(double_items_table, items_table.id == double_items_table.id)[
["id", "double_id"]
]
)
== 'SELECT "t2"."id", "t3"."double_id" FROM "dataset"."items" AS "t2" INNER JOIN'
' "dataset"."double_items" AS "t3" ON "t2"."id" = "t3"."id"'
assert sql_from_expr(
items_table.join(double_items_table, items_table.id == double_items_table.id)[
["id", "double_id"]
]
) == (
(
'SELECT "t2"."id", "t3"."double_id" FROM "dataset"."items" AS "t2" INNER JOIN'
' "dataset"."double_items" AS "t3" ON "t2"."id" = "t3"."id"'
),
None,
)

# subqueries
assert (
sql_from_expr(items_table.filter(items_table.decimal.isin(double_items_table.di_decimal)))
== 'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."decimal" IN ( SELECT'
' "t1"."di_decimal" FROM "dataset"."double_items" AS "t1" )'
assert sql_from_expr(
items_table.filter(items_table.decimal.isin(double_items_table.di_decimal))
) == (
(
'SELECT * FROM "dataset"."items" AS "t0" WHERE "t0"."decimal" IN ( SELECT'
' "t1"."di_decimal" FROM "dataset"."double_items" AS "t1" )'
),
ALL_COLUMNS,
)

# topk
assert (
sql_from_expr(items_table.decimal.topk(10))
== 'SELECT * FROM ( SELECT "t0"."decimal", COUNT(*) AS "CountStar(items)" FROM'
' "dataset"."items" AS "t0" GROUP BY 1 ) AS "t1" ORDER BY "t1"."CountStar(items)" DESC'
" LIMIT 10"
assert sql_from_expr(items_table.decimal.topk(10)) == (
(
'SELECT * FROM ( SELECT "t0"."decimal", COUNT(*) AS "CountStar(items)" FROM'
' "dataset"."items" AS "t0" GROUP BY 1 ) AS "t1" ORDER BY "t1"."CountStar(items)" DESC'
" LIMIT 10"
),
None,
)

# NOTE: here we test that dlt column type resolution still works
# re-enable this when lineage is implemented
# expected_decimal_precision = 10
# expected_decimal_precision_2 = 12
# expected_decimal_precision_di = 7
# if populated_pipeline.destination.destination_type == "dlt.destinations.bigquery":
# # bigquery does not allow precision configuration..
# expected_decimal_precision = 38
# expected_decimal_precision_2 = 38
# expected_decimal_precision_di = 38

# joined_table = items_table.join(double_items_table, items_table.id == double_items_table.id)[
# ["decimal", "other_decimal", "di_decimal"]
# ].rename(decimal_renamed="di_decimal").limit(20)
# table = joined_table.arrow()
# print(joined_table.compute_columns_schema(force=True))
# assert table.schema.field("decimal").type.precision == expected_decimal_precision
# assert table.schema.field("other_decimal").type.precision == expected_decimal_precision_2
# assert table.schema.field("di_decimal").type.precision == expected_decimal_precision_di


@pytest.mark.no_load
@pytest.mark.essential
Expand Down

0 comments on commit 289e289

Please sign in to comment.