Skip to content

Commit

Permalink
re-organize existing dataset code to prepare ibis relation integration
Browse files Browse the repository at this point in the history
  • Loading branch information
sh-rp committed Nov 25, 2024
1 parent b6850e8 commit 34323da
Show file tree
Hide file tree
Showing 11 changed files with 483 additions and 421 deletions.
416 changes: 0 additions & 416 deletions dlt/destinations/dataset.py

This file was deleted.

19 changes: 19 additions & 0 deletions dlt/destinations/dataset/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
from dlt.destinations.dataset.factory import (
dataset,
)
from dlt.destinations.dataset.dataset import (
ReadableDBAPIDataset,
get_destination_clients,
)
from dlt.destinations.dataset.utils import (
get_destination_clients,
get_destination_client_initial_config,
)


__all__ = [
"dataset",
"ReadableDBAPIDataset",
"get_destination_client_initial_config",
"get_destination_clients",
]
126 changes: 126 additions & 0 deletions dlt/destinations/dataset/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
from typing import Any, Union, TYPE_CHECKING

from dlt.common.json import json

from dlt.common.exceptions import MissingDependencyException

from dlt.common.destination.reference import (
SupportsReadableRelation,
SupportsReadableDataset,
TDestinationReferenceArg,
Destination,
JobClientBase,
WithStateSync,
)

from dlt.common.schema.typing import TTableSchemaColumns
from dlt.destinations.sql_client import SqlClientBase, WithSqlClient
from dlt.common.schema import Schema
from dlt.destinations.dataset.relation import ReadableDBAPIRelation
from dlt.destinations.dataset.utils import get_destination_clients

if TYPE_CHECKING:
try:
from dlt.common.libs.ibis import BaseBackend as IbisBackend
except MissingDependencyException:
IbisBackend = Any
else:
IbisBackend = Any


class ReadableDBAPIDataset(SupportsReadableDataset):
"""Access to dataframes and arrowtables in the destination dataset via dbapi"""

def __init__(
self,
destination: TDestinationReferenceArg,
dataset_name: str,
schema: Union[Schema, str, None] = None,
) -> None:
self._destination = Destination.from_reference(destination)
self._provided_schema = schema
self._dataset_name = dataset_name
self._sql_client: SqlClientBase[Any] = None
self._schema: Schema = None

def ibis(self) -> IbisBackend:
"""return a connected ibis backend"""
from dlt.common.libs.ibis import create_ibis_backend

self._ensure_client_and_schema()
return create_ibis_backend(
self._destination,
self._destination_client(self.schema),
)

@property
def schema(self) -> Schema:
self._ensure_client_and_schema()
return self._schema

@property
def sql_client(self) -> SqlClientBase[Any]:
self._ensure_client_and_schema()
return self._sql_client

def _destination_client(self, schema: Schema) -> JobClientBase:
return get_destination_clients(
schema, destination=self._destination, destination_dataset_name=self._dataset_name
)[0]

def _ensure_client_and_schema(self) -> None:
"""Lazy load schema and client"""

# full schema given, nothing to do
if not self._schema and isinstance(self._provided_schema, Schema):
self._schema = self._provided_schema

# schema name given, resolve it from destination by name
elif not self._schema and isinstance(self._provided_schema, str):
with self._destination_client(Schema(self._provided_schema)) as client:
if isinstance(client, WithStateSync):
stored_schema = client.get_stored_schema(self._provided_schema)
if stored_schema:
self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema))
else:
self._schema = Schema(self._provided_schema)

# no schema name given, load newest schema from destination
elif not self._schema:
with self._destination_client(Schema(self._dataset_name)) as client:
if isinstance(client, WithStateSync):
stored_schema = client.get_stored_schema()
if stored_schema:
self._schema = Schema.from_stored_schema(json.loads(stored_schema.schema))

# default to empty schema with dataset name
if not self._schema:
self._schema = Schema(self._dataset_name)

# here we create the client bound to the resolved schema
if not self._sql_client:
destination_client = self._destination_client(self._schema)
if isinstance(destination_client, WithSqlClient):
self._sql_client = destination_client.sql_client
else:
raise Exception(
f"Destination {destination_client.config.destination_type} does not support"
" SqlClient."
)

