Skip to content

Commit

Permalink
clean up interfaces a bit (more to come?)
Browse files Browse the repository at this point in the history
remove pipeline dependency from dataset
  • Loading branch information
sh-rp committed Aug 8, 2024
1 parent 9fcbd00 commit 28ee1c6
Show file tree
Hide file tree
Showing 6 changed files with 59 additions and 48 deletions.
20 changes: 15 additions & 5 deletions dlt/common/destination/reference.py
Original file line number Diff line number Diff line change
Expand Up @@ -570,10 +570,10 @@ def should_truncate_table_before_load_on_staging_destination(self, table: TTable
return True


class SupportsDataAccess(Protocol):
"""Add support accessing data items"""
class SupportsReadRelation(Protocol):
"""Add support accessing data items on a relation"""

def df(self, chunk_size: int = None, **kwargs: None) -> Optional[DataFrame]:
def df(self, chunk_size: int = None) -> Optional[DataFrame]:
"""Fetches the results as data frame. For large queries the results may be chunked
Fetches the results into a data frame. The default implementation uses helpers in `pandas.io.sql` to generate Pandas data frame.
Expand All @@ -589,7 +589,7 @@ def df(self, chunk_size: int = None, **kwargs: None) -> Optional[DataFrame]:
"""
...

def arrow(self, *, chunk_size: int = None) -> Optional[ArrowTable]: ...
def arrow(self, chunk_size: int = None) -> Optional[ArrowTable]: ...

def iter_df(self, chunk_size: int) -> Generator[DataFrame, None, None]: ...

Expand All @@ -604,6 +604,16 @@ def iter_fetchmany(self, chunk_size: int) -> Generator[List[Tuple[Any, ...]], An
def fetchone(self) -> Optional[Tuple[Any, ...]]: ...


class SupportsReadDataset(Protocol):
"""Add support for read access on a dataset"""

def sql(self, sql: str, prepare_tables: List[str] = None) -> SupportsReadRelation: ...

def __getitem__(self, table: str) -> SupportsReadRelation: ...

def __getattr__(self, table: str) -> SupportsReadRelation: ...


class SupportsRelationshipAccess(ABC):
"""Add support for accessing a cursor for a given relationship or query"""

Expand All @@ -614,7 +624,7 @@ def cursor_for_relation(
table: str = None,
sql: str = None,
prepare_tables: List[str] = None,
) -> ContextManager[SupportsDataAccess]: ...
) -> ContextManager[SupportsReadRelation]: ...


# TODO: type Destination properly
Expand Down
73 changes: 36 additions & 37 deletions dlt/dataset.py → dlt/destinations/dataset.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,52 @@
from typing import cast, Any, TYPE_CHECKING, Generator, List, Tuple, Optional
from typing import Any, Generator, List, Tuple, Optional

from contextlib import contextmanager
from dlt.destinations.job_client_impl import SqlJobClientBase
from dlt.destinations.fs_client import FSClientBase

from dlt.common.destination.reference import SupportsRelationshipAccess, SupportsDataAccess
from dlt.common.destination.reference import (
SupportsRelationshipAccess,
SupportsReadRelation,
JobClientBase,
SupportsReadDataset,
)

from dlt.common.typing import DataFrame, ArrowTable


class Relation:
class Relation(SupportsReadRelation):
def __init__(
self, *, pipeline: Any, table: str = None, sql: str = None, prepare_tables: List[str] = None
self,
*,
job_client: JobClientBase,
table: str = None,
sql: str = None,
prepare_tables: List[str] = None,
) -> None:
"""Create a lazy evaluated relation to for the dataset of a pipeline"""
from dlt.pipeline import Pipeline

self.pipeline: Pipeline = cast(Pipeline, pipeline)
"""Create a lazy evaluated relation to for the dataset of a destination"""
self.job_client = job_client
self.prepare_tables = prepare_tables
self.sql = sql
self.table = table

@contextmanager
def _client(self) -> Generator[SupportsRelationshipAccess, Any, Any]:
from dlt.destinations.job_client_impl import SqlJobClientBase
from dlt.destinations.fs_client import FSClientBase

client = self.pipeline.destination_client()

if isinstance(client, SqlJobClientBase):
with client.sql_client as sql_client:
if isinstance(self.job_client, SqlJobClientBase):
with self.job_client.sql_client as sql_client:
yield sql_client
return

if isinstance(client, FSClientBase):
yield client
if isinstance(self.job_client, FSClientBase):
yield self.job_client
return

raise Exception(
f"Destination {client.config.destination_type} does not support data access via"
" dataset."
f"Destination {self.job_client.config.destination_type} does not support data access"
" via dataset."
)

@contextmanager
def cursor(self) -> Generator[SupportsDataAccess, Any, Any]:
def cursor(self) -> Generator[SupportsReadRelation, Any, Any]:
"""Gets a DBApiCursor for the current relation"""
with self._client() as client:
with client.cursor_for_relation(
Expand All @@ -51,25 +56,22 @@ def cursor(self) -> Generator[SupportsDataAccess, Any, Any]:

def df(
self,
*,
chunk_size: int = None,
) -> DataFrame:
) -> Optional[DataFrame]:
"""Get first batch of table as dataframe"""
with self.cursor() as cursor:
return cursor.df(chunk_size=chunk_size)

def arrow(
self,
*,
chunk_size: int = None,
) -> ArrowTable:
) -> Optional[ArrowTable]:
"""Get first batch of table as arrow table"""
with self.cursor() as cursor:
return cursor.arrow(chunk_size=chunk_size)

