Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert automatically to arrow strings #86

Merged
merged 7 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 50 additions & 1 deletion dask_bigquery/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
import pandas as pd
import pyarrow
from dask.base import tokenize
from dask.dataframe._compat import PANDAS_GE_220
from dask.dataframe.utils import pyarrow_strings_enabled
from google.api_core import client_info as rest_client_info
from google.api_core import exceptions
from google.api_core.gapic_v1 import client_info as grpc_client_info
Expand Down Expand Up @@ -95,6 +97,7 @@ def bigquery_read(
read_kwargs: dict,
arrow_options: dict,
credentials: dict = None,
convert_string: bool = False,
) -> pd.DataFrame:
"""Read a single batch of rows via BQ Storage API, in Arrow binary format.

Expand All @@ -114,7 +117,15 @@ def bigquery_read(
BigQuery Storage API Stream "name"
NOTE: Please set if reading from Storage API without any `row_restriction`.
https://cloud.google.com/bigquery/docs/reference/storage/rpc/google.cloud.bigquery.storage.v1beta1#stream
convert_string: bool
Whether to convert strings directly to arrow strings in the output DataFrame
"""
arrow_options = arrow_options.copy()
if convert_string:
types_mapper = _get_types_mapper(arrow_options.get("types_mapper", {}.get))
if types_mapper is not None:
arrow_options["types_mapper"] = types_mapper

with bigquery_clients(project_id, credentials=credentials) as (_, bqs_client):
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(
Expand All @@ -130,6 +141,37 @@ def bigquery_read(
return pd.concat(shards)


def _get_types_mapper(user_mapper):
type_mappers = []

# always use the user-defined mapper first, if available
if user_mapper is not None:
type_mappers.append(user_mapper)

type_mappers.append({pyarrow.string(): pd.StringDtype("pyarrow")}.get)
if PANDAS_GE_220:
type_mappers.append({pyarrow.large_string(): pd.StringDtype("pyarrow")}.get)
type_mappers.append({pyarrow.date32(): pd.ArrowDtype(pyarrow.date32())}.get)
type_mappers.append({pyarrow.date64(): pd.ArrowDtype(pyarrow.date64())}.get)

def _convert_decimal_type(type):
if pyarrow.types.is_decimal(type):
return pd.ArrowDtype(type)
return None

type_mappers.append(_convert_decimal_type)

def default_types_mapper(pyarrow_dtype):
"""Try all type mappers in order, starting from the user type mapper."""
for type_converter in type_mappers:
converted_type = type_converter(pyarrow_dtype)
if converted_type is not None:
return converted_type

if len(type_mappers) > 0:
return default_types_mapper


def read_gbq(
project_id: str,
dataset_id: str,
Expand Down Expand Up @@ -196,13 +238,19 @@ def make_create_read_session_request():
),
)

arrow_options_meta = arrow_options.copy()
if pyarrow_strings_enabled():
types_mapper = _get_types_mapper(arrow_options.get("types_mapper", {}.get))
if types_mapper is not None:
arrow_options_meta["types_mapper"] = types_mapper
Comment on lines +241 to +245
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have this twice, once here and once in bigquery_read. Thoughts on keeping this here, passing arrow_options (with the correct types_mapper) through to dd.from_map below? That way we could drop the convert_string= parameter.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I thought about this too.

I'd like to avoid serialising that stuff in the graph, just passing a flag seems a lot easier.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. Are the types_mapper entries we're adding particularly large?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It adds a few callables, which doesn't seem like a good idea (didn't do any profiling)


# Create a read session in order to detect the schema.
# Read sessions are light weight and will be auto-deleted after 24 hours.
session = bqs_client.create_read_session(make_create_read_session_request())
schema = pyarrow.ipc.read_schema(
pyarrow.py_buffer(session.arrow_schema.serialized_schema)
)
meta = schema.empty_table().to_pandas(**arrow_options)
meta = schema.empty_table().to_pandas(**arrow_options_meta)

return dd.from_map(
partial(
Expand All @@ -212,6 +260,7 @@ def make_create_read_session_request():
read_kwargs=read_kwargs,
arrow_options=arrow_options,
credentials=credentials,
convert_string=pyarrow_strings_enabled(),
),
[stream.name for stream in session.streams],
meta=meta,
Expand Down
Empty file added dask_bigquery/tests/__init__.py
Empty file.
32 changes: 27 additions & 5 deletions dask_bigquery/tests/test_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import uuid
from datetime import datetime, timedelta, timezone

import dask
import dask.dataframe as dd
import gcsfs
import google.auth
import pandas as pd
import pyarrow as pa
import pytest
from dask.dataframe.utils import assert_eq
from dask.dataframe.utils import assert_eq, pyarrow_strings_enabled
from distributed.utils_test import cleanup # noqa: F401
from distributed.utils_test import client # noqa: F401
from distributed.utils_test import cluster_fixture # noqa: F401
Expand Down Expand Up @@ -380,11 +381,32 @@ def test_arrow_options(table):
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
arrow_options={
"types_mapper": {pa.string(): pd.StringDtype(storage="pyarrow")}.get
},
arrow_options={"types_mapper": {pa.int64(): pd.Float32Dtype()}.get},
)
assert ddf.dtypes["name"] == pd.StringDtype(storage="pyarrow")
assert ddf.dtypes["number"] == pd.Float32Dtype()


@pytest.mark.parametrize("convert_string", [True, False, None])
def test_convert_string(table, convert_string, df):
project_id, dataset_id, table_id = table
config = {}
if convert_string is not None:
config = {"dataframe.convert-string": convert_string}
with dask.config.set(config):
ddf = read_gbq(
project_id=project_id,
dataset_id=dataset_id,
table_id=table_id,
)
# Roundtrip through `dd.from_pandas` to check consistent
# behavior with Dask DataFrame
result = dd.from_pandas(df, npartitions=1)
if convert_string is True or (convert_string is None and pyarrow_strings_enabled()):
assert ddf.dtypes["name"] == pd.StringDtype(storage="pyarrow")
else:
assert ddf.dtypes["name"] == object

assert assert_eq(ddf.set_index("idx"), result.set_index("idx"))


@pytest.mark.skipif(sys.platform == "darwin", reason="Segfaults on macOS")
Expand Down
Loading