Skip to content

Commit

Permalink
move to relations based interface
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Aug 6, 2024
1 parent ac415b9 commit c92a527
Show file tree
Hide file tree
Showing 3 changed files with 59 additions and 39 deletions.
22 changes: 11 additions & 11 deletions composable_pipeline_1.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,8 @@ def o():
# run and print result
print("RUNNING WAREHOUSE INGESTION")
print(duck_pipeline.run([c(), o()]))
print(duck_pipeline.dataset.df(table="customers"))
print(duck_pipeline.dataset.df(table="orders"))
print(duck_pipeline.dataset.customers.df())
print(duck_pipeline.dataset.orders.df())
print("===========================")

#
Expand All @@ -67,22 +67,22 @@ def o():

print("RUNNING LOCAL SNAPSHOT EXTRACTION")
lake_pipeline.run(
duck_pipeline.dataset.iter_df(table="customers"),
duck_pipeline.dataset.customers.iter_df(),
loader_file_format="jsonl",
table_name="customers",
write_disposition="replace",
)
lake_pipeline.run(
duck_pipeline.dataset.iter_df(
sql="SELECT * FROM orders WHERE orders.order_day = 'tuesday'"
),
duck_pipeline.dataset.sql(
"SELECT * FROM orders WHERE orders.order_day = 'tuesday'"
).iter_df(),
loader_file_format="jsonl",
table_name="orders",
write_disposition="replace",
)

print(lake_pipeline.dataset.df(table="customers"))
print(lake_pipeline.dataset.df(table="orders"))
print(lake_pipeline.dataset.customers.df())
print(lake_pipeline.dataset.orders.df())
print("===========================")

#
Expand All @@ -95,15 +95,15 @@ def o():
)

denom_pipeline.run(
lake_pipeline.dataset.iter_df(
lake_pipeline.dataset.sql(
sql=(
"SELECT orders.*, customers.name FROM orders LEFT JOIN customers ON"
" orders.customer_id = customers.id"
),
prepare_tables=["customers", "orders"],
),
).iter_df(),
loader_file_format="jsonl",
table_name="customers",
write_disposition="replace",
)
print(denom_pipeline.dataset.df(table="customers"))
print(denom_pipeline.dataset.customers.df())
64 changes: 42 additions & 22 deletions dlt/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,20 +8,23 @@
from dlt.common.typing import DataFrame, ArrowTable


class Dataset:
"""Access to dataframes and arrowtables in the destination dataset"""

def __init__(self, pipeline: Any) -> None:
class Relation:
def __init__(
self, *, pipeline: Any, table: str = None, sql: str = None, prepare_tables: List[str] = None
) -> None:
"""Create a lazy evaluated relation to for the dataset of a pipeline"""
from dlt.pipeline import Pipeline

self.pipeline: Pipeline = cast(Pipeline, pipeline)
self.prepare_tables = prepare_tables
self.sql = sql
self.table = table

@contextmanager
def _client(self) -> Generator[SupportsDataAccess, None, None]:
from dlt.destinations.job_client_impl import SqlJobClientBase
from dlt.destinations.fs_client import FSClientBase

"""Get SupportsDataAccess destination object"""
client = self.pipeline.destination_client()

if isinstance(client, SqlJobClientBase):
Expand All @@ -33,62 +36,79 @@ def _client(self) -> Generator[SupportsDataAccess, None, None]:
yield client
return

raise Exception("Destination does not support data access.")
raise Exception(
f"Destination {client.config.destination_type} does not support data access via"
" dataset."
)

def df(
self,
*,
table: str,
batch_size: int = 1000,
sql: str = None,
prepare_tables: List[str] = None
) -> DataFrame:
"""Get first batch of table as dataframe"""
return next(
self.iter_df(sql=sql, table=table, batch_size=batch_size, prepare_tables=prepare_tables)
self.iter_df(
batch_size=batch_size,
)
)

def arrow(
self,
*,
table: str,
batch_size: int = 1000,
sql: str = None,
prepare_tables: List[str] = None
) -> ArrowTable:
"""Get first batch of table as arrow table"""
return next(
self.iter_arrow(
sql=sql, table=table, batch_size=batch_size, prepare_tables=prepare_tables
batch_size=batch_size,
)
)