def __call__(self, query: Any) -> ReadableDBAPIRelation:
return ReadableDBAPIRelation(readable_dataset=self, provided_query=query) # type: ignore[abstract]

def table(self, table_name: str) -> SupportsReadableRelation:
return ReadableDBAPIRelation(
readable_dataset=self,
table_name=table_name,
) # type: ignore[abstract]

def __getitem__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via dict notation"""
return self.table(table_name)

def __getattr__(self, table_name: str) -> SupportsReadableRelation:
"""access of table via property notation"""
return self.table(table_name)
22 changes: 22 additions & 0 deletions dlt/destinations/dataset/exceptions.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
from dlt.common.exceptions import DltException


class DatasetException(DltException):
pass


class ReadableRelationHasQueryException(DatasetException):
def __init__(self, attempted_change: str) -> None:
msg = (
"This readable relation was created with a provided sql query. You cannot change"
f" {attempted_change}. Please change the orignal sql query."
)
super().__init__(msg)


class ReadableRelationUnknownColumnException(DatasetException):
def __init__(self, column_name: str) -> None:
msg = (
f"The selected column {column_name} is not known in the dlt schema for this releation."
)
super().__init__(msg)
24 changes: 24 additions & 0 deletions dlt/destinations/dataset/factory.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Union


from dlt.common.destination import AnyDestination
from dlt.common.destination.reference import (
SupportsReadableDataset,
TDatasetType,
TDestinationReferenceArg,
)

from dlt.common.schema import Schema

from dlt.destinations.dataset.dataset import ReadableDBAPIDataset


def dataset(
destination: TDestinationReferenceArg,
dataset_name: str,
schema: Union[Schema, str, None] = None,
dataset_type: TDatasetType = "dbapi",
) -> SupportsReadableDataset:
if dataset_type == "dbapi":
return ReadableDBAPIDataset(destination, dataset_name, schema)
raise NotImplementedError(f"Dataset of type {dataset_type} not implemented")
Empty file.
192 changes: 192 additions & 0 deletions dlt/destinations/dataset/relation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
from typing import Any, Generator, Sequence, Union, TYPE_CHECKING

from contextlib import contextmanager


from dlt.common.destination.reference import (
SupportsReadableRelation,
)

from dlt.destinations.dataset.exceptions import (
ReadableRelationHasQueryException,
ReadableRelationUnknownColumnException,
)

from dlt.common.schema.typing import TTableSchemaColumns
from dlt.destinations.sql_client import SqlClientBase
from dlt.common.schema import Schema

if TYPE_CHECKING:
from dlt.destinations.dataset.dataset import ReadableDBAPIDataset
else:
ReadableDBAPIDataset = Any


class ReadableDBAPIRelation(SupportsReadableRelation):
def __init__(
self,
*,
readable_dataset: "ReadableDBAPIDataset",
provided_query: Any = None,
table_name: str = None,
limit: int = None,
selected_columns: Sequence[str] = None,
) -> None:
"""Create a lazy evaluated relation to for the dataset of a destination"""

# NOTE: we can keep an assertion here, this class will not be created by the user
assert bool(table_name) != bool(
provided_query
), "Please provide either an sql query OR a table_name"

self._dataset = readable_dataset

self._provided_query = provided_query
self._table_name = table_name
self._limit = limit
self._selected_columns = selected_columns

# wire protocol functions
self.df = self._wrap_func("df") # type: ignore
self.arrow = self._wrap_func("arrow") # type: ignore
self.fetchall = self._wrap_func("fetchall") # type: ignore
self.fetchmany = self._wrap_func("fetchmany") # type: ignore
self.fetchone = self._wrap_func("fetchone") # type: ignore

self.iter_df = self._wrap_iter("iter_df") # type: ignore
self.iter_arrow = self._wrap_iter("iter_arrow") # type: ignore
self.iter_fetch = self._wrap_iter("iter_fetch") # type: ignore

@property
def sql_client(self) -> SqlClientBase[Any]:
return self._dataset.sql_client

@property
def schema(self) -> Schema:
return self._dataset.schema

@property
def query(self) -> Any:
"""build the query"""
if self._provided_query:
return self._provided_query

table_name = self.sql_client.make_qualified_table_name(
self.schema.naming.normalize_path(self._table_name)
)

maybe_limit_clause_1 = ""
maybe_limit_clause_2 = ""
if self._limit:
maybe_limit_clause_1, maybe_limit_clause_2 = self.sql_client._limit_clause_sql(
self._limit
)

selector = "*"
if self._selected_columns:
selector = ",".join(
[
self.sql_client.escape_column_name(self.schema.naming.normalize_path(c))
for c in self._selected_columns
]
)

return f"SELECT {maybe_limit_clause_1} {selector} FROM {table_name} {maybe_limit_clause_2}"

@property
def columns_schema(self) -> TTableSchemaColumns:
return self.compute_columns_schema()

@columns_schema.setter
def columns_schema(self, new_value: TTableSchemaColumns) -> None:
raise NotImplementedError("columns schema in ReadableDBAPIRelation can only be computed")

def compute_columns_schema(self) -> TTableSchemaColumns:
"""provide schema columns for the cursor, may be filtered by selected columns"""

columns_schema = (
self.schema.tables.get(self._table_name, {}).get("columns", {}) if self.schema else {}
)

if not columns_schema:
return None
if not self._selected_columns:
return columns_schema

filtered_columns: TTableSchemaColumns = {}
for sc in self._selected_columns:
sc = self.schema.naming.normalize_path(sc)
if sc not in columns_schema.keys():
raise ReadableRelationUnknownColumnException(sc)
filtered_columns[sc] = columns_schema[sc]

return filtered_columns

@contextmanager
def cursor(self) -> Generator[SupportsReadableRelation, Any, Any]:
"""Gets a DBApiCursor for the current relation"""
with self.sql_client as client:
# this hacky code is needed for mssql to disable autocommit, read iterators
# will not work otherwise. in the future we should be able to create a readony
# client which will do this automatically
if hasattr(self.sql_client, "_conn") and hasattr(self.sql_client._conn, "autocommit"):
self.sql_client._conn.autocommit = False
with client.execute_query(self.query) as cursor:
if columns_schema := self.columns_schema:
cursor.columns_schema = columns_schema
yield cursor

def _wrap_iter(self, func_name: str) -> Any:
"""wrap SupportsReadableRelation generators in cursor context"""

def _wrap(*args: Any, **kwargs: Any) -> Any:
with self.cursor() as cursor:
yield from getattr(cursor, func_name)(*args, **kwargs)

return _wrap

def _wrap_func(self, func_name: str) -> Any:
"""wrap SupportsReadableRelation functions in cursor context"""

def _wrap(*args: Any, **kwargs: Any) -> Any:
with self.cursor() as cursor:
return getattr(cursor, func_name)(*args, **kwargs)

return _wrap

def __copy__(self) -> "ReadableDBAPIRelation":
return self.__class__(
readable_dataset=self._dataset,
provided_query=self._provided_query,
table_name=self._table_name,
limit=self._limit,
selected_columns=self._selected_columns,
)

def limit(self, limit: int) -> "ReadableDBAPIRelation":
if self._provided_query:
raise ReadableRelationHasQueryException("limit")
rel = self.__copy__()
rel._limit = limit
return rel

def select(self, *columns: str) -> "ReadableDBAPIRelation":
if self._provided_query:
raise ReadableRelationHasQueryException("select")
rel = self.__copy__()
rel._selected_columns = columns
# NOTE: the line below will ensure that no unknown columns are selected if
# schema is known
rel.compute_columns_schema()
return rel

def __getitem__(self, columns: Union[str, Sequence[str]]) -> "SupportsReadableRelation":
if isinstance(columns, str):
return self.select(columns)
elif isinstance(columns, Sequence):
return self.select(*columns)
else:
raise TypeError(f"Invalid argument type: {type(columns).__name__}")

def head(self, limit: int = 5) -> "ReadableDBAPIRelation":
return self.limit(limit)
Loading

0 comments on commit 34323da

Please sign in to comment.