Skip to content

Commit

Permalink
make pipeline dataset factory public
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Dec 9, 2024
1 parent eda6ad3 commit fae7a2b
Show file tree
Hide file tree
Showing 14 changed files with 44 additions and 45 deletions.
1 change: 0 additions & 1 deletion dlt/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
"TCredentials",
"sources",
"destinations",
"_dataset",
]

# verify that no injection context was created
Expand Down
2 changes: 1 addition & 1 deletion dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1750,7 +1750,7 @@ def __getstate__(self) -> Any:
# pickle only the SupportsPipeline protocol fields
return {"pipeline_name": self.pipeline_name}

def _dataset(
def dataset(
self, schema: Union[Schema, str, None] = None, dataset_type: TDatasetType = "dbapi"
) -> SupportsReadableDataset:
"""Access helper to dataset"""
Expand Down
4 changes: 2 additions & 2 deletions docs/website/docs/general-usage/dataset-access/dataset.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ Here's a full example of how to retrieve data from a pipeline and load it into a
# and you have loaded data to a table named 'items' in the destination

# Step 1: Get the readable dataset from the pipeline
dataset = pipeline._dataset()
dataset = pipeline.dataset()

# Step 2: Access a table as a ReadableRelation
items_relation = dataset.items # Or dataset["items"]
Expand All @@ -39,7 +39,7 @@ Assuming you have a `Pipeline` object (let's call it `pipeline`), you can obtain

```py
# Get the readable dataset from the pipeline
dataset = pipeline._dataset()
dataset = pipeline.dataset()
```

### Access tables as `ReadableRelation`
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ pip install ibis-framework[duckdb]

```py
# get the dataset from the pipeline
dataset = pipeline._dataset()
dataset = pipeline.dataset()
dataset_name = pipeline.dataset_name

# get the native ibis connection from the dataset
Expand Down
8 changes: 4 additions & 4 deletions tests/destinations/test_readable_dbapi_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@


def test_query_builder() -> None:
dataset = dlt.pipeline(destination="duckdb", pipeline_name="pipeline")._dataset()
dataset = dlt.pipeline(destination="duckdb", pipeline_name="pipeline").dataset()

# default query for a table
assert dataset.my_table.query.strip() == 'SELECT * FROM "pipeline_dataset"."my_table"' # type: ignore[attr-defined]
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_query_builder() -> None:


def test_copy_and_chaining() -> None:
dataset = dlt.pipeline(destination="duckdb", pipeline_name="pipeline")._dataset()
dataset = dlt.pipeline(destination="duckdb", pipeline_name="pipeline").dataset()

# create releation and set some stuff on it
relation = dataset.items
Expand All @@ -80,7 +80,7 @@ def test_copy_and_chaining() -> None:


def test_computed_schema_columns() -> None:
dataset = dlt.pipeline(destination="duckdb", pipeline_name="pipeline")._dataset()
dataset = dlt.pipeline(destination="duckdb", pipeline_name="pipeline").dataset()
relation = dataset.items

# no schema present
Expand All @@ -107,7 +107,7 @@ def test_computed_schema_columns() -> None:


def test_prevent_changing_relation_with_query() -> None:
dataset = dlt.pipeline(destination="duckdb", pipeline_name="pipeline")._dataset()
dataset = dlt.pipeline(destination="duckdb", pipeline_name="pipeline").dataset()
relation = dataset("SELECT * FROM something")

with pytest.raises(ReadableRelationHasQueryException):
Expand Down
8 changes: 4 additions & 4 deletions tests/extract/test_incremental.py
Original file line number Diff line number Diff line change
Expand Up @@ -228,7 +228,7 @@ def test_pandas_index_as_dedup_key() -> None:
no_index_r = some_data.with_name(new_name="no_index")
p.run(no_index_r)
p.run(no_index_r)
data_ = p._dataset().no_index.arrow()
data_ = p.dataset().no_index.arrow()
assert data_.schema.names == ["created_at", "id"]
assert data_["id"].to_pylist() == ["a", "b", "c", "d", "e", "f", "g"]

Expand All @@ -240,7 +240,7 @@ def test_pandas_index_as_dedup_key() -> None:
unnamed_index_r.incremental.primary_key = "__index_level_0__"
p.run(unnamed_index_r)
p.run(unnamed_index_r)
data_ = p._dataset().unnamed_index.arrow()
data_ = p.dataset().unnamed_index.arrow()
assert data_.schema.names == ["created_at", "id", "index_level_0"]
# indexes 2 and 3 are removed from second batch because they were in the previous batch
# and the created_at overlapped so they got deduplicated
Expand All @@ -258,7 +258,7 @@ def _make_named_index(df_: pd.DataFrame) -> pd.DataFrame:
named_index_r.incremental.primary_key = "order_id"
p.run(named_index_r)
p.run(named_index_r)
data_ = p._dataset().named_index.arrow()
data_ = p.dataset().named_index.arrow()
assert data_.schema.names == ["created_at", "id", "order_id"]
assert data_["order_id"].to_pylist() == [0, 1, 2, 3, 4, 0, 1, 4]

Expand All @@ -268,7 +268,7 @@ def _make_named_index(df_: pd.DataFrame) -> pd.DataFrame:
)
p.run(named_index_impl_r)
p.run(named_index_impl_r)
data_ = p._dataset().named_index_impl.arrow()
data_ = p.dataset().named_index_impl.arrow()
assert data_.schema.names == ["created_at", "id"]
assert data_["id"].to_pylist() == ["a", "b", "c", "d", "e", "f", "g"]

Expand Down
4 changes: 2 additions & 2 deletions tests/load/duckdb/test_duckdb_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,14 +282,14 @@ def test_drops_pipeline_changes_bound() -> None:
p = dlt.pipeline(pipeline_name="quack_pipeline", destination="duckdb")
p.run([1, 2, 3], table_name="p_table")
p = p.drop()
assert len(p._dataset().p_table.fetchall()) == 3
assert len(p.dataset().p_table.fetchall()) == 3

# drops internal duckdb
p = dlt.pipeline(pipeline_name="quack_pipeline", destination=duckdb(":pipeline:"))
p.run([1, 2, 3], table_name="p_table")
p = p.drop()
with pytest.raises(DatabaseUndefinedRelation):
p._dataset().p_table.fetchall()
p.dataset().p_table.fetchall()


def test_duckdb_database_delete() -> None:
Expand Down
6 changes: 3 additions & 3 deletions tests/load/filesystem/test_sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ def items():

pipeline.run([items()], loader_file_format=destination_config.file_format)

df = pipeline._dataset().items.df()
df = pipeline.dataset().items.df()
assert len(df.index) == 20

@dlt.resource(table_name="items")
Expand All @@ -359,5 +359,5 @@ def items2():
pipeline.run([items2()], loader_file_format=destination_config.file_format)

# check df and arrow access
assert len(pipeline._dataset().items.df().index) == 50
assert pipeline._dataset().items.arrow().num_rows == 50
assert len(pipeline.dataset().items.df().index) == 50
assert pipeline.dataset().items.arrow().num_rows == 50
8 changes: 4 additions & 4 deletions tests/load/pipeline/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ def resource():
bigquery_adapter(resource, autodetect_schema=True)
pipeline.run(resource)

assert len(pipeline._dataset().items.df()) == 5
assert len(pipeline._dataset().items__nested.df()) == 5
assert len(pipeline.dataset().items.df()) == 5
assert len(pipeline.dataset().items__nested.df()) == 5

@dlt.resource(primary_key="id", table_name="items", write_disposition="merge")
def resource2():
Expand All @@ -395,5 +395,5 @@ def resource2():
bigquery_adapter(resource2, autodetect_schema=True)
pipeline.run(resource2)

assert len(pipeline._dataset().items.df()) == 7
assert len(pipeline._dataset().items__nested.df()) == 7
assert len(pipeline.dataset().items.df()) == 7
assert len(pipeline.dataset().items__nested.df()) == 7
4 changes: 2 additions & 2 deletions tests/load/pipeline/test_duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,10 +273,10 @@ def test_duckdb_credentials_separation(
p2 = dlt.pipeline("p2", destination=duckdb(credentials=":pipeline:"))

p1.run([1, 2, 3], table_name="p1_data")
p1_dataset = p1._dataset()
p1_dataset = p1.dataset()

p2.run([1, 2, 3], table_name="p2_data")
p2_dataset = p2._dataset()
p2_dataset = p2.dataset()

# both dataset should have independent duckdb databases
# destinations should be bounded to pipelines still
Expand Down
28 changes: 14 additions & 14 deletions tests/load/test_read_interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,7 @@ def double_items():
ids=lambda x: x.name,
)
def test_arrow_access(populated_pipeline: Pipeline) -> None:
table_relationship = populated_pipeline._dataset().items
table_relationship = populated_pipeline.dataset().items
total_records = _total_records(populated_pipeline)
chunk_size = _chunk_size(populated_pipeline)
expected_chunk_counts = _expected_chunk_count(populated_pipeline)
Expand Down Expand Up @@ -194,7 +194,7 @@ def test_arrow_access(populated_pipeline: Pipeline) -> None:
)
def test_dataframe_access(populated_pipeline: Pipeline) -> None:
# access via key
table_relationship = populated_pipeline._dataset()["items"]
table_relationship = populated_pipeline.dataset()["items"]
total_records = _total_records(populated_pipeline)
chunk_size = _chunk_size(populated_pipeline)
expected_chunk_counts = _expected_chunk_count(populated_pipeline)
Expand Down Expand Up @@ -234,7 +234,7 @@ def test_dataframe_access(populated_pipeline: Pipeline) -> None:
)
def test_db_cursor_access(populated_pipeline: Pipeline) -> None:
# check fetch accessors
table_relationship = populated_pipeline._dataset().items
table_relationship = populated_pipeline.dataset().items
total_records = _total_records(populated_pipeline)
chunk_size = _chunk_size(populated_pipeline)
expected_chunk_counts = _expected_chunk_count(populated_pipeline)
Expand Down Expand Up @@ -280,11 +280,11 @@ def test_ibis_dataset_access(populated_pipeline: Pipeline) -> None:
# check correct error if not supported
if populated_pipeline.destination.destination_type not in SUPPORTED_DESTINATIONS:
with pytest.raises(NotImplementedError):
populated_pipeline._dataset().ibis()
populated_pipeline.dataset().ibis()
return

total_records = _total_records(populated_pipeline)
ibis_connection = populated_pipeline._dataset().ibis()
ibis_connection = populated_pipeline.dataset().ibis()

map_i = lambda x: x
if populated_pipeline.destination.destination_type == "dlt.destinations.snowflake":
Expand Down Expand Up @@ -333,7 +333,7 @@ def test_ibis_dataset_access(populated_pipeline: Pipeline) -> None:
ids=lambda x: x.name,
)
def test_hint_preservation(populated_pipeline: Pipeline) -> None:
table_relationship = populated_pipeline._dataset().items
table_relationship = populated_pipeline.dataset().items
# check that hints are carried over to arrow table
expected_decimal_precision = 10
expected_decimal_precision_2 = 12
Expand Down Expand Up @@ -361,7 +361,7 @@ def test_hint_preservation(populated_pipeline: Pipeline) -> None:
)
def test_loads_table_access(populated_pipeline: Pipeline) -> None:
# check loads table access, we should have one entry
loads_table = populated_pipeline._dataset()[populated_pipeline.default_schema.loads_table_name]
loads_table = populated_pipeline.dataset()[populated_pipeline.default_schema.loads_table_name]
assert len(loads_table.fetchall()) == 1


Expand All @@ -376,7 +376,7 @@ def test_loads_table_access(populated_pipeline: Pipeline) -> None:
def test_sql_queries(populated_pipeline: Pipeline) -> None:
# simple check that query also works
tname = populated_pipeline.sql_client().make_qualified_table_name("items")
query_relationship = populated_pipeline._dataset()(f"select * from {tname} where id < 20")
query_relationship = populated_pipeline.dataset()(f"select * from {tname} where id < 20")

# we selected the first 20
table = query_relationship.arrow()
Expand All @@ -388,7 +388,7 @@ def test_sql_queries(populated_pipeline: Pipeline) -> None:
f"SELECT i.id, di.double_id FROM {tname} as i JOIN {tdname} as di ON (i.id = di.id) WHERE"
" i.id < 20 ORDER BY i.id ASC"
)
join_relationship = populated_pipeline._dataset()(query)
join_relationship = populated_pipeline.dataset()(query)
table = join_relationship.fetchall()
assert len(table) == 20
assert list(table[0]) == [0, 0]
Expand All @@ -405,7 +405,7 @@ def test_sql_queries(populated_pipeline: Pipeline) -> None:
ids=lambda x: x.name,
)
def test_limit_and_head(populated_pipeline: Pipeline) -> None:
table_relationship = populated_pipeline._dataset().items
table_relationship = populated_pipeline.dataset().items

assert len(table_relationship.head().fetchall()) == 5
assert len(table_relationship.limit(24).fetchall()) == 24
Expand All @@ -426,7 +426,7 @@ def test_limit_and_head(populated_pipeline: Pipeline) -> None:
ids=lambda x: x.name,
)
def test_column_selection(populated_pipeline: Pipeline) -> None:
table_relationship = populated_pipeline._dataset().items
table_relationship = populated_pipeline.dataset().items

columns = ["_dlt_load_id", "other_decimal"]
data_frame = table_relationship.select(*columns).head().df()
Expand Down Expand Up @@ -464,18 +464,18 @@ def test_schema_arg(populated_pipeline: Pipeline) -> None:
"""Simple test to ensure schemas may be selected via schema arg"""

# if there is no arg, the defautl schema is used
dataset = populated_pipeline._dataset()
dataset = populated_pipeline.dataset()
assert dataset.schema.name == populated_pipeline.default_schema_name
assert "items" in dataset.schema.tables

# setting a different schema name will try to load that schema,
# not find one and create an empty schema with that name
dataset = populated_pipeline._dataset(schema="unknown_schema")
dataset = populated_pipeline.dataset(schema="unknown_schema")
assert dataset.schema.name == "unknown_schema"
assert "items" not in dataset.schema.tables

# providing the schema name of the right schema will load it
dataset = populated_pipeline._dataset(schema=populated_pipeline.default_schema_name)
dataset = populated_pipeline.dataset(schema=populated_pipeline.default_schema_name)
assert dataset.schema.name == populated_pipeline.default_schema_name
assert "items" in dataset.schema.tables

Expand Down
2 changes: 1 addition & 1 deletion tests/pipeline/test_dlt_versions.py
Original file line number Diff line number Diff line change
Expand Up @@ -538,5 +538,5 @@ def test_normalize_path_separator_legacy_behavior(test_storage: FileStorage) ->
"_dlt_load_id",
}
# datasets must be the same
data_ = pipeline._dataset().issues_2.select("issue_id", "id").fetchall()
data_ = pipeline.dataset().issues_2.select("issue_id", "id").fetchall()
print(data_)
8 changes: 4 additions & 4 deletions tests/pipeline/test_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -1730,7 +1730,7 @@ def test_column_name_with_break_path() -> None:
# get data
assert_data_table_counts(pipeline, {"custom__path": 1})
# get data via dataset with dbapi
data_ = pipeline._dataset().custom__path[["example_custom_field__c", "reg_c"]].fetchall()
data_ = pipeline.dataset().custom__path[["example_custom_field__c", "reg_c"]].fetchall()
assert data_ == [("custom", "c")]


Expand All @@ -1754,7 +1754,7 @@ def test_column_name_with_break_path_legacy() -> None:
# get data
assert_data_table_counts(pipeline, {"custom_path": 1})
# get data via dataset with dbapi
data_ = pipeline._dataset().custom_path[["example_custom_field_c", "reg_c"]].fetchall()
data_ = pipeline.dataset().custom_path[["example_custom_field_c", "reg_c"]].fetchall()
assert data_ == [("custom", "c")]


Expand Down Expand Up @@ -1782,7 +1782,7 @@ def flattened_dict():
assert table["columns"]["value__timestamp"]["data_type"] == "timestamp"

# make sure data is there
data_ = pipeline._dataset().flattened__dict[["delta", "value__timestamp"]].limit(1).fetchall()
data_ = pipeline.dataset().flattened__dict[["delta", "value__timestamp"]].limit(1).fetchall()
assert data_ == [(0, now)]


Expand Down Expand Up @@ -1812,7 +1812,7 @@ def flattened_dict():
assert set(table["columns"]) == {"delta", "value__timestamp", "_dlt_id", "_dlt_load_id"}
assert table["columns"]["value__timestamp"]["data_type"] == "timestamp"
# make sure data is there
data_ = pipeline._dataset().flattened_dict[["delta", "value__timestamp"]].limit(1).fetchall()
data_ = pipeline.dataset().flattened_dict[["delta", "value__timestamp"]].limit(1).fetchall()
assert data_ == [(0, now)]


Expand Down
4 changes: 2 additions & 2 deletions tests/pipeline/test_pipeline_extra.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,7 @@ def test_parquet_with_flattened_columns() -> None:
assert "issue__reactions__url" in pipeline.default_schema.tables["events"]["columns"]
assert "issue_reactions_url" not in pipeline.default_schema.tables["events"]["columns"]

events_table = pipeline._dataset().events.arrow()
events_table = pipeline.dataset().events.arrow()
assert "issue__reactions__url" in events_table.schema.names
assert "issue_reactions_url" not in events_table.schema.names

Expand All @@ -536,7 +536,7 @@ def test_parquet_with_flattened_columns() -> None:
info = pipeline.run(events_table, table_name="events", loader_file_format="parquet")
assert_load_info(info)

events_table_new = pipeline._dataset().events.arrow()
events_table_new = pipeline.dataset().events.arrow()
assert events_table.schema == events_table_new.schema
# double row count
assert events_table.num_rows * 2 == events_table_new.num_rows
Expand Down

0 comments on commit fae7a2b

Please sign in to comment.