def iter_df(
self,
*,
chunk_size: int,
) -> Generator[DataFrame, None, None]:
"""iterates over the whole table in dataframes of the given chunk_size, chunk_size of -1 will return the full table in the first batch"""
Expand All @@ -80,7 +82,6 @@ def iter_df(

def iter_arrow(
self,
*,
chunk_size: int,
) -> Generator[ArrowTable, None, None]:
"""iterates over the whole table in arrow tables of the given chunk_size, chunk_size of -1 will return the full table in the first batch"""
Expand Down Expand Up @@ -108,21 +109,19 @@ def fetchone(self) -> Optional[Tuple[Any, ...]]:
return cursor.fetchone()


class Dataset:
class Dataset(SupportsReadDataset):
"""Access to dataframes and arrowtables in the destination dataset"""

def __init__(self, pipeline: Any) -> None:
from dlt.pipeline import Pipeline

self.pipeline: Pipeline = cast(Pipeline, pipeline)
def __init__(self, job_client: JobClientBase) -> None:
self.job_client = job_client

def sql(self, sql: str, prepare_tables: List[str] = None) -> Relation:
return Relation(pipeline=self.pipeline, sql=sql, prepare_tables=prepare_tables)
def sql(self, sql: str, prepare_tables: List[str] = None) -> SupportsReadRelation:
return Relation(job_client=self.job_client, sql=sql, prepare_tables=prepare_tables)

def __getitem__(self, table: str) -> Relation:
def __getitem__(self, table: str) -> SupportsReadRelation:
"""access of table via dict notation"""
return Relation(pipeline=self.pipeline, table=table)
return Relation(job_client=self.job_client, table=table)

def __getattr__(self, table: str) -> Relation:
def __getattr__(self, table: str) -> SupportsReadRelation:
"""access of table via property notation"""
return Relation(pipeline=self.pipeline, table=table)
return Relation(job_client=self.job_client, table=table)
4 changes: 2 additions & 2 deletions dlt/destinations/typing.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from typing import Any, AnyStr, List, Type, Optional, Protocol, Tuple, TypeVar, Generator

from dlt.common.typing import DataFrame, ArrowTable
from dlt.common.destination.reference import SupportsDataAccess
from dlt.common.destination.reference import SupportsReadRelation

# native connection
TNativeConn = TypeVar("TNativeConn", bound=Any)
Expand All @@ -19,7 +19,7 @@ class DBApi(Protocol):
paramstyle: str


class DBApiCursor(SupportsDataAccess):
class DBApiCursor(SupportsReadRelation):
"""Protocol for DBAPI cursor"""

description: Tuple[Any, ...]
Expand Down
1 change: 0 additions & 1 deletion dlt/extract/resource.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@
ResourceNotATransformer,
)
from dlt.extract.wrappers import wrap_additional_type
from dlt.dataset import Dataset


def with_table_name(item: TDataItems, table_name: str) -> DataItemWithMeta:
Expand Down
8 changes: 5 additions & 3 deletions dlt/pipeline/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,7 @@
DestinationClientStagingConfiguration,
DestinationClientStagingConfiguration,
DestinationClientDwhWithStagingConfiguration,
SupportsReadDataset,
)
from dlt.common.normalizers.naming import NamingConvention
from dlt.common.pipeline import (
Expand Down Expand Up @@ -147,7 +148,6 @@
)
from dlt.common.storages.load_package import TLoadPackageState
from dlt.pipeline.helpers import refresh_source
from dlt.dataset import Dataset


def with_state_sync(may_extract_state: bool = False) -> Callable[[TFun], TFun]:
Expand Down Expand Up @@ -1702,6 +1702,8 @@ def __getstate__(self) -> Any:
return {"pipeline_name": self.pipeline_name}

@property
def dataset(self) -> Dataset:
def dataset(self) -> SupportsReadDataset:
"""Access helper to dataset"""
return Dataset(self)
from dlt.destinations.dataset import Dataset

return Dataset(self.destination_client())
1 change: 1 addition & 0 deletions tests/load/test_sql_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -298,6 +298,7 @@ def test_execute_query(client: SqlJobClientBase) -> None:
rows = curr.fetchall()
assert len(rows) == 0


@pytest.mark.parametrize(
"client", destinations_configs(default_sql_configs=True), indirect=True, ids=lambda x: x.name
)
Expand Down

0 comments on commit 28ee1c6

Please sign in to comment.