Skip to content

Commit

Permalink
add tests for sql filtering (now passing)
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronsteers committed Jan 22, 2024
1 parent 746a7ad commit 210c0bf
Show file tree
Hide file tree
Showing 4 changed files with 115 additions and 76 deletions.
71 changes: 41 additions & 30 deletions airbyte-lib/airbyte_lib/datasets/_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
from typing import TYPE_CHECKING, Any, cast

from overrides import overrides
from sqlalchemy import all_, text
from sqlalchemy import and_, text

from airbyte_lib.datasets._base import DatasetBase

Expand Down Expand Up @@ -37,31 +37,37 @@ def __init__(
self._stream_name: str = stream_name
self._query_statement: Selectable = query_statement

@property
def stream_name(self) -> str:
return self._stream_name

def __iter__(self) -> Iterator[Mapping[str, Any]]:
with self._cache.get_sql_connection() as conn:
for row in conn.execute(self._query_statement):
# Access to private member required because SQLAlchemy doesn't expose a public API.
# https://pydoc.dev/sqlalchemy/latest/sqlalchemy.engine.row.RowMapping.html
yield cast(Mapping[str, Any], row._mapping) # noqa: SLF001

def __eq__(self, __value: object) -> bool:
if not isinstance(__value, SQLDataset):
return False

if self._cache is not __value._cache:
return False

if self._stream_name != __value._stream_name:
return False

if self._query_statement != __value._query_statement:
return False

return True

def to_pandas(self) -> DataFrame:
return self._cache.get_pandas_dataframe(self._stream_name)

def with_filter(self, *filter_expressions: ClauseElement | str) -> SQLDataset:
"""Filter the dataset by a set of column values.
Filters can be specified as either a string or a SQLAlchemy expression.
"""
# Convert all strings to TextClause objects.
filters: list[ClauseElement] = [
text(expression) if isinstance(expression, str) else expression
for expression in filter_expressions
]
filtered_select = self._query_statement.where(and_(*filters))
return SQLDataset(
cache=self._cache,
stream_name=self._stream_name,
query_statement=filtered_select,
)


class CachedDataset(SQLDataset):
"""A dataset backed by a SQL table cache.
Expand All @@ -82,19 +88,24 @@ def to_pandas(self) -> DataFrame:
def to_sql_table(self) -> Table:
return self._cache.get_sql_table(self._stream_name)

def with_filter(self, *filter_expressions: ClauseElement | str) -> SQLDataset:
"""Filter the dataset by a set of column values.
def __eq__(self, value: object) -> bool:
"""Return True if the value is a CachedDataset with the same cache and stream name.
Filters can be specified as either a string or a SQLAlchemy expression.
In the case of CachedDataset objects, we can simply compare the cache and stream name.
Note that this equality check is only supported on CachedDataset objects and not for
the base SQLDataset implementation. This is because of the complexity and computational
cost of comparing two arbitrary SQL queries that could be bound to different variables,
as well as the chance that two queries can be syntactically equivalent without being
text-wise equivalent.
"""
# Convert all strings to TextClause objects.
filters: list[ClauseElement] = [
text(expression) if isinstance(expression, str) else expression
for expression in filter_expressions
]
filtered_select = self._query_statement.where(all_(*filters))
return SQLDataset(
cache=self._cache,
stream_name=self._stream_name,
query_statement=filtered_select,
)
if not isinstance(value, SQLDataset):
return False

if self._cache is not value._cache:
return False

if self._stream_name != value._stream_name:
return False

return True
23 changes: 8 additions & 15 deletions airbyte-lib/docs/generated/airbyte_lib.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

51 changes: 36 additions & 15 deletions airbyte-lib/docs/generated/airbyte_lib/datasets.html

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 30 additions & 16 deletions airbyte-lib/tests/integration_tests/test_integration.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
# Copyright (c) 2023 Airbyte, Inc., all rights reserved.

from collections.abc import Mapping
import os
import shutil
import subprocess
from typing import Any
from unittest.mock import Mock, call, patch
import tempfile
from pathlib import Path

from sqlalchemy import column, text

import airbyte_lib as ab
from airbyte_lib.caches import SnowflakeCacheConfig, SnowflakeSQLCache
import pandas as pd
Expand All @@ -16,7 +19,7 @@
from airbyte_lib.registry import _update_cache
from airbyte_lib.version import get_version
from airbyte_lib.results import ReadResult
from airbyte_lib.datasets import CachedDataset, DatasetBase, LazyDataset, LazySQLDataset
from airbyte_lib.datasets import CachedDataset, LazyDataset, SQLDataset
import airbyte_lib as ab

from airbyte_lib.results import ReadResult
Expand Down Expand Up @@ -229,38 +232,49 @@ def test_cached_dataset():
result.cache.streams[not_a_stream_name]


def test_lazy_dataset_from_source():
def test_cached_dataset_filter():
source = ab.get_connector("source-test", config={"apiKey": "test"})
result: ReadResult = source.read(ab.new_local_cache())

stream_name = "stream1"
not_a_stream_name = "not_a_stream"

lazy_dataset_a = source.get_records(stream_name)
lazy_dataset_b = source.get_records(stream_name)
# Check the many ways to add a filter:
cached_dataset: CachedDataset = result[stream_name]
filtered_dataset_a: SQLDataset = cached_dataset.with_filter("column2 == 1")
filtered_dataset_b: SQLDataset = cached_dataset.with_filter(text("column2 == 1"))
filtered_dataset_c: SQLDataset = cached_dataset.with_filter(column("column2") == 1)

assert isinstance(lazy_dataset_a, LazyDataset)
assert isinstance(cached_dataset, CachedDataset)
all_records = list(cached_dataset)
assert len(all_records) == 2

# Check that we can iterate over the stream
for filtered_dataset, case in [
(filtered_dataset_a, "a"),
(filtered_dataset_b, "b"),
(filtered_dataset_c, "c"),
]:
assert isinstance(filtered_dataset, SQLDataset)

list_from_iter_a = list(lazy_dataset_a)
list_from_iter_b = [row for row in lazy_dataset_b]
# Check that we can iterate over each stream

assert list_from_iter_a == list_from_iter_b
filtered_records: list[Mapping[str, Any]] = [row for row in filtered_dataset]

# Make sure that we get a key error if we try to access a stream that doesn't exist
with pytest.raises(KeyError):
source.get_records(not_a_stream_name)
# Check that the filter worked
assert len(filtered_records) == 1, f"Case '{case}' had incorrect number of records."

# Assert the stream name still matches
assert filtered_dataset.stream_name == stream_name, \
f"Case '{case}' had incorrect stream name."

def test_lazy_sql_dataset_from_cache():

def test_lazy_dataset_from_source():
source = ab.get_connector("source-test", config={"apiKey": "test"})

stream_name = "stream1"
not_a_stream_name = "not_a_stream"

lazy_dataset_a = source.get_records(stream_name)
lazy_dataset_b = source.get_records(stream_name)
lazy_dataset_c = source.get_records(stream_name)

assert isinstance(lazy_dataset_a, LazyDataset)

Expand Down

0 comments on commit 210c0bf

Please sign in to comment.