def iter_df(
self,
*,
table: str,
batch_size: int = 1000,
sql: str = None,
prepare_tables: List[str] = None
) -> Generator[DataFrame, None, None]:
"""iterates over the whole table in dataframes of the given batch_size, batch_size of -1 will return the full table in the first batch"""
# if no table is given, take the bound table
with self._client() as data_access:
yield from data_access.iter_df(
sql=sql, table=table, batch_size=batch_size, prepare_tables=prepare_tables
sql=self.sql,
table=self.table,
batch_size=batch_size,
prepare_tables=self.prepare_tables,
)

def iter_arrow(
self,
*,
table: str,
batch_size: int = 1000,
sql: str = None,
prepare_tables: List[str] = None
) -> Generator[ArrowTable, None, None]:
"""iterates over the whole table in arrow tables of the given batch_size, batch_size of -1 will return the full table in the first batch"""
# if no table is given, take the bound table
with self._client() as data_access:
yield from data_access.iter_arrow(
sql=sql, table=table, batch_size=batch_size, prepare_tables=prepare_tables
sql=self.sql,
table=self.table,
batch_size=batch_size,
prepare_tables=self.prepare_tables,
)


class Dataset:
"""Access to dataframes and arrowtables in the destination dataset"""

def __init__(self, pipeline: Any) -> None:
from dlt.pipeline import Pipeline

self.pipeline: Pipeline = cast(Pipeline, pipeline)

def sql(self, sql: str, prepare_tables: List[str] = None) -> Relation:
return Relation(pipeline=self.pipeline, sql=sql, prepare_tables=prepare_tables)

def __getitem__(self, table: str) -> Relation:
return Relation(pipeline=self.pipeline, table=table)

def __getattr__(self, table: str) -> Relation:
return Relation(pipeline=self.pipeline, table=table)
12 changes: 6 additions & 6 deletions tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,13 +40,13 @@ def test_read_interfaces_sql(destination_config: DestinationTestConfiguration) -
)

# get one df
df = pipeline.dataset.df(table="items", batch_size=5)
df = pipeline.dataset["items"].df(batch_size=5)
assert len(df.index) == 5
assert set(df.columns.values) == {"id", "_dlt_load_id", "_dlt_id"}

# iterate all dataframes
frames = []
for df in pipeline.dataset.iter_df(table="items", batch_size=70):
for df in pipeline.dataset["items"].iter_df(batch_size=70):
frames.append(df)

# check frame amount and items counts
Expand All @@ -58,7 +58,7 @@ def test_read_interfaces_sql(destination_config: DestinationTestConfiguration) -
assert set(ids) == set(range(300))

# basic check of arrow table
table = pipeline.dataset.arrow(table="items", batch_size=5)
table = pipeline.dataset.items.arrow(batch_size=5)
assert set(table.column_names) == {"id", "_dlt_load_id", "_dlt_id"}
assert table.num_rows == 5

Expand Down Expand Up @@ -93,13 +93,13 @@ def test_read_interfaces_filesystem(destination_config: DestinationTestConfigura
pipeline.run(s, loader_file_format=destination_config.file_format)

# get one df
df = pipeline.dataset.df(table="items", batch_size=5)
df = pipeline.dataset["items"].df(batch_size=5)
assert len(df.index) == 5
assert set(df.columns.values) == {"id", "_dlt_load_id", "_dlt_id"}

# iterate all dataframes
frames = []
for df in pipeline.dataset.iter_df(table="items", batch_size=70):
for df in pipeline.dataset.items.iter_df(batch_size=70):
frames.append(df)

# check frame amount and items counts
Expand All @@ -111,6 +111,6 @@ def test_read_interfaces_filesystem(destination_config: DestinationTestConfigura
assert set(ids) == set(range(300))

# basic check of arrow table
table = pipeline.dataset.arrow(table="items", batch_size=5)
table = pipeline.dataset["items"].arrow(batch_size=5)
assert set(table.column_names) == {"id", "_dlt_load_id", "_dlt_id"}
assert table.num_rows == 5

0 comments on commit c92a527

Please sign in to comment.