From 210c0bf1e0aebb4bfd6d6edebc12bc292ef4f8b1 Mon Sep 17 00:00:00 2001 From: Aaron Steers Date: Mon, 22 Jan 2024 14:01:22 -0800 Subject: [PATCH] add tests for sql filtering (now passing) --- airbyte-lib/airbyte_lib/datasets/_sql.py | 71 +++++++++++-------- airbyte-lib/docs/generated/airbyte_lib.html | 23 +++--- .../docs/generated/airbyte_lib/datasets.html | 51 +++++++++---- .../integration_tests/test_integration.py | 46 +++++++----- 4 files changed, 115 insertions(+), 76 deletions(-) diff --git a/airbyte-lib/airbyte_lib/datasets/_sql.py b/airbyte-lib/airbyte_lib/datasets/_sql.py index a1c4b3d5195f..0472ddbf4e50 100644 --- a/airbyte-lib/airbyte_lib/datasets/_sql.py +++ b/airbyte-lib/airbyte_lib/datasets/_sql.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, cast from overrides import overrides -from sqlalchemy import all_, text +from sqlalchemy import and_, text from airbyte_lib.datasets._base import DatasetBase @@ -37,6 +37,10 @@ def __init__( self._stream_name: str = stream_name self._query_statement: Selectable = query_statement + @property + def stream_name(self) -> str: + return self._stream_name + def __iter__(self) -> Iterator[Mapping[str, Any]]: with self._cache.get_sql_connection() as conn: for row in conn.execute(self._query_statement): @@ -44,24 +48,26 @@ def __iter__(self) -> Iterator[Mapping[str, Any]]: # https://pydoc.dev/sqlalchemy/latest/sqlalchemy.engine.row.RowMapping.html yield cast(Mapping[str, Any], row._mapping) # noqa: SLF001 - def __eq__(self, __value: object) -> bool: - if not isinstance(__value, SQLDataset): - return False - - if self._cache is not __value._cache: - return False - - if self._stream_name != __value._stream_name: - return False - - if self._query_statement != __value._query_statement: - return False - - return True - def to_pandas(self) -> DataFrame: return self._cache.get_pandas_dataframe(self._stream_name) + def with_filter(self, *filter_expressions: ClauseElement | str) -> SQLDataset: + """Filter the dataset by a set of column values. + + Filters can be specified as either a string or a SQLAlchemy expression. + """ + # Convert all strings to TextClause objects. + filters: list[ClauseElement] = [ + text(expression) if isinstance(expression, str) else expression + for expression in filter_expressions + ] + filtered_select = self._query_statement.where(and_(*filters)) + return SQLDataset( + cache=self._cache, + stream_name=self._stream_name, + query_statement=filtered_select, + ) + class CachedDataset(SQLDataset): """A dataset backed by a SQL table cache. @@ -82,19 +88,24 @@ def to_pandas(self) -> DataFrame: def to_sql_table(self) -> Table: return self._cache.get_sql_table(self._stream_name) - def with_filter(self, *filter_expressions: ClauseElement | str) -> SQLDataset: - """Filter the dataset by a set of column values. + def __eq__(self, value: object) -> bool: + """Return True if the value is a CachedDataset with the same cache and stream name. - Filters can be specified as either a string or a SQLAlchemy expression. + In the case of CachedDataset objects, we can simply compare the cache and stream name. + + Note that this equality check is only supported on CachedDataset objects and not for + the base SQLDataset implementation. This is because of the complexity and computational + cost of comparing two arbitrary SQL queries that could be bound to different variables, + as well as the chance that two queries can be syntactically equivalent without being + text-wise equivalent. """ - # Convert all strings to TextClause objects. - filters: list[ClauseElement] = [ - text(expression) if isinstance(expression, str) else expression - for expression in filter_expressions - ] - filtered_select = self._query_statement.where(all_(*filters)) - return SQLDataset( - cache=self._cache, - stream_name=self._stream_name, - query_statement=filtered_select, - ) + if not isinstance(value, SQLDataset): + return False + + if self._cache is not value._cache: + return False + + if self._stream_name != value._stream_name: + return False + + return True diff --git a/airbyte-lib/docs/generated/airbyte_lib.html b/airbyte-lib/docs/generated/airbyte_lib.html index b8f3f534ec44..5d1a7539b0fb 100644 --- a/airbyte-lib/docs/generated/airbyte_lib.html +++ b/airbyte-lib/docs/generated/airbyte_lib.html @@ -58,22 +58,15 @@ -
-
- - def - with_filter( self, *filter_expressions: 'ClauseElement | str') -> airbyte_lib.datasets._sql.SQLDataset: - - -
- - -

Filter the dataset by a set of column values.

- -

Filters can be specified as either a string or a SQLAlchemy expression.

-
- +
+
Inherited Members
+
+
airbyte_lib.datasets._sql.SQLDataset
+
stream_name
+
with_filter
+
+
diff --git a/airbyte-lib/docs/generated/airbyte_lib/datasets.html b/airbyte-lib/docs/generated/airbyte_lib/datasets.html index 82f34f033ee4..4d675640f905 100644 --- a/airbyte-lib/docs/generated/airbyte_lib/datasets.html +++ b/airbyte-lib/docs/generated/airbyte_lib/datasets.html @@ -58,22 +58,15 @@
-
-
- - def - with_filter( self, *filter_expressions: 'ClauseElement | str') -> SQLDataset: - - -
- - -

Filter the dataset by a set of column values.

- -

Filters can be specified as either a string or a SQLAlchemy expression.

-
- +
+
Inherited Members
+
+
SQLDataset
+
stream_name
+
with_filter
+
+
@@ -197,6 +190,17 @@
Inherited Members
+
+
+
+ stream_name: str + + +
+ + + +
@@ -212,6 +216,23 @@
Inherited Members
+
+
+
+ + def + with_filter( self, *filter_expressions: 'ClauseElement | str') -> SQLDataset: + + +
+ + +

Filter the dataset by a set of column values.

+ +

Filters can be specified as either a string or a SQLAlchemy expression.

+
+ +
diff --git a/airbyte-lib/tests/integration_tests/test_integration.py b/airbyte-lib/tests/integration_tests/test_integration.py index 378d26e68888..ba8efff5bc51 100644 --- a/airbyte-lib/tests/integration_tests/test_integration.py +++ b/airbyte-lib/tests/integration_tests/test_integration.py @@ -1,12 +1,15 @@ # Copyright (c) 2023 Airbyte, Inc., all rights reserved. +from collections.abc import Mapping import os import shutil -import subprocess +from typing import Any from unittest.mock import Mock, call, patch import tempfile from pathlib import Path +from sqlalchemy import column, text + import airbyte_lib as ab from airbyte_lib.caches import SnowflakeCacheConfig, SnowflakeSQLCache import pandas as pd @@ -16,7 +19,7 @@ from airbyte_lib.registry import _update_cache from airbyte_lib.version import get_version from airbyte_lib.results import ReadResult -from airbyte_lib.datasets import CachedDataset, DatasetBase, LazyDataset, LazySQLDataset +from airbyte_lib.datasets import CachedDataset, LazyDataset, SQLDataset import airbyte_lib as ab from airbyte_lib.results import ReadResult @@ -229,30 +232,42 @@ def test_cached_dataset(): result.cache.streams[not_a_stream_name] -def test_lazy_dataset_from_source(): +def test_cached_dataset_filter(): source = ab.get_connector("source-test", config={"apiKey": "test"}) + result: ReadResult = source.read(ab.new_local_cache()) stream_name = "stream1" - not_a_stream_name = "not_a_stream" - lazy_dataset_a = source.get_records(stream_name) - lazy_dataset_b = source.get_records(stream_name) + # Check the many ways to add a filter: + cached_dataset: CachedDataset = result[stream_name] + filtered_dataset_a: SQLDataset = cached_dataset.with_filter("column2 == 1") + filtered_dataset_b: SQLDataset = cached_dataset.with_filter(text("column2 == 1")) + filtered_dataset_c: SQLDataset = cached_dataset.with_filter(column("column2") == 1) - assert isinstance(lazy_dataset_a, LazyDataset) + assert isinstance(cached_dataset, CachedDataset) + all_records = list(cached_dataset) + assert len(all_records) == 2 - # Check that we can iterate over the stream + for filtered_dataset, case in [ + (filtered_dataset_a, "a"), + (filtered_dataset_b, "b"), + (filtered_dataset_c, "c"), + ]: + assert isinstance(filtered_dataset, SQLDataset) - list_from_iter_a = list(lazy_dataset_a) - list_from_iter_b = [row for row in lazy_dataset_b] + # Check that we can iterate over each stream - assert list_from_iter_a == list_from_iter_b + filtered_records: list[Mapping[str, Any]] = [row for row in filtered_dataset] - # Make sure that we get a key error if we try to access a stream that doesn't exist - with pytest.raises(KeyError): - source.get_records(not_a_stream_name) + # Check that the filter worked + assert len(filtered_records) == 1, f"Case '{case}' had incorrect number of records." + # Assert the stream name still matches + assert filtered_dataset.stream_name == stream_name, \ + f"Case '{case}' had incorrect stream name." -def test_lazy_sql_dataset_from_cache(): + +def test_lazy_dataset_from_source(): source = ab.get_connector("source-test", config={"apiKey": "test"}) stream_name = "stream1" @@ -260,7 +275,6 @@ def test_lazy_sql_dataset_from_cache(): lazy_dataset_a = source.get_records(stream_name) lazy_dataset_b = source.get_records(stream_name) - lazy_dataset_c = source.get_records(stream_name) assert isinstance(lazy_dataset_a, LazyDataset)