Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat: Adds interop with Arrow library using new method Dataset.to_arrow() #281

Merged
merged 22 commits into from
Jul 9, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions airbyte/caches/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
from typing import TYPE_CHECKING, Any, Optional, final

import pandas as pd
import pyarrow as pa
import pyarrow.dataset as ds
from pydantic import Field, PrivateAttr

from airbyte_protocol.models import ConfiguredAirbyteCatalog
Expand All @@ -19,6 +21,7 @@
from airbyte._future_cdk.state_writers import StdOutStateWriter
from airbyte.caches._catalog_backend import CatalogBackendBase, SqlCatalogBackend
from airbyte.caches._state_backend import SqlStateBackend
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
from airbyte.datasets._sql import CachedDataset


Expand Down Expand Up @@ -146,6 +149,38 @@ def get_pandas_dataframe(
engine = self.get_sql_engine()
return pd.read_sql_table(table_name, engine, schema=self.schema_name)

def get_arrow_dataset(
avirajsingh7 marked this conversation as resolved.
Show resolved Hide resolved
self,
stream_name: str,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> ds.Dataset:
"""Return an Arrow Dataset with the stream's data."""
table_name = self._read_processor.get_sql_table_name(stream_name)
engine = self.get_sql_engine()

# Read the table in chunks to handle large tables which does not fits in memory
pandas_chunks = pd.read_sql_table(
table_name=table_name,
con=engine,
schema=self.schema_name,
chunksize=max_chunk_size,
)

arrow_batches_list = []
arrow_schema = None

for pandas_chunk in pandas_chunks:
if arrow_schema is None:
# Initialize the schema with the first chunk
arrow_schema = pa.Schema.from_pandas(pandas_chunk)

# Convert each pandas chunk to an Arrow Table
arrow_table = pa.RecordBatch.from_pandas(pandas_chunk, schema=arrow_schema)
arrow_batches_list.append(arrow_table)

return ds.dataset(arrow_batches_list)

@final
@property
def streams(self) -> dict[str, CachedDataset]:
Expand Down
14 changes: 14 additions & 0 deletions airbyte/caches/bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,23 @@
from airbyte.caches.base import (
CacheBase,
)
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE


class BigQueryCache(BigQueryConfig, CacheBase):
"""The BigQuery cache implementation."""

_sql_processor_class: type[BigQuerySqlProcessor] = PrivateAttr(default=BigQuerySqlProcessor)

def get_arrow_dataset(
self,
stream_name: str,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> None:

# Raises a NotImplementedError as BigQuery doesn't support pd.read_sql_table
aaronsteers marked this conversation as resolved.
Show resolved Hide resolved
raise NotImplementedError(
"BigQuery doesn't currently support to_arrow"
"Please consider using a different cache implementation for these functionalities."
)
3 changes: 3 additions & 0 deletions airbyte/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,6 @@

Specific caches may override this value with a different schema name.
"""

DEFAULT_ARROW_MAX_CHUNK_SIZE = 100_000
"""The default number of records to include in each batch of an Arrow dataset."""
15 changes: 15 additions & 0 deletions airbyte/datasets/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,15 @@
from typing import TYPE_CHECKING, Any, cast

from pandas import DataFrame
from pyarrow.dataset import Dataset

from airbyte._util.document_rendering import DocumentRenderer
from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE


if TYPE_CHECKING:
from pyarrow.dataset import Dataset

from airbyte_protocol.models import ConfiguredAirbyteStream

from airbyte.documents import Document
Expand All @@ -37,6 +41,17 @@ def to_pandas(self) -> DataFrame:
# duck typing is correct for this use case.
return DataFrame(cast(Iterator[dict[str, Any]], self))

def to_arrow(
self,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> Dataset:
"""Return an Arrow Dataset representation of the dataset.

This method should be implemented by subclasses.
"""
raise NotImplementedError("Not implemented in base class")

def to_documents(
self,
title_property: str | None = None,
Expand Down
29 changes: 29 additions & 0 deletions airbyte/datasets/_sql.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,15 @@

from airbyte_protocol.models.airbyte_protocol import ConfiguredAirbyteStream

from airbyte.constants import DEFAULT_ARROW_MAX_CHUNK_SIZE
from airbyte.datasets._base import DatasetBase


if TYPE_CHECKING:
from collections.abc import Iterator

from pandas import DataFrame
from pyarrow.dataset import Dataset
from sqlalchemy import Table
from sqlalchemy.sql import ClauseElement
from sqlalchemy.sql.selectable import Selectable
Expand Down Expand Up @@ -102,6 +104,13 @@ def __len__(self) -> int:
def to_pandas(self) -> DataFrame:
return self._cache.get_pandas_dataframe(self._stream_name)

def to_arrow(
self,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> Dataset:
return self._cache.get_arrow_dataset(self._stream_name, max_chunk_size=max_chunk_size)

def with_filter(self, *filter_expressions: ClauseElement | str) -> SQLDataset:
"""Filter the dataset by a set of column values.

Expand Down Expand Up @@ -166,6 +175,26 @@ def to_pandas(self) -> DataFrame:
"""Return the underlying dataset data as a pandas DataFrame."""
return self._cache.get_pandas_dataframe(self._stream_name)

@overrides
def to_arrow(
self,
*,
max_chunk_size: int = DEFAULT_ARROW_MAX_CHUNK_SIZE,
) -> Dataset:
"""Return an Arrow Dataset containing the data from the specified stream.

Args:
stream_name (str): Name of the stream to retrieve data from.
max_chunk_size (int): max number of records to include in each batch of pyarrow dataset.

Returns:
pa.dataset.Dataset: Arrow Dataset containing the stream's data.
"""
return self._cache.get_arrow_dataset(
stream_name=self._stream_name,
max_chunk_size=max_chunk_size,
)

def to_sql_table(self) -> Table:
"""Return the underlying SQL table as a SQLAlchemy Table object."""
return self._cache.processor.get_sql_table(self.stream_name)
Expand Down
Loading
Loading