Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chore: update comments and tests for label counts. #1182

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 13 additions & 3 deletions bigframes/session/_io/bigquery/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,12 @@ def format_option(key: str, value: Union[bool, str]) -> str:
return f"{key}={repr(value)}"


def add_labels(job_config, api_name: Optional[str] = None):
def add_and_trim_labels(job_config, api_name: Optional[str] = None):
"""
Add additional labels to the job configuration and trim the total number of labels
to ensure they do not exceed the maximum limit allowed by BigQuery, which is 64
labels per job.
"""
api_methods = log_adapter.get_and_reset_api_methods(dry_run=job_config.dry_run)
job_config.labels = create_job_configs_labels(
job_configs_labels=job_config.labels,
Expand All @@ -225,7 +230,9 @@ def start_query_with_client(
"""
Starts query job and waits for results.
"""
add_labels(job_config, api_name=api_name)
# Note: Ensure no additional labels are added to job_config after this point,
# as `add_and_trim_labels` ensures the label count does not exceed 64.
add_and_trim_labels(job_config, api_name=api_name)

try:
query_job = bq_client.query(sql, job_config=job_config, timeout=timeout)
Expand Down Expand Up @@ -304,7 +311,10 @@ def create_bq_dataset_reference(
bigquery.DatasetReference: The constructed reference to the anonymous dataset.
"""
job_config = google.cloud.bigquery.QueryJobConfig()
add_labels(job_config, api_name=api_name)

# Note: Ensure no additional labels are added to job_config after this point,
# as `add_and_trim_labels` ensures the label count does not exceed 64.
add_and_trim_labels(job_config, api_name=api_name)
query_job = bq_client.query(
"SELECT 1", location=location, project=project, job_config=job_config
)
Expand Down
15 changes: 11 additions & 4 deletions bigframes/session/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -348,7 +348,10 @@ def export_gcs(
export_options=dict(export_options),
)
job_config = bigquery.QueryJobConfig()
bq_io.add_labels(job_config, api_name=f"dataframe-to_{format.lower()}")

# Note: Ensure no additional labels are added to job_config after this point,
# as `add_and_trim_labels` ensures the label count does not exceed 64.
Copy link
Contributor

@shobsi shobsi Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The repetition of this note at multiple places makes me think - can we not call start_query_with_client instead of bqclient.query from everywhere? That way any preprocessing (including labels) would be centralized and we would have only one note lilke this in our code.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The logic after bqclient.query in start_query_with_client don't seems to match other functions, so this may not work.

We can make a new function that add labels and execute to replace bqclient.query, what do you think? This is my original thought but I'm not sure if it's necessary to add a new function just for two lines of code.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not a great refactoring but better than leaving repetitive notes IMHO.

Copy link
Contributor

@shobsi shobsi Dec 7, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Or we could make a hook in the ClientsProvider._create_bigquery_client, something like:

...
original_query = bq_client.query

def query (query, *args, job_config=None, api_name=None, **kwargs):
    add_and_trim_labels(job_config, api_name=api_name)
    original_query(query, *args, job_config=job_config, **kwargs)

bq_client.query = query
...
return bq_client

@tswast would this be too hacky? What would be the most pythonic way to ensure labels are within the max limit?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would discourage us from monkey patching like this, though it should work in Python. I think your first instinct of using our central helper function makes the most sense to me.

I don't fully understand what you mean by

The logic after bqclient.query in start_query_with_client don't seems to match other functions, so this may not work.

We should be able to put start_query_with_client in a try/except block. Also, it's probably a bug that we aren't showing that a query job is running here.

bq_io.add_and_trim_labels(job_config, api_name=f"dataframe-to_{format.lower()}")
export_job = self.bqclient.query(export_data_statement, job_config=job_config)
self._wait_on_job(export_job)
return query_job
Expand All @@ -358,7 +361,9 @@ def dry_run(
) -> bigquery.QueryJob:
sql = self.to_sql(array_value, ordered=ordered)
job_config = bigquery.QueryJobConfig(dry_run=True)
bq_io.add_labels(job_config)
# Note: Ensure no additional labels are added to job_config after this point,
# as `add_and_trim_labels` ensures the label count does not exceed 64.
bq_io.add_and_trim_labels(job_config)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we actually create a job resource in the backend when we do a dry-run, so this might actually lose analytics. I think probably better not to call add_and_trim_labels at all in this case.

query_job = self.bqclient.query(sql, job_config=job_config)
_ = query_job.result()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm pretty sure result() is a no-op for dry-run.

return query_job
Expand Down Expand Up @@ -486,8 +491,10 @@ def _run_execute_query(
if not self.strictly_ordered:
job_config.labels["bigframes-mode"] = "unordered"

# Note: add_labels is global scope which may have unexpected effects
bq_io.add_labels(job_config, api_name=api_name)
# Note: add_and_trim_labels is global scope which may have unexpected effects
# Ensure no additional labels are added to job_config after this point,
# as `add_and_trim_labels` ensures the label count does not exceed 64.
bq_io.add_and_trim_labels(job_config, api_name=api_name)
try:
query_job = self.bqclient.query(sql, job_config=job_config)
return (
Expand Down
304 changes: 304 additions & 0 deletions tests/unit/session/test_io_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,15 +17,66 @@
from typing import Iterable

import google.cloud.bigquery as bigquery
import google.cloud.bigquery_storage_v1
import pytest

import bigframes
from bigframes.core import log_adapter
import bigframes.pandas as bpd
import bigframes.session._io.bigquery as io_bq
import bigframes.session.executor as bf_exe
from tests.unit import resources


@pytest.fixture(scope="function")
def mock_bq_client(mocker):
mock_client = mocker.Mock(spec=bigquery.Client)
mock_query_job = mocker.Mock(spec=bigquery.QueryJob)
mock_row_iterator = mocker.Mock(spec=bigquery.table.RowIterator)

mock_query_job.result.return_value = mock_row_iterator

mock_destination = bigquery.DatasetReference(
project="mock_project", dataset_id="mock_dataset"
)
mock_query_job.destination = mock_destination

mock_client.query.return_value = mock_query_job

return mock_client


@pytest.fixture(scope="function")
def mock_storage_manager(mocker):
return mocker.Mock(spec=bigframes.session.temp_storage.TemporaryGbqStorageManager)


@pytest.fixture(scope="function")
def mock_bq_storage_read_client(mocker):
return mocker.Mock(spec=google.cloud.bigquery_storage_v1.BigQueryReadClient)


@pytest.fixture(scope="function")
def mock_array_value(mocker):
return mocker.Mock(spec=bigframes.core.ArrayValue)


@pytest.fixture(scope="function")
def patch_bq_caching_executor(mocker):
mock_execute_result = mocker.Mock(spec=bf_exe.ExecuteResult)
mock_execute_result.query_job.destination.project = "some_project"
mock_execute_result.query_job.destination.dataset_id = "some_dataset"
mock_execute_result.query_job.destination.table_id = "some_table"

with mocker.patch.object(
bf_exe.BigQueryCachingExecutor, "to_sql", return_value='select * from "abc"'
):
with mocker.patch.object(
bf_exe.BigQueryCachingExecutor, "execute", return_value=mock_execute_result
):
yield


def test_create_job_configs_labels_is_none():
api_methods = ["agg", "series-mode"]
labels = io_bq.create_job_configs_labels(
Expand Down Expand Up @@ -148,6 +199,259 @@ def test_create_job_configs_labels_length_limit_met():
assert "source" in labels.keys()


def test_add_and_trim_labels_length_limit_met():
log_adapter.get_and_reset_api_methods()
cur_labels = {
"bigframes-api": "read_pandas",
"source": "bigquery-dataframes-temp",
}
for i in range(10):
key = f"bigframes-api-test-{i}"
value = f"test{i}"
cur_labels[key] = value

df = bpd.DataFrame(
{"col1": [1, 2], "col2": [3, 4]}, session=resources.create_bigquery_session()
)

job_config = bigquery.job.QueryJobConfig()
job_config.labels = cur_labels

df.max()
for _ in range(60):
df.head()

io_bq.add_and_trim_labels(job_config=job_config)
assert job_config.labels is not None
assert len(job_config.labels) == 64
assert "dataframe-max" not in job_config.labels.values()
assert "dataframe-head" in job_config.labels.values()
assert "bigframes-api" in job_config.labels.keys()
assert "source" in job_config.labels.keys()


@pytest.mark.parametrize(
("max_results", "timeout", "api_name"),
[(None, None, None), (100, 30.0, "test_api")],
)
def test_start_query_with_client_labels_length_limit_met(
mock_bq_client, max_results, timeout, api_name
):
sql = "select * from abc"
cur_labels = {
"bigframes-api": "read_pandas",
"source": "bigquery-dataframes-temp",
}
for i in range(10):
key = f"bigframes-api-test-{i}"
value = f"test{i}"
cur_labels[key] = value

df = bpd.DataFrame(
{"col1": [1, 2], "col2": [3, 4]}, session=resources.create_bigquery_session()
)

job_config = bigquery.job.QueryJobConfig()
job_config.labels = cur_labels

df.max()
for _ in range(60):
df.head()

io_bq.start_query_with_client(
mock_bq_client,
sql,
job_config,
max_results=max_results,
timeout=timeout,
api_name=api_name,
)

assert job_config.labels is not None
assert len(job_config.labels) == 64
assert "dataframe-max" not in job_config.labels.values()
assert "dataframe-head" in job_config.labels.values()
assert "bigframes-api" in job_config.labels.keys()
assert "source" in job_config.labels.keys()


@pytest.mark.parametrize(
("location", "project", "api_name"),
[(None, None, None), ("us", "abc", "test_api")],
)
def test_create_bq_dataset_reference_length_limit_met(
mock_bq_client, location, project, api_name
):
df = bpd.DataFrame(
{"col1": [1, 2], "col2": [3, 4]}, session=resources.create_bigquery_session()
)

df.max()
for _ in range(64):
df.head()

io_bq.create_bq_dataset_reference(
mock_bq_client,
location=location,
project=project,
api_name=api_name,
)
_, kwargs = mock_bq_client.query.call_args
job_config = kwargs["job_config"]

assert job_config.labels is not None
assert len(job_config.labels) == 64
assert "dataframe-max" not in job_config.labels.values()
assert "dataframe-head" in job_config.labels.values()
assert "bigframes-api" in job_config.labels.keys()


@pytest.mark.parametrize(
("strictly_ordered", "format"),
[(True, "json"), (False, "csv"), (True, "json")],
)
def test_export_gcs_length_limit_met(
mock_bq_client,
mock_storage_manager,
mock_bq_storage_read_client,
mock_array_value,
strictly_ordered,
format,
patch_bq_caching_executor,
):
bigquery_caching_executor = bf_exe.BigQueryCachingExecutor(
mock_bq_client,
mock_storage_manager,
mock_bq_storage_read_client,
strictly_ordered=strictly_ordered,
)

df = bpd.DataFrame(
{"col1": [1, 2], "col2": [3, 4]}, session=resources.create_bigquery_session()
)

df.max()
for _ in range(63):
df.head()

bigquery_caching_executor.export_gcs(
mock_array_value,
col_id_overrides={"a": "b", "c": "d"},
uri="abc",
format=format,
export_options={"aa": True, "bb": "cc"},
)

_, kwargs = mock_bq_client.query.call_args
job_config = kwargs["job_config"]

assert job_config.labels is not None
assert len(job_config.labels) == 64
assert "dataframe-max" not in job_config.labels.values()
assert "dataframe-head" in job_config.labels.values()
assert "bigframes-api" in job_config.labels.keys()


@pytest.mark.parametrize(
("strictly_ordered", "ordered"),
[(True, False), (False, True)],
)
def test_dry_run_length_limit_met(
mock_bq_client,
mock_storage_manager,
mock_bq_storage_read_client,
mock_array_value,
strictly_ordered,
ordered,
patch_bq_caching_executor,
):
bigquery_caching_executor = bf_exe.BigQueryCachingExecutor(
mock_bq_client,
mock_storage_manager,
mock_bq_storage_read_client,
strictly_ordered=strictly_ordered,
)

df = bpd.DataFrame(
{"col1": [1, 2], "col2": [3, 4]}, session=resources.create_bigquery_session()
)

df.max()
for _ in range(64):
df.head()

bigquery_caching_executor.dry_run(mock_array_value, ordered=ordered)

_, kwargs = mock_bq_client.query.call_args
job_config = kwargs["job_config"]

assert job_config.labels is not None
assert len(job_config.labels) == 64
assert "dataframe-max" not in job_config.labels.values()
assert "dataframe-head" in job_config.labels.values()
assert "bigframes-api" in job_config.labels.keys()


@pytest.mark.parametrize(
("strictly_ordered", "api_name", "page_size", "max_results"),
[(True, None, None, None), (False, "test_api", 100, 10)],
)
def test__run_execute_query_length_limit_met(
mock_bq_client,
mock_storage_manager,
mock_bq_storage_read_client,
strictly_ordered,
api_name,
page_size,
max_results,
):
sql = "select * from abc"

bigquery_caching_executor = bf_exe.BigQueryCachingExecutor(
mock_bq_client,
mock_storage_manager,
mock_bq_storage_read_client,
strictly_ordered=strictly_ordered,
)

cur_labels = {
"bigframes-api": "read_pandas",
"source": "bigquery-dataframes-temp",
}
for i in range(10):
key = f"bigframes-api-test-{i}"
value = f"test{i}"
cur_labels[key] = value

job_config = bigquery.job.QueryJobConfig()
job_config.labels = cur_labels

df = bpd.DataFrame(
{"col1": [1, 2], "col2": [3, 4]}, session=resources.create_bigquery_session()
)

df.max()
for _ in range(60):
df.head()

bigquery_caching_executor._run_execute_query(
sql,
job_config=job_config,
api_name=api_name,
page_size=page_size,
max_results=max_results,
)

_, kwargs = mock_bq_client.query.call_args
job_config = kwargs["job_config"]

assert job_config.labels is not None
assert len(job_config.labels) == 64
assert "dataframe-max" not in job_config.labels.values()
assert "dataframe-head" in job_config.labels.values()
assert "bigframes-api" in job_config.labels.keys()


def test_create_temp_table_default_expiration():
"""Make sure the created table has an expiration."""
expiration = datetime.datetime(
Expand Down
Loading