Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AirbyteLib: Add Lazy Datasets and iterator syntax support for datasets, caches, and read results #34429

Merged
merged 17 commits into from
Jan 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/airbyte-ci-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ jobs:
- 'airbyte-ci/connectors/metadata_service/orchestrator/**'
- '!**/*.md'
airbyte_lib:
- 'airbyte_lib/**'
- 'airbyte-lib/**'
- '!**/*.md'

- name: Run airbyte-ci/connectors/connector_ops tests
Expand Down
5 changes: 4 additions & 1 deletion airbyte-lib/airbyte_lib/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,16 +2,19 @@

from airbyte_lib._factories.cache_factories import get_default_cache, new_local_cache
from airbyte_lib._factories.connector_factories import get_connector
from airbyte_lib.caches import DuckDBCache, DuckDBCacheConfig
from airbyte_lib.datasets import CachedDataset
from airbyte_lib.results import ReadResult
from airbyte_lib.source import Source


__all__ = [
"CachedDataset",
"DuckDBCache",
"DuckDBCacheConfig",
"get_connector",
"get_default_cache",
"new_local_cache",
"CachedDataset",
"ReadResult",
"Source",
]
5 changes: 5 additions & 0 deletions airbyte-lib/airbyte_lib/_processors.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,11 @@ def register_source(
_ = source_name
self.source_catalog = source_catalog

@property
def _streams_with_data(self) -> set[str]:
"""Return a list of known streams."""
return self._pending_batches.keys() | self._finalized_batches.keys()

Comment on lines +102 to +106
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is new. Basically, this tells us all the streams we've seen so far.

@final
def process_stdin(
self,
Expand Down
2 changes: 1 addition & 1 deletion airbyte-lib/airbyte_lib/_util/protocol_util.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,6 @@ def get_primary_keys_from_stream(
None,
)
if stream is None:
raise ValueError(f"Stream {stream_name} not found in catalog.")
raise KeyError(f"Stream {stream_name} not found in catalog.")

return set(stream.stream.source_defined_primary_key or [])
36 changes: 20 additions & 16 deletions airbyte-lib/airbyte_lib/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

import abc
import enum
from collections.abc import Generator, Iterator, Mapping
from contextlib import contextmanager
from functools import cached_property
from typing import TYPE_CHECKING, Any, cast, final
Expand All @@ -22,10 +21,12 @@
from airbyte_lib._file_writers.base import FileWriterBase, FileWriterBatchHandle
from airbyte_lib._processors import BatchHandle, RecordProcessor
from airbyte_lib.config import CacheConfigBase
from airbyte_lib.datasets._sql import CachedDataset
from airbyte_lib.types import SQLTypeConverter


if TYPE_CHECKING:
from collections.abc import Generator, Iterator
from pathlib import Path

from sqlalchemy.engine import Connection, Engine
Expand Down Expand Up @@ -118,6 +119,15 @@ def __init__(
self.file_writer = file_writer or self.file_writer_class(config)
self.type_converter = self.type_converter_class()

def __getitem__(self, stream: str) -> DatasetBase:
return self.streams[stream]

def __contains__(self, stream: str) -> bool:
return stream in self._streams_with_data

def __iter__(self) -> Iterator[str]:
return iter(self._streams_with_data)

# Public interface:

def get_sql_alchemy_url(self) -> str:
Expand Down Expand Up @@ -211,28 +221,22 @@ def get_sql_table(
@property
def streams(
self,
) -> dict[str, DatasetBase]:
) -> dict[str, CachedDataset]:
"""Return a temporary table name."""
# TODO: Add support for streams map, based on the cached catalog.
raise NotImplementedError("Streams map is not yet supported.")
result = {}
for stream_name in self._streams_with_data:
result[stream_name] = CachedDataset(self, stream_name)

return result

# Read methods:

def get_records(
self,
stream_name: str,
) -> Iterator[Mapping[str, Any]]:
"""Uses SQLAlchemy to select all rows from the table.

# TODO: Refactor to return a LazyDataset here.
"""
table_ref = self.get_sql_table(stream_name)
stmt = table_ref.select()
with self.get_sql_connection() as conn:
for row in conn.execute(stmt):
# Access to private member required because SQLAlchemy doesn't expose a public API.
# https://pydoc.dev/sqlalchemy/latest/sqlalchemy.engine.row.RowMapping.html
yield cast(Mapping[str, Any], row._mapping) # noqa: SLF001
) -> CachedDataset:
"""Uses SQLAlchemy to select all rows from the table."""
return CachedDataset(self, stream_name)

def get_pandas_dataframe(
self,
Expand Down
4 changes: 1 addition & 3 deletions airbyte-lib/airbyte_lib/caches/postgres.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ class PostgresCacheConfig(SQLCacheConfigBase, ParquetWriterConfig):
@overrides
def get_sql_alchemy_url(self) -> str:
"""Return the SQLAlchemy URL to use."""
return (
f"postgresql://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"
)
return f"postgresql+psycopg2://{self.username}:{self.password}@{self.host}:{self.port}/{self.database}"

def get_database_name(self) -> str:
"""Return the name of the database."""
Expand Down
5 changes: 4 additions & 1 deletion airbyte-lib/airbyte_lib/datasets/__init__.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,13 @@
from airbyte_lib.datasets._base import DatasetBase
from airbyte_lib.datasets._cached import CachedDataset
from airbyte_lib.datasets._lazy import LazyDataset
from airbyte_lib.datasets._map import DatasetMap
from airbyte_lib.datasets._sql import CachedDataset, SQLDataset


__all__ = [
"CachedDataset",
"DatasetBase",
"DatasetMap",
"LazyDataset",
"SQLDataset",
]
16 changes: 7 additions & 9 deletions airbyte-lib/airbyte_lib/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,23 +5,21 @@
from typing import Any, cast

from pandas import DataFrame
from typing_extensions import Self


class DatasetBase(ABC, Iterator[Mapping[str, Any]]):
class DatasetBase(ABC):
"""Base implementation for all datasets."""

def __iter__(self) -> Self:
"""Return the iterator object (usually self)."""
return self

@abstractmethod
def __next__(self) -> Mapping[str, Any]:
"""Return the next value from the iterator."""
def __iter__(self) -> Iterator[Mapping[str, Any]]:
"""Return the iterator of records."""
raise NotImplementedError

def to_pandas(self) -> DataFrame:
"""Return a pandas DataFrame representation of the dataset."""
"""Return a pandas DataFrame representation of the dataset.

The base implementation simply passes the record iterator to Panda's DataFrame constructor.
"""
# Technically, we return an iterator of Mapping objects. However, pandas
# expects an iterator of dict objects. This cast is safe because we know
# duck typing is correct for this use case.
Expand Down
34 changes: 0 additions & 34 deletions airbyte-lib/airbyte_lib/datasets/_cached.py

This file was deleted.

32 changes: 7 additions & 25 deletions airbyte-lib/airbyte_lib/datasets/_lazy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,44 +4,26 @@
from typing import TYPE_CHECKING, Any

from overrides import overrides
from typing_extensions import Self

from airbyte_lib.datasets import DatasetBase


if TYPE_CHECKING:
from collections.abc import Callable, Iterator
from collections.abc import Iterator, Mapping


class LazyDataset(DatasetBase):
"""A dataset that is loaded incrementally from a source or a SQL query.

TODO: Test and debug this. It is not yet implemented anywhere in the codebase.
For now it servers as a placeholder.
"""
"""A dataset that is loaded incrementally from a source or a SQL query."""

def __init__(
self,
iterator: Iterator,
on_open: Callable | None = None,
on_close: Callable | None = None,
iterator: Iterator[Mapping[str, Any]],
) -> None:
self._iterator = iterator
self._on_open = on_open
self._on_close = on_close
raise NotImplementedError("This class is not implemented yet.")
self._iterator: Iterator[Mapping[str, Any]] = iterator

@overrides
def __iter__(self) -> Self:
raise NotImplementedError("This class is not implemented yet.")
# Pseudocode:
# if self._on_open is not None:
# self._on_open()

# yield from self._iterator

# if self._on_close is not None:
# self._on_close()
def __iter__(self) -> Iterator[Mapping[str, Any]]:
return self._iterator

def __next__(self) -> dict[str, Any]:
def __next__(self) -> Mapping[str, Any]:
return next(self._iterator)
119 changes: 119 additions & 0 deletions airbyte-lib/airbyte_lib/datasets/_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,119 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
from __future__ import annotations

from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, cast

from overrides import overrides
from sqlalchemy import and_, text

from airbyte_lib.datasets._base import DatasetBase


if TYPE_CHECKING:
from collections.abc import Iterator

from pandas import DataFrame
from sqlalchemy import Selectable, Table
from sqlalchemy.sql import ClauseElement

from airbyte_lib.caches import SQLCacheBase


class SQLDataset(DatasetBase):
"""A dataset that is loaded incrementally from a SQL query.

The CachedDataset class is a subclass of this class, which simply passes a SELECT over the full
table as the query statement.
"""

def __init__(
self,
cache: SQLCacheBase,
stream_name: str,
query_statement: Selectable,
) -> None:
self._cache: SQLCacheBase = cache
self._stream_name: str = stream_name
self._query_statement: Selectable = query_statement

@property
def stream_name(self) -> str:
return self._stream_name

def __iter__(self) -> Iterator[Mapping[str, Any]]:
with self._cache.get_sql_connection() as conn:
for row in conn.execute(self._query_statement):
# Access to private member required because SQLAlchemy doesn't expose a public API.
# https://pydoc.dev/sqlalchemy/latest/sqlalchemy.engine.row.RowMapping.html
yield cast(Mapping[str, Any], row._mapping) # noqa: SLF001

def to_pandas(self) -> DataFrame:
return self._cache.get_pandas_dataframe(self._stream_name)

def with_filter(self, *filter_expressions: ClauseElement | str) -> SQLDataset:
"""Filter the dataset by a set of column values.

Filters can be specified as either a string or a SQLAlchemy expression.

Filters are lazily applied to the dataset, so they can be chained together. For example:

dataset.with_filter("id > 5").with_filter("id < 10")

is equivalent to:

dataset.with_filter("id > 5", "id < 10")
"""
# Convert all strings to TextClause objects.
filters: list[ClauseElement] = [
text(expression) if isinstance(expression, str) else expression
for expression in filter_expressions
]
filtered_select = self._query_statement.where(and_(*filters))
return SQLDataset(
cache=self._cache,
stream_name=self._stream_name,
query_statement=filtered_select,
)


class CachedDataset(SQLDataset):
"""A dataset backed by a SQL table cache.

Because this dataset includes all records from the underlying table, we also expose the
underlying table as a SQLAlchemy Table object.
"""

def __init__(self, cache: SQLCacheBase, stream_name: str) -> None:
self._cache: SQLCacheBase = cache
self._stream_name: str = stream_name
self._query_statement: Selectable = self.to_sql_table().select()

@overrides
def to_pandas(self) -> DataFrame:
return self._cache.get_pandas_dataframe(self._stream_name)

def to_sql_table(self) -> Table:
return self._cache.get_sql_table(self._stream_name)

def __eq__(self, value: object) -> bool:
"""Return True if the value is a CachedDataset with the same cache and stream name.

In the case of CachedDataset objects, we can simply compare the cache and stream name.

Note that this equality check is only supported on CachedDataset objects and not for
the base SQLDataset implementation. This is because of the complexity and computational
cost of comparing two arbitrary SQL queries that could be bound to different variables,
as well as the chance that two queries can be syntactically equivalent without being
text-wise equivalent.
"""
if not isinstance(value, SQLDataset):
return False

if self._cache is not value._cache:
return False

if self._stream_name != value._stream_name:
return False

return True
Loading
Loading