diff --git a/airbyte-lib/airbyte_lib/datasets/_sql.py b/airbyte-lib/airbyte_lib/datasets/_sql.py
index a1c4b3d5195f..0472ddbf4e50 100644
--- a/airbyte-lib/airbyte_lib/datasets/_sql.py
+++ b/airbyte-lib/airbyte_lib/datasets/_sql.py
@@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any, cast
from overrides import overrides
-from sqlalchemy import all_, text
+from sqlalchemy import and_, text
from airbyte_lib.datasets._base import DatasetBase
@@ -37,6 +37,10 @@ def __init__(
self._stream_name: str = stream_name
self._query_statement: Selectable = query_statement
+ @property
+ def stream_name(self) -> str:
+ return self._stream_name
+
def __iter__(self) -> Iterator[Mapping[str, Any]]:
with self._cache.get_sql_connection() as conn:
for row in conn.execute(self._query_statement):
@@ -44,24 +48,26 @@ def __iter__(self) -> Iterator[Mapping[str, Any]]:
# https://pydoc.dev/sqlalchemy/latest/sqlalchemy.engine.row.RowMapping.html
yield cast(Mapping[str, Any], row._mapping) # noqa: SLF001
- def __eq__(self, __value: object) -> bool:
- if not isinstance(__value, SQLDataset):
- return False
-
- if self._cache is not __value._cache:
- return False
-
- if self._stream_name != __value._stream_name:
- return False
-
- if self._query_statement != __value._query_statement:
- return False
-
- return True
-
def to_pandas(self) -> DataFrame:
return self._cache.get_pandas_dataframe(self._stream_name)
+ def with_filter(self, *filter_expressions: ClauseElement | str) -> SQLDataset:
+ """Filter the dataset by a set of column values.
+
+ Filters can be specified as either a string or a SQLAlchemy expression.
+ """
+ # Convert all strings to TextClause objects.
+ filters: list[ClauseElement] = [
+ text(expression) if isinstance(expression, str) else expression
+ for expression in filter_expressions
+ ]
+ filtered_select = self._query_statement.where(and_(*filters))
+ return SQLDataset(
+ cache=self._cache,
+ stream_name=self._stream_name,
+ query_statement=filtered_select,
+ )
+
class CachedDataset(SQLDataset):
"""A dataset backed by a SQL table cache.
@@ -82,19 +88,24 @@ def to_pandas(self) -> DataFrame:
def to_sql_table(self) -> Table:
return self._cache.get_sql_table(self._stream_name)
- def with_filter(self, *filter_expressions: ClauseElement | str) -> SQLDataset:
- """Filter the dataset by a set of column values.
+ def __eq__(self, value: object) -> bool:
+ """Return True if the value is a CachedDataset with the same cache and stream name.
- Filters can be specified as either a string or a SQLAlchemy expression.
+ In the case of CachedDataset objects, we can simply compare the cache and stream name.
+
+ Note that this equality check is only supported on CachedDataset objects and not for
+ the base SQLDataset implementation. This is because of the complexity and computational
+ cost of comparing two arbitrary SQL queries that could be bound to different variables,
+ as well as the chance that two queries can be syntactically equivalent without being
+ text-wise equivalent.
"""
- # Convert all strings to TextClause objects.
- filters: list[ClauseElement] = [
- text(expression) if isinstance(expression, str) else expression
- for expression in filter_expressions
- ]
- filtered_select = self._query_statement.where(all_(*filters))
- return SQLDataset(
- cache=self._cache,
- stream_name=self._stream_name,
- query_statement=filtered_select,
- )
+ if not isinstance(value, SQLDataset):
+ return False
+
+ if self._cache is not value._cache:
+ return False
+
+ if self._stream_name != value._stream_name:
+ return False
+
+ return True
diff --git a/airbyte-lib/docs/generated/airbyte_lib.html b/airbyte-lib/docs/generated/airbyte_lib.html
index b8f3f534ec44..5d1a7539b0fb 100644
--- a/airbyte-lib/docs/generated/airbyte_lib.html
+++ b/airbyte-lib/docs/generated/airbyte_lib.html
@@ -58,22 +58,15 @@
-
-
-
- def
- with_filter( self, *filter_expressions: 'ClauseElement | str') -> airbyte_lib.datasets._sql.SQLDataset:
-
-
-
-
-
-
Filter the dataset by a set of column values.
-
-
Filters can be specified as either a string or a SQLAlchemy expression.
-
-
+
+
Inherited Members
+
+ - airbyte_lib.datasets._sql.SQLDataset
+ - stream_name
+ - with_filter
+
+
diff --git a/airbyte-lib/docs/generated/airbyte_lib/datasets.html b/airbyte-lib/docs/generated/airbyte_lib/datasets.html
index 82f34f033ee4..4d675640f905 100644
--- a/airbyte-lib/docs/generated/airbyte_lib/datasets.html
+++ b/airbyte-lib/docs/generated/airbyte_lib/datasets.html
@@ -58,22 +58,15 @@
-
-
-
-
def
-
with_filter( self, *filter_expressions: 'ClauseElement | str') -> SQLDataset:
-
-
-
-
-
-
Filter the dataset by a set of column values.
-
-
Filters can be specified as either a string or a SQLAlchemy expression.
-
-
+
+
Inherited Members
+
+
+
@@ -197,6 +190,17 @@ Inherited Members
+
+
+
+ stream_name: str
+
+
+
+
+
+
+
@@ -212,6 +216,23 @@
Inherited Members
+
+
+
+
+
def
+
with_filter( self, *filter_expressions: 'ClauseElement | str') -> SQLDataset:
+
+
+
+
+
+
Filter the dataset by a set of column values.
+
+
Filters can be specified as either a string or a SQLAlchemy expression.
+
+
+
diff --git a/airbyte-lib/tests/integration_tests/test_integration.py b/airbyte-lib/tests/integration_tests/test_integration.py
index 378d26e68888..ba8efff5bc51 100644
--- a/airbyte-lib/tests/integration_tests/test_integration.py
+++ b/airbyte-lib/tests/integration_tests/test_integration.py
@@ -1,12 +1,15 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.
+from collections.abc import Mapping
import os
import shutil
-import subprocess
+from typing import Any
from unittest.mock import Mock, call, patch
import tempfile
from pathlib import Path
+from sqlalchemy import column, text
+
import airbyte_lib as ab
from airbyte_lib.caches import SnowflakeCacheConfig, SnowflakeSQLCache
import pandas as pd
@@ -16,7 +19,7 @@
from airbyte_lib.registry import _update_cache
from airbyte_lib.version import get_version
from airbyte_lib.results import ReadResult
-from airbyte_lib.datasets import CachedDataset, DatasetBase, LazyDataset, LazySQLDataset
+from airbyte_lib.datasets import CachedDataset, LazyDataset, SQLDataset
import airbyte_lib as ab
from airbyte_lib.results import ReadResult
@@ -229,30 +232,42 @@ def test_cached_dataset():
result.cache.streams[not_a_stream_name]
-def test_lazy_dataset_from_source():
+def test_cached_dataset_filter():
source = ab.get_connector("source-test", config={"apiKey": "test"})
+ result: ReadResult = source.read(ab.new_local_cache())
stream_name = "stream1"
- not_a_stream_name = "not_a_stream"
- lazy_dataset_a = source.get_records(stream_name)
- lazy_dataset_b = source.get_records(stream_name)
+ # Check the many ways to add a filter:
+ cached_dataset: CachedDataset = result[stream_name]
+ filtered_dataset_a: SQLDataset = cached_dataset.with_filter("column2 == 1")
+ filtered_dataset_b: SQLDataset = cached_dataset.with_filter(text("column2 == 1"))
+ filtered_dataset_c: SQLDataset = cached_dataset.with_filter(column("column2") == 1)
- assert isinstance(lazy_dataset_a, LazyDataset)
+ assert isinstance(cached_dataset, CachedDataset)
+ all_records = list(cached_dataset)
+ assert len(all_records) == 2
- # Check that we can iterate over the stream
+ for filtered_dataset, case in [
+ (filtered_dataset_a, "a"),
+ (filtered_dataset_b, "b"),
+ (filtered_dataset_c, "c"),
+ ]:
+ assert isinstance(filtered_dataset, SQLDataset)
- list_from_iter_a = list(lazy_dataset_a)
- list_from_iter_b = [row for row in lazy_dataset_b]
+ # Check that we can iterate over each stream
- assert list_from_iter_a == list_from_iter_b
+ filtered_records: list[Mapping[str, Any]] = [row for row in filtered_dataset]
- # Make sure that we get a key error if we try to access a stream that doesn't exist
- with pytest.raises(KeyError):
- source.get_records(not_a_stream_name)
+ # Check that the filter worked
+ assert len(filtered_records) == 1, f"Case '{case}' had incorrect number of records."
+ # Assert the stream name still matches
+ assert filtered_dataset.stream_name == stream_name, \
+ f"Case '{case}' had incorrect stream name."
-def test_lazy_sql_dataset_from_cache():
+
+def test_lazy_dataset_from_source():
source = ab.get_connector("source-test", config={"apiKey": "test"})
stream_name = "stream1"
@@ -260,7 +275,6 @@ def test_lazy_sql_dataset_from_cache():
lazy_dataset_a = source.get_records(stream_name)
lazy_dataset_b = source.get_records(stream_name)
- lazy_dataset_c = source.get_records(stream_name)
assert isinstance(lazy_dataset_a, LazyDataset)