diff --git a/.changes/unreleased/Under the Hood-20241107-143856.yaml b/.changes/unreleased/Under the Hood-20241107-143856.yaml new file mode 100644 index 000000000..db8557bf0 --- /dev/null +++ b/.changes/unreleased/Under the Hood-20241107-143856.yaml @@ -0,0 +1,6 @@ +kind: Under the Hood +body: Create a retry factory to simplify retry strategies across dbt-bigquery +time: 2024-11-07T14:38:56.210445-05:00 +custom: + Author: mikealfare osalama + Issue: "1395" diff --git a/dbt/adapters/bigquery/clients.py b/dbt/adapters/bigquery/clients.py new file mode 100644 index 000000000..18c59fc12 --- /dev/null +++ b/dbt/adapters/bigquery/clients.py @@ -0,0 +1,69 @@ +from google.api_core.client_info import ClientInfo +from google.api_core.client_options import ClientOptions +from google.api_core.retry import Retry +from google.auth.exceptions import DefaultCredentialsError +from google.cloud.bigquery import Client as BigQueryClient +from google.cloud.dataproc_v1 import BatchControllerClient, JobControllerClient +from google.cloud.storage import Client as StorageClient + +from dbt.adapters.events.logging import AdapterLogger + +import dbt.adapters.bigquery.__version__ as dbt_version +from dbt.adapters.bigquery.credentials import ( + BigQueryCredentials, + create_google_credentials, + set_default_credentials, +) + + +_logger = AdapterLogger("BigQuery") + + +def create_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: + try: + return _create_bigquery_client(credentials) + except DefaultCredentialsError: + _logger.info("Please log into GCP to continue") + set_default_credentials() + return _create_bigquery_client(credentials) + + +@Retry() # google decorator. retries on transient errors with exponential backoff +def create_gcs_client(credentials: BigQueryCredentials) -> StorageClient: + return StorageClient( + project=credentials.execution_project, + credentials=create_google_credentials(credentials), + ) + + +@Retry() # google decorator. retries on transient errors with exponential backoff +def create_dataproc_job_controller_client(credentials: BigQueryCredentials) -> JobControllerClient: + return JobControllerClient( + credentials=create_google_credentials(credentials), + client_options=ClientOptions(api_endpoint=_dataproc_endpoint(credentials)), + ) + + +@Retry() # google decorator. retries on transient errors with exponential backoff +def create_dataproc_batch_controller_client( + credentials: BigQueryCredentials, +) -> BatchControllerClient: + return BatchControllerClient( + credentials=create_google_credentials(credentials), + client_options=ClientOptions(api_endpoint=_dataproc_endpoint(credentials)), + ) + + +@Retry() # google decorator. retries on transient errors with exponential backoff +def _create_bigquery_client(credentials: BigQueryCredentials) -> BigQueryClient: + return BigQueryClient( + credentials.execution_project, + create_google_credentials(credentials), + location=getattr(credentials, "location", None), + client_info=ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}"), + client_options=ClientOptions(quota_project_id=credentials.quota_project), + ) + + +def _dataproc_endpoint(credentials: BigQueryCredentials) -> str: + return f"{credentials.dataproc_region}-dataproc.googleapis.com:443" diff --git a/dbt/adapters/bigquery/connections.py b/dbt/adapters/bigquery/connections.py index bda54080b..61fa87d40 100644 --- a/dbt/adapters/bigquery/connections.py +++ b/dbt/adapters/bigquery/connections.py @@ -8,17 +8,20 @@ from typing import Dict, Hashable, List, Optional, Tuple, TYPE_CHECKING import uuid -from google.api_core import client_info, client_options, retry -import google.auth -from google.auth import impersonated_credentials -import google.auth.exceptions -import google.cloud.bigquery -import google.cloud.exceptions -from google.oauth2 import ( - credentials as GoogleCredentials, - service_account as GoogleServiceAccountCredentials, +from google.auth.exceptions import RefreshError +from google.cloud.bigquery import ( + Client, + CopyJobConfig, + Dataset, + DatasetReference, + LoadJobConfig, + QueryJobConfig, + QueryPriority, + SchemaField, + Table, + TableReference, ) -from requests.exceptions import ConnectionError +from google.cloud.exceptions import BadRequest, Forbidden, NotFound from dbt_common.events.contextvars import get_node_info from dbt_common.events.functions import fire_event @@ -34,14 +37,9 @@ from dbt.adapters.events.types import SQLQuery from dbt.adapters.exceptions.connection import FailedToConnectError -import dbt.adapters.bigquery.__version__ as dbt_version -from dbt.adapters.bigquery.credentials import ( - BigQueryConnectionMethod, - Priority, - get_bigquery_defaults, - setup_default_credentials, -) -from dbt.adapters.bigquery.utility import is_base64, base64_to_string +from dbt.adapters.bigquery.clients import create_bigquery_client +from dbt.adapters.bigquery.credentials import Priority +from dbt.adapters.bigquery.retry import RetryFactory if TYPE_CHECKING: # Indirectly imported via agate_helper, which is lazy loaded further downfile. @@ -51,22 +49,8 @@ logger = AdapterLogger("BigQuery") -BQ_QUERY_JOB_SPLIT = "-----Query Job SQL Follows-----" - -WRITE_TRUNCATE = google.cloud.bigquery.job.WriteDisposition.WRITE_TRUNCATE -REOPENABLE_ERRORS = ( - ConnectionResetError, - ConnectionError, -) - -RETRYABLE_ERRORS = ( - google.cloud.exceptions.ServerError, - google.cloud.exceptions.BadRequest, - google.cloud.exceptions.BadGateway, - ConnectionResetError, - ConnectionError, -) +BQ_QUERY_JOB_SPLIT = "-----Query Job SQL Follows-----" @dataclass @@ -82,12 +66,10 @@ class BigQueryAdapterResponse(AdapterResponse): class BigQueryConnectionManager(BaseConnectionManager): TYPE = "bigquery" - DEFAULT_INITIAL_DELAY = 1.0 # Seconds - DEFAULT_MAXIMUM_DELAY = 3.0 # Seconds - def __init__(self, profile: AdapterRequiredConfig, mp_context: SpawnContext): super().__init__(profile, mp_context) self.jobs_by_thread: Dict[Hashable, List[str]] = defaultdict(list) + self._retry = RetryFactory(profile.credentials) @classmethod def handle_error(cls, error, message): @@ -108,19 +90,19 @@ def exception_handler(self, sql): try: yield - except google.cloud.exceptions.BadRequest as e: + except BadRequest as e: message = "Bad request while running query" self.handle_error(e, message) - except google.cloud.exceptions.Forbidden as e: + except Forbidden as e: message = "Access denied while running query" self.handle_error(e, message) - except google.cloud.exceptions.NotFound as e: + except NotFound as e: message = "Not found while running query" self.handle_error(e, message) - except google.auth.exceptions.RefreshError as e: + except RefreshError as e: message = ( "Unable to generate access token, if you're using " "impersonate_service_account, make sure your " @@ -153,15 +135,15 @@ def cancel_open(self): for thread_id, connection in self.thread_connections.items(): if connection is this_connection: continue + if connection.handle is not None and connection.state == ConnectionState.OPEN: - client = connection.handle + client: Client = connection.handle for job_id in self.jobs_by_thread.get(thread_id, []): - - def fn(): - return client.cancel_job(job_id) - - self._retry_and_handle(msg=f"Cancel job: {job_id}", conn=connection, fn=fn) - + with self.exception_handler(f"Cancel job: {job_id}"): + client.cancel_job( + job_id, + retry=self._retry.create_reopen_with_deadline(connection), + ) self.close(connection) if connection.name is not None: @@ -203,121 +185,23 @@ def format_rows_number(self, rows_number): rows_number *= 1000.0 return f"{rows_number:3.1f}{unit}".strip() - @classmethod - def get_google_credentials(cls, profile_credentials) -> GoogleCredentials: - method = profile_credentials.method - creds = GoogleServiceAccountCredentials.Credentials - - if method == BigQueryConnectionMethod.OAUTH: - credentials, _ = get_bigquery_defaults(scopes=profile_credentials.scopes) - return credentials - - elif method == BigQueryConnectionMethod.SERVICE_ACCOUNT: - keyfile = profile_credentials.keyfile - return creds.from_service_account_file(keyfile, scopes=profile_credentials.scopes) - - elif method == BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: - details = profile_credentials.keyfile_json - if is_base64(profile_credentials.keyfile_json): - details = base64_to_string(details) - return creds.from_service_account_info(details, scopes=profile_credentials.scopes) - - elif method == BigQueryConnectionMethod.OAUTH_SECRETS: - return GoogleCredentials.Credentials( - token=profile_credentials.token, - refresh_token=profile_credentials.refresh_token, - client_id=profile_credentials.client_id, - client_secret=profile_credentials.client_secret, - token_uri=profile_credentials.token_uri, - scopes=profile_credentials.scopes, - ) - - error = 'Invalid `method` in profile: "{}"'.format(method) - raise FailedToConnectError(error) - - @classmethod - def get_impersonated_credentials(cls, profile_credentials): - source_credentials = cls.get_google_credentials(profile_credentials) - return impersonated_credentials.Credentials( - source_credentials=source_credentials, - target_principal=profile_credentials.impersonate_service_account, - target_scopes=list(profile_credentials.scopes), - ) - - @classmethod - def get_credentials(cls, profile_credentials): - if profile_credentials.impersonate_service_account: - return cls.get_impersonated_credentials(profile_credentials) - else: - return cls.get_google_credentials(profile_credentials) - - @classmethod - @retry.Retry() # google decorator. retries on transient errors with exponential backoff - def get_bigquery_client(cls, profile_credentials): - creds = cls.get_credentials(profile_credentials) - execution_project = profile_credentials.execution_project - quota_project = profile_credentials.quota_project - location = getattr(profile_credentials, "location", None) - - info = client_info.ClientInfo(user_agent=f"dbt-bigquery-{dbt_version.version}") - options = client_options.ClientOptions(quota_project_id=quota_project) - return google.cloud.bigquery.Client( - execution_project, - creds, - location=location, - client_info=info, - client_options=options, - ) - @classmethod def open(cls, connection): - if connection.state == "open": + if connection.state == ConnectionState.OPEN: logger.debug("Connection is already open, skipping open.") return connection try: - handle = cls.get_bigquery_client(connection.credentials) - - except google.auth.exceptions.DefaultCredentialsError: - logger.info("Please log into GCP to continue") - setup_default_credentials() - - handle = cls.get_bigquery_client(connection.credentials) + connection.handle = create_bigquery_client(connection.credentials) + connection.state = ConnectionState.OPEN + return connection except Exception as e: - logger.debug( - "Got an error when attempting to create a bigquery " "client: '{}'".format(e) - ) - + logger.debug(f"""Got an error when attempting to create a bigquery " "client: '{e}'""") connection.handle = None - connection.state = "fail" - + connection.state = ConnectionState.FAIL raise FailedToConnectError(str(e)) - connection.handle = handle - connection.state = "open" - return connection - - @classmethod - def get_job_execution_timeout_seconds(cls, conn): - credentials = conn.credentials - return credentials.job_execution_timeout_seconds - - @classmethod - def get_job_retries(cls, conn) -> int: - credentials = conn.credentials - return credentials.job_retries - - @classmethod - def get_job_creation_timeout_seconds(cls, conn): - credentials = conn.credentials - return credentials.job_creation_timeout_seconds - - @classmethod - def get_job_retry_deadline_seconds(cls, conn): - credentials = conn.credentials - return credentials.job_retry_deadline_seconds - @classmethod def get_table_from_response(cls, resp) -> "agate.Table": from dbt_common.clients import agate_helper @@ -357,7 +241,6 @@ def raw_execute( dry_run: bool = False, ): conn = self.get_thread_connection() - client = conn.handle fire_event(SQLQuery(conn_name=conn.name, sql=sql, node_info=get_node_info())) @@ -373,34 +256,25 @@ def raw_execute( priority = conn.credentials.priority if priority == Priority.Batch: - job_params["priority"] = google.cloud.bigquery.QueryPriority.BATCH + job_params["priority"] = QueryPriority.BATCH else: - job_params["priority"] = google.cloud.bigquery.QueryPriority.INTERACTIVE + job_params["priority"] = QueryPriority.INTERACTIVE maximum_bytes_billed = conn.credentials.maximum_bytes_billed if maximum_bytes_billed is not None and maximum_bytes_billed != 0: job_params["maximum_bytes_billed"] = maximum_bytes_billed - job_creation_timeout = self.get_job_creation_timeout_seconds(conn) - job_execution_timeout = self.get_job_execution_timeout_seconds(conn) - - def fn(): + with self.exception_handler(sql): job_id = self.generate_job_id() return self._query_and_results( - client, + conn, sql, job_params, job_id, - job_creation_timeout=job_creation_timeout, - job_execution_timeout=job_execution_timeout, limit=limit, ) - query_job, iterator = self._retry_and_handle(msg=sql, conn=conn, fn=fn) - - return query_job, iterator - def execute( self, sql, auto_begin=False, fetch=None, limit: Optional[int] = None ) -> Tuple[BigQueryAdapterResponse, "agate.Table"]: @@ -528,9 +402,9 @@ def standard_to_legacy(table): _, iterator = self.raw_execute(sql, use_legacy_sql=True) return self.get_table_from_response(iterator) - def copy_bq_table(self, source, destination, write_disposition): + def copy_bq_table(self, source, destination, write_disposition) -> None: conn = self.get_thread_connection() - client = conn.handle + client: Client = conn.handle # ------------------------------------------------------------------------------- # BigQuery allows to use copy API using two different formats: @@ -558,89 +432,149 @@ def copy_bq_table(self, source, destination, write_disposition): write_disposition, ) - def copy_and_results(): - job_config = google.cloud.bigquery.CopyJobConfig(write_disposition=write_disposition) - copy_job = client.copy_table(source_ref_array, destination_ref, job_config=job_config) - timeout = self.get_job_execution_timeout_seconds(conn) or 300 - iterator = copy_job.result(timeout=timeout) - return copy_job, iterator - - self._retry_and_handle( - msg='copy table "{}" to "{}"'.format( - ", ".join(source_ref.path for source_ref in source_ref_array), - destination_ref.path, - ), - conn=conn, - fn=copy_and_results, + msg = 'copy table "{}" to "{}"'.format( + ", ".join(source_ref.path for source_ref in source_ref_array), + destination_ref.path, + ) + with self.exception_handler(msg): + copy_job = client.copy_table( + source_ref_array, + destination_ref, + job_config=CopyJobConfig(write_disposition=write_disposition), + retry=self._retry.create_reopen_with_deadline(conn), + ) + copy_job.result(timeout=self._retry.create_job_execution_timeout(fallback=300)) + + def write_dataframe_to_table( + self, + client: Client, + file_path: str, + database: str, + schema: str, + identifier: str, + table_schema: List[SchemaField], + field_delimiter: str, + fallback_timeout: Optional[float] = None, + ) -> None: + load_config = LoadJobConfig( + skip_leading_rows=1, + schema=table_schema, + field_delimiter=field_delimiter, ) + table = self.table_ref(database, schema, identifier) + self._write_file_to_table(client, file_path, table, load_config, fallback_timeout) + + def write_file_to_table( + self, + client: Client, + file_path: str, + database: str, + schema: str, + identifier: str, + fallback_timeout: Optional[float] = None, + **kwargs, + ) -> None: + config = kwargs["kwargs"] + if "schema" in config: + config["schema"] = json.load(config["schema"]) + load_config = LoadJobConfig(**config) + table = self.table_ref(database, schema, identifier) + self._write_file_to_table(client, file_path, table, load_config, fallback_timeout) + + def _write_file_to_table( + self, + client: Client, + file_path: str, + table: TableReference, + config: LoadJobConfig, + fallback_timeout: Optional[float] = None, + ) -> None: + + with self.exception_handler("LOAD TABLE"): + with open(file_path, "rb") as f: + job = client.load_table_from_file(f, table, rewind=True, job_config=config) + + response = job.result(retry=self._retry.create_retry(fallback=fallback_timeout)) + + if response.state != "DONE": + raise DbtRuntimeError("BigQuery Timeout Exceeded") + + elif response.error_result: + message = "\n".join(error["message"].strip() for error in response.errors) + raise DbtRuntimeError(message) @staticmethod def dataset_ref(database, schema): - return google.cloud.bigquery.DatasetReference(project=database, dataset_id=schema) + return DatasetReference(project=database, dataset_id=schema) @staticmethod def table_ref(database, schema, table_name): - dataset_ref = google.cloud.bigquery.DatasetReference(database, schema) - return google.cloud.bigquery.TableReference(dataset_ref, table_name) + dataset_ref = DatasetReference(database, schema) + return TableReference(dataset_ref, table_name) - def get_bq_table(self, database, schema, identifier): + def get_bq_table(self, database, schema, identifier) -> Table: """Get a bigquery table for a schema/model.""" conn = self.get_thread_connection() + client: Client = conn.handle # backwards compatibility: fill in with defaults if not specified database = database or conn.credentials.database schema = schema or conn.credentials.schema - table_ref = self.table_ref(database, schema, identifier) - return conn.handle.get_table(table_ref) + return client.get_table( + table=self.table_ref(database, schema, identifier), + retry=self._retry.create_reopen_with_deadline(conn), + ) - def drop_dataset(self, database, schema): + def drop_dataset(self, database, schema) -> None: conn = self.get_thread_connection() - dataset_ref = self.dataset_ref(database, schema) - client = conn.handle - - def fn(): - return client.delete_dataset(dataset_ref, delete_contents=True, not_found_ok=True) - - self._retry_and_handle(msg="drop dataset", conn=conn, fn=fn) + client: Client = conn.handle + with self.exception_handler("drop dataset"): + client.delete_dataset( + dataset=self.dataset_ref(database, schema), + delete_contents=True, + not_found_ok=True, + retry=self._retry.create_reopen_with_deadline(conn), + ) - def create_dataset(self, database, schema): + def create_dataset(self, database, schema) -> Dataset: conn = self.get_thread_connection() - client = conn.handle - dataset_ref = self.dataset_ref(database, schema) - - def fn(): - return client.create_dataset(dataset_ref, exists_ok=True) - - self._retry_and_handle(msg="create dataset", conn=conn, fn=fn) + client: Client = conn.handle + with self.exception_handler("create dataset"): + return client.create_dataset( + dataset=self.dataset_ref(database, schema), + exists_ok=True, + retry=self._retry.create_reopen_with_deadline(conn), + ) def list_dataset(self, database: str): - # the database string we get here is potentially quoted. Strip that off - # for the API call. - database = database.strip("`") + # The database string we get here is potentially quoted. + # Strip that off for the API call. conn = self.get_thread_connection() - client = conn.handle - - def query_schemas(): + client: Client = conn.handle + with self.exception_handler("list dataset"): # this is similar to how we have to deal with listing tables - all_datasets = client.list_datasets(project=database, max_results=10000) + all_datasets = client.list_datasets( + project=database.strip("`"), + max_results=10000, + retry=self._retry.create_reopen_with_deadline(conn), + ) return [ds.dataset_id for ds in all_datasets] - return self._retry_and_handle(msg="list dataset", conn=conn, fn=query_schemas) - def _query_and_results( self, - client, + conn, sql, job_params, job_id, - job_creation_timeout=None, - job_execution_timeout=None, limit: Optional[int] = None, ): + client: Client = conn.handle """Query the client and wait for results.""" # Cannot reuse job_config if destination is set and ddl is used - job_config = google.cloud.bigquery.QueryJobConfig(**job_params) query_job = client.query( - query=sql, job_config=job_config, job_id=job_id, timeout=job_creation_timeout + query=sql, + job_config=QueryJobConfig(**job_params), + job_id=job_id, # note, this disables retry since the job_id will have been used + timeout=self._retry.create_job_creation_timeout(), ) if ( query_job.location is not None @@ -650,37 +584,14 @@ def _query_and_results( logger.debug( self._bq_job_link(query_job.location, query_job.project, query_job.job_id) ) + + timeout = self._retry.create_job_execution_timeout() try: - iterator = query_job.result(max_results=limit, timeout=job_execution_timeout) - return query_job, iterator + iterator = query_job.result(max_results=limit, timeout=timeout) except TimeoutError: - exc = f"Operation did not complete within the designated timeout of {job_execution_timeout} seconds." + exc = f"Operation did not complete within the designated timeout of {timeout} seconds." raise TimeoutError(exc) - - def _retry_and_handle(self, msg, conn, fn): - """retry a function call within the context of exception_handler.""" - - def reopen_conn_on_error(error): - if isinstance(error, REOPENABLE_ERRORS): - logger.warning("Reopening connection after {!r}".format(error)) - self.close(conn) - self.open(conn) - return - - with self.exception_handler(msg): - return retry.retry_target( - target=fn, - predicate=_ErrorCounter(self.get_job_retries(conn)).count_error, - sleep_generator=self._retry_generator(), - deadline=self.get_job_retry_deadline_seconds(conn), - on_error=reopen_conn_on_error, - ) - - def _retry_generator(self): - """Generates retry intervals that exponentially back off.""" - return retry.exponential_sleep_generator( - initial=self.DEFAULT_INITIAL_DELAY, maximum=self.DEFAULT_MAXIMUM_DELAY - ) + return query_job, iterator def _labels_from_query_comment(self, comment: str) -> Dict: try: @@ -693,39 +604,6 @@ def _labels_from_query_comment(self, comment: str) -> Dict: } -class _ErrorCounter(object): - """Counts errors seen up to a threshold then raises the next error.""" - - def __init__(self, retries): - self.retries = retries - self.error_count = 0 - - def count_error(self, error): - if self.retries == 0: - return False # Don't log - self.error_count += 1 - if _is_retryable(error) and self.error_count <= self.retries: - logger.debug( - "Retry attempt {} of {} after error: {}".format( - self.error_count, self.retries, repr(error) - ) - ) - return True - else: - return False - - -def _is_retryable(error): - """Return true for errors that are unlikely to occur again if retried.""" - if isinstance(error, RETRYABLE_ERRORS): - return True - elif isinstance(error, google.api_core.exceptions.Forbidden) and any( - e["reason"] == "rateLimitExceeded" for e in error.errors - ): - return True - return False - - _SANITIZE_LABEL_PATTERN = re.compile(r"[^a-z0-9_-]") _VALIDATE_LABEL_LENGTH_LIMIT = 63 diff --git a/dbt/adapters/bigquery/credentials.py b/dbt/adapters/bigquery/credentials.py index 32f172dac..94d70a931 100644 --- a/dbt/adapters/bigquery/credentials.py +++ b/dbt/adapters/bigquery/credentials.py @@ -1,9 +1,14 @@ +import base64 +import binascii from dataclasses import dataclass, field from functools import lru_cache -from typing import Any, Dict, Optional, Tuple +from typing import Any, Dict, Iterable, Optional, Tuple, Union -import google.auth +from google.auth import default from google.auth.exceptions import DefaultCredentialsError +from google.auth.impersonated_credentials import Credentials as ImpersonatedCredentials +from google.oauth2.credentials import Credentials as GoogleCredentials +from google.oauth2.service_account import Credentials as ServiceAccountCredentials from mashumaro import pass_through from dbt_common.clients.system import run_cmd @@ -11,6 +16,7 @@ from dbt_common.exceptions import DbtConfigError, DbtRuntimeError from dbt.adapters.contracts.connection import Credentials from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.exceptions.connection import FailedToConnectError _logger = AdapterLogger("BigQuery") @@ -21,59 +27,22 @@ class Priority(StrEnum): Batch = "batch" -class BigQueryConnectionMethod(StrEnum): - OAUTH = "oauth" - SERVICE_ACCOUNT = "service-account" - SERVICE_ACCOUNT_JSON = "service-account-json" - OAUTH_SECRETS = "oauth-secrets" - - @dataclass class DataprocBatchConfig(ExtensibleDbtClassMixin): def __init__(self, batch_config): self.batch_config = batch_config -@lru_cache() -def get_bigquery_defaults(scopes=None) -> Tuple[Any, Optional[str]]: - """ - Returns (credentials, project_id) - - project_id is returned available from the environment; otherwise None - """ - # Cached, because the underlying implementation shells out, taking ~1s - try: - credentials, _ = google.auth.default(scopes=scopes) - return credentials, _ - except DefaultCredentialsError as e: - raise DbtConfigError(f"Failed to authenticate with supplied credentials\nerror:\n{e}") - - -def setup_default_credentials(): - if _gcloud_installed(): - run_cmd(".", ["gcloud", "auth", "application-default", "login"]) - else: - msg = """ - dbt requires the gcloud SDK to be installed to authenticate with BigQuery. - Please download and install the SDK, or use a Service Account instead. - - https://cloud.google.com/sdk/ - """ - raise DbtRuntimeError(msg) - - -def _gcloud_installed(): - try: - run_cmd(".", ["gcloud", "--version"]) - return True - except OSError as e: - _logger.debug(e) - return False +class _BigQueryConnectionMethod(StrEnum): + OAUTH = "oauth" + OAUTH_SECRETS = "oauth-secrets" + SERVICE_ACCOUNT = "service-account" + SERVICE_ACCOUNT_JSON = "service-account-json" @dataclass class BigQueryCredentials(Credentials): - method: BigQueryConnectionMethod = None # type: ignore + method: _BigQueryConnectionMethod = None # type: ignore # BigQuery allows an empty database / project, where it defers to the # environment for the project @@ -179,9 +148,122 @@ def __pre_deserialize__(cls, d: Dict[Any, Any]) -> Dict[Any, Any]: # `database` is an alias of `project` in BigQuery if "database" not in d: - _, database = get_bigquery_defaults() + _, database = _create_bigquery_defaults() d["database"] = database # `execution_project` default to dataset/project if "execution_project" not in d: d["execution_project"] = d["database"] return d + + +def set_default_credentials() -> None: + try: + run_cmd(".", ["gcloud", "--version"]) + except OSError as e: + _logger.debug(e) + msg = """ + dbt requires the gcloud SDK to be installed to authenticate with BigQuery. + Please download and install the SDK, or use a Service Account instead. + + https://cloud.google.com/sdk/ + """ + raise DbtRuntimeError(msg) + + run_cmd(".", ["gcloud", "auth", "application-default", "login"]) + + +def create_google_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: + if credentials.impersonate_service_account: + return _create_impersonated_credentials(credentials) + return _create_google_credentials(credentials) + + +def _create_impersonated_credentials(credentials: BigQueryCredentials) -> ImpersonatedCredentials: + if credentials.scopes and isinstance(credentials.scopes, Iterable): + target_scopes = list(credentials.scopes) + else: + target_scopes = [] + + return ImpersonatedCredentials( + source_credentials=_create_google_credentials(credentials), + target_principal=credentials.impersonate_service_account, + target_scopes=target_scopes, + ) + + +def _create_google_credentials(credentials: BigQueryCredentials) -> GoogleCredentials: + + if credentials.method == _BigQueryConnectionMethod.OAUTH: + creds, _ = _create_bigquery_defaults(scopes=credentials.scopes) + + elif credentials.method == _BigQueryConnectionMethod.SERVICE_ACCOUNT: + creds = ServiceAccountCredentials.from_service_account_file( + credentials.keyfile, scopes=credentials.scopes + ) + + elif credentials.method == _BigQueryConnectionMethod.SERVICE_ACCOUNT_JSON: + details = credentials.keyfile_json + if _is_base64(details): # type:ignore + details = _base64_to_string(details) + creds = ServiceAccountCredentials.from_service_account_info( + details, scopes=credentials.scopes + ) + + elif credentials.method == _BigQueryConnectionMethod.OAUTH_SECRETS: + creds = GoogleCredentials( + token=credentials.token, + refresh_token=credentials.refresh_token, + client_id=credentials.client_id, + client_secret=credentials.client_secret, + token_uri=credentials.token_uri, + scopes=credentials.scopes, + ) + + else: + raise FailedToConnectError(f"Invalid `method` in profile: '{credentials.method}'") + + return creds + + +@lru_cache() +def _create_bigquery_defaults(scopes=None) -> Tuple[Any, Optional[str]]: + """ + Returns (credentials, project_id) + + project_id is returned available from the environment; otherwise None + """ + # Cached, because the underlying implementation shells out, taking ~1s + try: + return default(scopes=scopes) + except DefaultCredentialsError as e: + raise DbtConfigError(f"Failed to authenticate with supplied credentials\nerror:\n{e}") + + +def _is_base64(s: Union[str, bytes]) -> bool: + """ + Checks if the given string or bytes object is valid Base64 encoded. + + Args: + s: The string or bytes object to check. + + Returns: + True if the input is valid Base64, False otherwise. + """ + + if isinstance(s, str): + # For strings, ensure they consist only of valid Base64 characters + if not s.isascii(): + return False + # Convert to bytes for decoding + s = s.encode("ascii") + + try: + # Use the 'validate' parameter to enforce strict Base64 decoding rules + base64.b64decode(s, validate=True) + return True + except (TypeError, binascii.Error): + return False + + +def _base64_to_string(b): + return base64.b64decode(b).decode("utf-8") diff --git a/dbt/adapters/bigquery/dataproc/__init__.py b/dbt/adapters/bigquery/dataproc/__init__.py deleted file mode 100644 index e69de29bb..000000000 diff --git a/dbt/adapters/bigquery/dataproc/batch.py b/dbt/adapters/bigquery/dataproc/batch.py deleted file mode 100644 index 59f40d246..000000000 --- a/dbt/adapters/bigquery/dataproc/batch.py +++ /dev/null @@ -1,68 +0,0 @@ -from datetime import datetime -import time -from typing import Dict, Union - -from google.cloud.dataproc_v1 import ( - Batch, - BatchControllerClient, - CreateBatchRequest, - GetBatchRequest, -) -from google.protobuf.json_format import ParseDict - -from dbt.adapters.bigquery.credentials import DataprocBatchConfig - - -_BATCH_RUNNING_STATES = [Batch.State.PENDING, Batch.State.RUNNING] -DEFAULT_JAR_FILE_URI = "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.13-0.34.0.jar" - - -def create_batch_request( - batch: Batch, batch_id: str, project: str, region: str -) -> CreateBatchRequest: - return CreateBatchRequest( - parent=f"projects/{project}/locations/{region}", - batch_id=batch_id, - batch=batch, - ) - - -def poll_batch_job( - parent: str, batch_id: str, job_client: BatchControllerClient, timeout: int -) -> Batch: - batch_name = "".join([parent, "/batches/", batch_id]) - state = Batch.State.PENDING - response = None - run_time = 0 - while state in _BATCH_RUNNING_STATES and run_time < timeout: - time.sleep(1) - response = job_client.get_batch( - request=GetBatchRequest(name=batch_name), - ) - run_time = datetime.now().timestamp() - response.create_time.timestamp() - state = response.state - if not response: - raise ValueError("No response from Dataproc") - if state != Batch.State.SUCCEEDED: - if run_time >= timeout: - raise ValueError( - f"Operation did not complete within the designated timeout of {timeout} seconds." - ) - else: - raise ValueError(response.state_message) - return response - - -def update_batch_from_config(config_dict: Union[Dict, DataprocBatchConfig], target: Batch): - try: - # updates in place - ParseDict(config_dict, target._pb) - except Exception as e: - docurl = ( - "https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1" - "#google.cloud.dataproc.v1.Batch" - ) - raise ValueError( - f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}" - ) from e - return target diff --git a/dbt/adapters/bigquery/impl.py b/dbt/adapters/bigquery/impl.py index f6470e7f7..51c457129 100644 --- a/dbt/adapters/bigquery/impl.py +++ b/dbt/adapters/bigquery/impl.py @@ -1,9 +1,7 @@ from dataclasses import dataclass from datetime import datetime -import json from multiprocessing.context import SpawnContext import threading -import time from typing import ( Any, Dict, @@ -22,7 +20,7 @@ import google.auth import google.oauth2 import google.cloud.bigquery -from google.cloud.bigquery import AccessEntry, SchemaField, Table as BigQueryTable +from google.cloud.bigquery import AccessEntry, Client, SchemaField, Table as BigQueryTable import google.cloud.exceptions import pytz @@ -454,22 +452,6 @@ def get_columns_in_select_sql(self, select_sql: str) -> List[BigQueryColumn]: logger.debug("get_columns_in_select_sql error: {}".format(e)) return [] - @classmethod - def poll_until_job_completes(cls, job, timeout): - retry_count = timeout - - while retry_count > 0 and job.state != "DONE": - retry_count -= 1 - time.sleep(1) - job.reload() - - if job.state != "DONE": - raise dbt_common.exceptions.DbtRuntimeError("BigQuery Timeout Exceeded") - - elif job.error_result: - message = "\n".join(error["message"].strip() for error in job.errors) - raise dbt_common.exceptions.DbtRuntimeError(message) - def _bq_table_to_relation(self, bq_table) -> Union[BigQueryRelation, None]: if bq_table is None: return None @@ -669,55 +651,50 @@ def alter_table_add_columns(self, relation, columns): @available.parse_none def load_dataframe( self, - database, - schema, - table_name, + database: str, + schema: str, + table_name: str, agate_table: "agate.Table", - column_override, - field_delimiter, - ): - bq_schema = self._agate_to_schema(agate_table, column_override) - conn = self.connections.get_thread_connection() - client = conn.handle - - table_ref = self.connections.table_ref(database, schema, table_name) - - load_config = google.cloud.bigquery.LoadJobConfig() - load_config.skip_leading_rows = 1 - load_config.schema = bq_schema - load_config.field_delimiter = field_delimiter - job_id = self.connections.generate_job_id() - with open(agate_table.original_abspath, "rb") as f: # type: ignore - job = client.load_table_from_file( - f, table_ref, rewind=True, job_config=load_config, job_id=job_id - ) - - timeout = self.connections.get_job_execution_timeout_seconds(conn) or 300 - with self.connections.exception_handler("LOAD TABLE"): - self.poll_until_job_completes(job, timeout) + column_override: Dict[str, str], + field_delimiter: str, + ) -> None: + connection = self.connections.get_thread_connection() + client: Client = connection.handle + table_schema = self._agate_to_schema(agate_table, column_override) + file_path = agate_table.original_abspath # type: ignore + + self.connections.write_dataframe_to_table( + client, + file_path, + database, + schema, + table_name, + table_schema, + field_delimiter, + fallback_timeout=300, + ) @available.parse_none def upload_file( - self, local_file_path: str, database: str, table_schema: str, table_name: str, **kwargs + self, + local_file_path: str, + database: str, + table_schema: str, + table_name: str, + **kwargs, ) -> None: - conn = self.connections.get_thread_connection() - client = conn.handle - - table_ref = self.connections.table_ref(database, table_schema, table_name) - - load_config = google.cloud.bigquery.LoadJobConfig() - for k, v in kwargs["kwargs"].items(): - if k == "schema": - setattr(load_config, k, json.loads(v)) - else: - setattr(load_config, k, v) - - with open(local_file_path, "rb") as f: - job = client.load_table_from_file(f, table_ref, rewind=True, job_config=load_config) - - timeout = self.connections.get_job_execution_timeout_seconds(conn) or 300 - with self.connections.exception_handler("LOAD TABLE"): - self.poll_until_job_completes(job, timeout) + connection = self.connections.get_thread_connection() + client: Client = connection.handle + + self.connections.write_file_to_table( + client, + local_file_path, + database, + table_schema, + table_name, + fallback_timeout=300, + **kwargs, + ) @classmethod def _catalog_filter_table( @@ -753,7 +730,7 @@ def calculate_freshness_from_metadata( macro_resolver: Optional[MacroResolverProtocol] = None, ) -> Tuple[Optional[AdapterResponse], FreshnessResponse]: conn = self.connections.get_thread_connection() - client: google.cloud.bigquery.Client = conn.handle + client: Client = conn.handle table_ref = self.get_table_ref_from_relation(source) table = client.get_table(table_ref) diff --git a/dbt/adapters/bigquery/python_submissions.py b/dbt/adapters/bigquery/python_submissions.py index 93c82ca92..cd7f7d86f 100644 --- a/dbt/adapters/bigquery/python_submissions.py +++ b/dbt/adapters/bigquery/python_submissions.py @@ -1,187 +1,165 @@ -import uuid from typing import Dict, Union +import uuid -from google.api_core import retry -from google.api_core.client_options import ClientOptions -from google.api_core.future.polling import POLLING_PREDICATE -from google.cloud import storage, dataproc_v1 -from google.cloud.dataproc_v1.types.batches import Batch +from google.cloud.dataproc_v1 import Batch, CreateBatchRequest, Job, RuntimeConfig from dbt.adapters.base import PythonJobHelper from dbt.adapters.events.logging import AdapterLogger +from google.protobuf.json_format import ParseDict -from dbt.adapters.bigquery.connections import BigQueryConnectionManager -from dbt.adapters.bigquery.credentials import BigQueryCredentials -from dbt.adapters.bigquery.dataproc.batch import ( - DEFAULT_JAR_FILE_URI, - create_batch_request, - poll_batch_job, - update_batch_from_config, +from dbt.adapters.bigquery.credentials import BigQueryCredentials, DataprocBatchConfig +from dbt.adapters.bigquery.clients import ( + create_dataproc_batch_controller_client, + create_dataproc_job_controller_client, + create_gcs_client, ) +from dbt.adapters.bigquery.retry import RetryFactory + -OPERATION_RETRY_TIME = 10 -logger = AdapterLogger("BigQuery") +_logger = AdapterLogger("BigQuery") -class BaseDataProcHelper(PythonJobHelper): - def __init__(self, parsed_model: Dict, credential: BigQueryCredentials) -> None: - """_summary_ +_DEFAULT_JAR_FILE_URI = "gs://spark-lib/bigquery/spark-bigquery-with-dependencies_2.13-0.34.0.jar" - Args: - credential (_type_): _description_ - """ + +class _BaseDataProcHelper(PythonJobHelper): + def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: # validate all additional stuff for python is set - schema = parsed_model["schema"] - identifier = parsed_model["alias"] - self.parsed_model = parsed_model - python_required_configs = [ - "dataproc_region", - "gcs_bucket", - ] - for required_config in python_required_configs: - if not getattr(credential, required_config): + for required_config in ["dataproc_region", "gcs_bucket"]: + if not getattr(credentials, required_config): raise ValueError( f"Need to supply {required_config} in profile to submit python job" ) - self.model_file_name = f"{schema}/{identifier}.py" - self.credential = credential - self.GoogleCredentials = BigQueryConnectionManager.get_credentials(credential) - self.storage_client = storage.Client( - project=self.credential.execution_project, credentials=self.GoogleCredentials - ) - self.gcs_location = "gs://{}/{}".format(self.credential.gcs_bucket, self.model_file_name) + + self._storage_client = create_gcs_client(credentials) + self._project = credentials.execution_project + self._region = credentials.dataproc_region + + schema = parsed_model["schema"] + identifier = parsed_model["alias"] + self._model_file_name = f"{schema}/{identifier}.py" + self._gcs_bucket = credentials.gcs_bucket + self._gcs_path = f"gs://{credentials.gcs_bucket}/{self._model_file_name}" # set retry policy, default to timeout after 24 hours - self.timeout = self.parsed_model["config"].get( - "timeout", self.credential.job_execution_timeout_seconds or 60 * 60 * 24 - ) - self.result_polling_policy = retry.Retry( - predicate=POLLING_PREDICATE, maximum=10.0, timeout=self.timeout - ) - self.client_options = ClientOptions( - api_endpoint="{}-dataproc.googleapis.com:443".format(self.credential.dataproc_region) + retry = RetryFactory(credentials) + self._polling_retry = retry.create_polling( + model_timeout=parsed_model["config"].get("timeout") ) - self.job_client = self._get_job_client() - def _upload_to_gcs(self, filename: str, compiled_code: str) -> None: - bucket = self.storage_client.get_bucket(self.credential.gcs_bucket) - blob = bucket.blob(filename) + def _write_to_gcs(self, compiled_code: str) -> None: + bucket = self._storage_client.get_bucket(self._gcs_bucket) + blob = bucket.blob(self._model_file_name) blob.upload_from_string(compiled_code) - def submit(self, compiled_code: str) -> dataproc_v1.types.jobs.Job: - # upload python file to GCS - self._upload_to_gcs(self.model_file_name, compiled_code) - # submit dataproc job - return self._submit_dataproc_job() - - def _get_job_client( - self, - ) -> Union[dataproc_v1.JobControllerClient, dataproc_v1.BatchControllerClient]: - raise NotImplementedError("_get_job_client not implemented") - - def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job: - raise NotImplementedError("_submit_dataproc_job not implemented") +class ClusterDataprocHelper(_BaseDataProcHelper): + def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: + super().__init__(parsed_model, credentials) + self._job_controller_client = create_dataproc_job_controller_client(credentials) + self._cluster_name = parsed_model["config"].get( + "dataproc_cluster_name", credentials.dataproc_cluster_name + ) -class ClusterDataprocHelper(BaseDataProcHelper): - def _get_job_client(self) -> dataproc_v1.JobControllerClient: - if not self._get_cluster_name(): + if not self._cluster_name: raise ValueError( "Need to supply dataproc_cluster_name in profile or config to submit python job with cluster submission method" ) - return dataproc_v1.JobControllerClient( - client_options=self.client_options, credentials=self.GoogleCredentials - ) - def _get_cluster_name(self) -> str: - return self.parsed_model["config"].get( - "dataproc_cluster_name", self.credential.dataproc_cluster_name - ) + def submit(self, compiled_code: str) -> Job: + _logger.debug(f"Submitting cluster job to: {self._cluster_name}") - def _submit_dataproc_job(self) -> dataproc_v1.types.jobs.Job: - job = { - "placement": {"cluster_name": self._get_cluster_name()}, - "pyspark_job": { - "main_python_file_uri": self.gcs_location, + self._write_to_gcs(compiled_code) + + request = { + "project_id": self._project, + "region": self._region, + "job": { + "placement": {"cluster_name": self._cluster_name}, + "pyspark_job": { + "main_python_file_uri": self._gcs_path, + }, }, } - operation = self.job_client.submit_job_as_operation( - request={ - "project_id": self.credential.execution_project, - "region": self.credential.dataproc_region, - "job": job, - } - ) - # check if job failed - response = operation.result(polling=self.result_polling_policy) + + # submit the job + operation = self._job_controller_client.submit_job_as_operation(request) + + # wait for the job to complete + response: Job = operation.result(polling=self._polling_retry) + if response.status.state == 6: raise ValueError(response.status.details) + return response -class ServerlessDataProcHelper(BaseDataProcHelper): - def _get_job_client(self) -> dataproc_v1.BatchControllerClient: - return dataproc_v1.BatchControllerClient( - client_options=self.client_options, credentials=self.GoogleCredentials - ) +class ServerlessDataProcHelper(_BaseDataProcHelper): + def __init__(self, parsed_model: Dict, credentials: BigQueryCredentials) -> None: + super().__init__(parsed_model, credentials) + self._batch_controller_client = create_dataproc_batch_controller_client(credentials) + self._batch_id = parsed_model["config"].get("batch_id", str(uuid.uuid4())) + self._jar_file_uri = parsed_model["config"].get("jar_file_uri", _DEFAULT_JAR_FILE_URI) + self._dataproc_batch = credentials.dataproc_batch - def _get_batch_id(self) -> str: - model = self.parsed_model - default_batch_id = str(uuid.uuid4()) - return model["config"].get("batch_id", default_batch_id) - - def _submit_dataproc_job(self) -> Batch: - batch_id = self._get_batch_id() - logger.info(f"Submitting batch job with id: {batch_id}") - request = create_batch_request( - batch=self._configure_batch(), - batch_id=batch_id, - region=self.credential.dataproc_region, # type: ignore - project=self.credential.execution_project, # type: ignore - ) - # make the request - self.job_client.create_batch(request=request) - return poll_batch_job( - parent=request.parent, - batch_id=batch_id, - job_client=self.job_client, - timeout=self.timeout, + def submit(self, compiled_code: str) -> Batch: + _logger.debug(f"Submitting batch job with id: {self._batch_id}") + + self._write_to_gcs(compiled_code) + + request = CreateBatchRequest( + parent=f"projects/{self._project}/locations/{self._region}", + batch=self._create_batch(), + batch_id=self._batch_id, ) - # there might be useful results here that we can parse and return - # Dataproc job output is saved to the Cloud Storage bucket - # allocated to the job. Use regex to obtain the bucket and blob info. - # matches = re.match("gs://(.*?)/(.*)", response.driver_output_resource_uri) - # output = ( - # self.storage_client - # .get_bucket(matches.group(1)) - # .blob(f"{matches.group(2)}.000000000") - # .download_as_string() - # ) - - def _configure_batch(self): + + # submit the batch + operation = self._batch_controller_client.create_batch(request) + + # wait for the batch to complete + response: Batch = operation.result(polling=self._polling_retry) + + return response + + def _create_batch(self) -> Batch: # create the Dataproc Serverless job config # need to pin dataproc version to 1.1 as it now defaults to 2.0 # https://cloud.google.com/dataproc-serverless/docs/concepts/properties # https://cloud.google.com/dataproc-serverless/docs/reference/rest/v1/projects.locations.batches#runtimeconfig - batch = dataproc_v1.Batch( + batch = Batch( { - "runtime_config": dataproc_v1.RuntimeConfig( + "runtime_config": RuntimeConfig( version="1.1", properties={ "spark.executor.instances": "2", }, - ) + ), + "pyspark_batch": { + "main_python_file_uri": self._gcs_path, + "jar_file_uris": [self._jar_file_uri], + }, } ) - # Apply defaults - batch.pyspark_batch.main_python_file_uri = self.gcs_location - jar_file_uri = self.parsed_model["config"].get( - "jar_file_uri", - DEFAULT_JAR_FILE_URI, - ) - batch.pyspark_batch.jar_file_uris = [jar_file_uri] # Apply configuration from dataproc_batch key, possibly overriding defaults. - if self.credential.dataproc_batch: - batch = update_batch_from_config(self.credential.dataproc_batch, batch) + if self._dataproc_batch: + batch = _update_batch_from_config(self._dataproc_batch, batch) + return batch + + +def _update_batch_from_config( + config_dict: Union[Dict, DataprocBatchConfig], target: Batch +) -> Batch: + try: + # updates in place + ParseDict(config_dict, target._pb) + except Exception as e: + docurl = ( + "https://cloud.google.com/dataproc-serverless/docs/reference/rpc/google.cloud.dataproc.v1" + "#google.cloud.dataproc.v1.Batch" + ) + raise ValueError( + f"Unable to parse dataproc_batch as valid batch specification. See {docurl}. {str(e)}" + ) from e + return target diff --git a/dbt/adapters/bigquery/retry.py b/dbt/adapters/bigquery/retry.py new file mode 100644 index 000000000..391c00e46 --- /dev/null +++ b/dbt/adapters/bigquery/retry.py @@ -0,0 +1,128 @@ +from typing import Callable, Optional + +from google.api_core.exceptions import Forbidden +from google.api_core.future.polling import DEFAULT_POLLING +from google.api_core.retry import Retry +from google.cloud.bigquery.retry import DEFAULT_RETRY +from google.cloud.exceptions import BadGateway, BadRequest, ServerError +from requests.exceptions import ConnectionError + +from dbt.adapters.contracts.connection import Connection, ConnectionState +from dbt.adapters.events.logging import AdapterLogger +from dbt.adapters.exceptions.connection import FailedToConnectError + +from dbt.adapters.bigquery.clients import create_bigquery_client +from dbt.adapters.bigquery.credentials import BigQueryCredentials + + +_logger = AdapterLogger("BigQuery") + + +_SECOND = 1.0 +_MINUTE = 60 * _SECOND +_HOUR = 60 * _MINUTE +_DAY = 24 * _HOUR +_DEFAULT_INITIAL_DELAY = _SECOND +_DEFAULT_MAXIMUM_DELAY = 3 * _SECOND +_DEFAULT_POLLING_MAXIMUM_DELAY = 10 * _SECOND + + +class RetryFactory: + + def __init__(self, credentials: BigQueryCredentials) -> None: + self._retries = credentials.job_retries or 0 + self._job_creation_timeout = credentials.job_creation_timeout_seconds + self._job_execution_timeout = credentials.job_execution_timeout_seconds + self._job_deadline = credentials.job_retry_deadline_seconds + + def create_job_creation_timeout(self, fallback: float = _MINUTE) -> float: + return ( + self._job_creation_timeout or fallback + ) # keep _MINUTE here so it's not overridden by passing fallback=None + + def create_job_execution_timeout(self, fallback: float = _DAY) -> float: + return ( + self._job_execution_timeout or fallback + ) # keep _DAY here so it's not overridden by passing fallback=None + + def create_retry(self, fallback: Optional[float] = None) -> Retry: + return DEFAULT_RETRY.with_timeout(self._job_execution_timeout or fallback or _DAY) + + def create_polling(self, model_timeout: Optional[float] = None) -> Retry: + return DEFAULT_POLLING.with_timeout(model_timeout or self._job_execution_timeout or _DAY) + + def create_reopen_with_deadline(self, connection: Connection) -> Retry: + """ + This strategy mimics what was accomplished with _retry_and_handle + """ + return Retry( + predicate=_DeferredException(self._retries), + initial=_DEFAULT_INITIAL_DELAY, + maximum=_DEFAULT_MAXIMUM_DELAY, + deadline=self._job_deadline, + on_error=_create_reopen_on_error(connection), + ) + + +class _DeferredException: + """ + Count ALL errors, not just retryable errors, up to a threshold. + Raise the next error, regardless of whether it is retryable. + """ + + def __init__(self, retries: int) -> None: + self._retries: int = retries + self._error_count = 0 + + def __call__(self, error: Exception) -> bool: + # exit immediately if the user does not want retries + if self._retries == 0: + return False + + # count all errors + self._error_count += 1 + + # if the error is retryable, and we haven't breached the threshold, log and continue + if _is_retryable(error) and self._error_count <= self._retries: + _logger.debug( + f"Retry attempt {self._error_count} of {self._retries} after error: {repr(error)}" + ) + return True + + # otherwise raise + return False + + +def _create_reopen_on_error(connection: Connection) -> Callable[[Exception], None]: + + def on_error(error: Exception): + if isinstance(error, (ConnectionResetError, ConnectionError)): + _logger.warning("Reopening connection after {!r}".format(error)) + connection.handle.close() + + try: + connection.handle = create_bigquery_client(connection.credentials) + connection.state = ConnectionState.OPEN + + except Exception as e: + _logger.debug( + f"""Got an error when attempting to create a bigquery " "client: '{e}'""" + ) + connection.handle = None + connection.state = ConnectionState.FAIL + raise FailedToConnectError(str(e)) + + return on_error + + +def _is_retryable(error: Exception) -> bool: + """Return true for errors that are unlikely to occur again if retried.""" + if isinstance( + error, (BadGateway, BadRequest, ConnectionError, ConnectionResetError, ServerError) + ): + return True + elif isinstance(error, Forbidden) and any( + e["reason"] == "rateLimitExceeded" for e in error.errors + ): + return True + return False diff --git a/dbt/adapters/bigquery/utility.py b/dbt/adapters/bigquery/utility.py index 557986b38..5914280a3 100644 --- a/dbt/adapters/bigquery/utility.py +++ b/dbt/adapters/bigquery/utility.py @@ -1,7 +1,5 @@ -import base64 -import binascii import json -from typing import Any, Optional, Union +from typing import Any, Optional import dbt_common.exceptions @@ -45,39 +43,3 @@ def sql_escape(string): if not isinstance(string, str): raise dbt_common.exceptions.CompilationError(f"cannot escape a non-string: {string}") return json.dumps(string)[1:-1] - - -def is_base64(s: Union[str, bytes]) -> bool: - """ - Checks if the given string or bytes object is valid Base64 encoded. - - Args: - s: The string or bytes object to check. - - Returns: - True if the input is valid Base64, False otherwise. - """ - - if isinstance(s, str): - # For strings, ensure they consist only of valid Base64 characters - if not s.isascii(): - return False - # Convert to bytes for decoding - s = s.encode("ascii") - - try: - # Use the 'validate' parameter to enforce strict Base64 decoding rules - base64.b64decode(s, validate=True) - return True - except TypeError: - return False - except binascii.Error: # Catch specific errors from the base64 module - return False - - -def base64_to_string(b): - return base64.b64decode(b).decode("utf-8") - - -def string_to_base64(s): - return base64.b64encode(s.encode("utf-8")) diff --git a/tests/conftest.py b/tests/conftest.py index 6dc9e6443..33f7f9d17 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,9 +1,9 @@ import pytest import os import json -from dbt.adapters.bigquery.utility import is_base64, base64_to_string +from dbt.adapters.bigquery.credentials import _is_base64, _base64_to_string -# Import the fuctional fixtures as a plugin +# Import the functional fixtures as a plugin # Note: fixtures with session scope need to be local pytest_plugins = ["dbt.tests.fixtures.project"] @@ -39,8 +39,8 @@ def oauth_target(): def service_account_target(): credentials_json_str = os.getenv("BIGQUERY_TEST_SERVICE_ACCOUNT_JSON").replace("'", "") - if is_base64(credentials_json_str): - credentials_json_str = base64_to_string(credentials_json_str) + if _is_base64(credentials_json_str): + credentials_json_str = _base64_to_string(credentials_json_str) credentials = json.loads(credentials_json_str) project_id = credentials.get("project_id") return { diff --git a/tests/functional/adapter/test_json_keyfile.py b/tests/functional/adapter/test_json_keyfile.py index 91e41a3f1..a5caaebdf 100644 --- a/tests/functional/adapter/test_json_keyfile.py +++ b/tests/functional/adapter/test_json_keyfile.py @@ -1,6 +1,11 @@ +import base64 import json import pytest -from dbt.adapters.bigquery.utility import string_to_base64, is_base64 +from dbt.adapters.bigquery.credentials import _is_base64 + + +def string_to_base64(s): + return base64.b64encode(s.encode("utf-8")) @pytest.fixture @@ -53,7 +58,7 @@ def test_valid_base64_strings(example_json_keyfile_b64): ] for s in valid_strings: - assert is_base64(s) is True + assert _is_base64(s) is True def test_valid_base64_bytes(example_json_keyfile_b64): @@ -65,7 +70,7 @@ def test_valid_base64_bytes(example_json_keyfile_b64): example_json_keyfile_b64, ] for s in valid_bytes: - assert is_base64(s) is True + assert _is_base64(s) is True def test_invalid_base64(example_json_keyfile): @@ -79,4 +84,4 @@ def test_invalid_base64(example_json_keyfile): example_json_keyfile, ] for s in invalid_inputs: - assert is_base64(s) is False + assert _is_base64(s) is False diff --git a/tests/unit/test_bigquery_adapter.py b/tests/unit/test_bigquery_adapter.py index ca3bfc24c..e57db9a62 100644 --- a/tests/unit/test_bigquery_adapter.py +++ b/tests/unit/test_bigquery_adapter.py @@ -203,7 +203,7 @@ def get_adapter(self, target) -> BigQueryAdapter: class TestBigQueryAdapterAcquire(BaseTestBigQueryAdapter): @patch( - "dbt.adapters.bigquery.credentials.get_bigquery_defaults", + "dbt.adapters.bigquery.credentials._create_bigquery_defaults", return_value=("credentials", "project_id"), ) @patch("dbt.adapters.bigquery.BigQueryConnectionManager.open", return_value=_bq_conn()) @@ -244,10 +244,12 @@ def test_acquire_connection_oauth_validations(self, mock_open_connection): mock_open_connection.assert_called_once() @patch( - "dbt.adapters.bigquery.credentials.get_bigquery_defaults", + "dbt.adapters.bigquery.credentials._create_bigquery_defaults", return_value=("credentials", "project_id"), ) - @patch("dbt.adapters.bigquery.BigQueryConnectionManager.open", return_value=_bq_conn()) + @patch( + "dbt.adapters.bigquery.connections.BigQueryConnectionManager.open", return_value=_bq_conn() + ) def test_acquire_connection_dataproc_serverless( self, mock_open_connection, mock_get_bigquery_defaults ): @@ -386,21 +388,20 @@ def test_cancel_open_connections_single(self): adapter.connections.thread_connections.update({key: master, 1: model}) self.assertEqual(len(list(adapter.cancel_open_connections())), 1) - @patch("dbt.adapters.bigquery.impl.google.api_core.client_options.ClientOptions") - @patch("dbt.adapters.bigquery.impl.google.auth.default") - @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") - def test_location_user_agent(self, mock_bq, mock_auth_default, MockClientOptions): + @patch("dbt.adapters.bigquery.clients.ClientOptions") + @patch("dbt.adapters.bigquery.credentials.default") + @patch("dbt.adapters.bigquery.clients.BigQueryClient") + def test_location_user_agent(self, MockClient, mock_auth_default, MockClientOptions): creds = MagicMock() mock_auth_default.return_value = (creds, MagicMock()) adapter = self.get_adapter("loc") connection = adapter.acquire_connection("dummy") - mock_client = mock_bq.Client mock_client_options = MockClientOptions.return_value - mock_client.assert_not_called() + MockClient.assert_not_called() connection.handle - mock_client.assert_called_once_with( + MockClient.assert_called_once_with( "dbt-unit-000000", creds, location="Luna Station", diff --git a/tests/unit/test_bigquery_connection_manager.py b/tests/unit/test_bigquery_connection_manager.py index 1c14100f6..d4c95792e 100644 --- a/tests/unit/test_bigquery_connection_manager.py +++ b/tests/unit/test_bigquery_connection_manager.py @@ -1,81 +1,59 @@ import json import unittest -from contextlib import contextmanager from requests.exceptions import ConnectionError from unittest.mock import patch, MagicMock, Mock, ANY import dbt.adapters +import google.cloud.bigquery from dbt.adapters.bigquery import BigQueryCredentials from dbt.adapters.bigquery import BigQueryRelation from dbt.adapters.bigquery.connections import BigQueryConnectionManager +from dbt.adapters.bigquery.retry import RetryFactory class TestBigQueryConnectionManager(unittest.TestCase): def setUp(self): - credentials = Mock(BigQueryCredentials) - profile = Mock(query_comment=None, credentials=credentials) - self.connections = BigQueryConnectionManager(profile=profile, mp_context=Mock()) + self.credentials = Mock(BigQueryCredentials) + self.credentials.method = "oauth" + self.credentials.job_retries = 1 + self.credentials.job_retry_deadline_seconds = 1 + self.credentials.scopes = tuple() - self.mock_client = Mock(dbt.adapters.bigquery.impl.google.cloud.bigquery.Client) - self.mock_connection = MagicMock() + self.mock_client = Mock(google.cloud.bigquery.Client) + self.mock_connection = MagicMock() self.mock_connection.handle = self.mock_client + self.mock_connection.credentials = self.credentials + self.connections = BigQueryConnectionManager( + profile=Mock(credentials=self.credentials, query_comment=None), + mp_context=Mock(), + ) self.connections.get_thread_connection = lambda: self.mock_connection - self.connections.get_job_retry_deadline_seconds = lambda x: None - self.connections.get_job_retries = lambda x: 1 - - @patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True) - def test_retry_and_handle(self, is_retryable): - self.connections.DEFAULT_MAXIMUM_DELAY = 2.0 - - @contextmanager - def dummy_handler(msg): - yield - - self.connections.exception_handler = dummy_handler - - class DummyException(Exception): - """Count how many times this exception is raised""" - - count = 0 - def __init__(self): - DummyException.count += 1 + @patch( + "dbt.adapters.bigquery.retry.create_bigquery_client", + return_value=Mock(google.cloud.bigquery.Client), + ) + def test_retry_connection_reset(self, mock_client_factory): + new_mock_client = mock_client_factory.return_value - def raiseDummyException(): - raise DummyException() + @self.connections._retry.create_reopen_with_deadline(self.mock_connection) + def generate_connection_reset_error(): + raise ConnectionResetError - with self.assertRaises(DummyException): - self.connections._retry_and_handle( - "some sql", Mock(credentials=Mock(retries=8)), raiseDummyException - ) - self.assertEqual(DummyException.count, 9) + assert self.mock_connection.handle is self.mock_client - @patch("dbt.adapters.bigquery.connections._is_retryable", return_value=True) - def test_retry_connection_reset(self, is_retryable): - self.connections.open = MagicMock() - self.connections.close = MagicMock() - self.connections.DEFAULT_MAXIMUM_DELAY = 2.0 - - @contextmanager - def dummy_handler(msg): - yield - - self.connections.exception_handler = dummy_handler - - def raiseConnectionResetError(): - raise ConnectionResetError("Connection broke") - - mock_conn = Mock(credentials=Mock(retries=1)) with self.assertRaises(ConnectionResetError): - self.connections._retry_and_handle("some sql", mock_conn, raiseConnectionResetError) - self.connections.close.assert_called_once_with(mock_conn) - self.connections.open.assert_called_once_with(mock_conn) + # this will always raise the error, we just want to test that the connection was reopening in between + generate_connection_reset_error() + + assert self.mock_connection.handle is new_mock_client + assert new_mock_client is not self.mock_client def test_is_retryable(self): - _is_retryable = dbt.adapters.bigquery.connections._is_retryable + _is_retryable = dbt.adapters.bigquery.retry._is_retryable exceptions = dbt.adapters.bigquery.impl.google.cloud.exceptions internal_server_error = exceptions.InternalServerError("code broke") bad_request_error = exceptions.BadRequest("code broke") @@ -104,29 +82,30 @@ def test_drop_dataset(self): self.mock_client.delete_table.assert_not_called() self.mock_client.delete_dataset.assert_called_once() - @patch("dbt.adapters.bigquery.impl.google.cloud.bigquery") - def test_query_and_results(self, mock_bq): + @patch("dbt.adapters.bigquery.connections.QueryJobConfig") + def test_query_and_results(self, MockQueryJobConfig): self.connections._query_and_results( - self.mock_client, + self.mock_connection, "sql", - {"job_param_1": "blah"}, + {"dry_run": True}, job_id=1, - job_creation_timeout=15, - job_execution_timeout=100, ) - mock_bq.QueryJobConfig.assert_called_once() + MockQueryJobConfig.assert_called_once() self.mock_client.query.assert_called_once_with( - query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15 + query="sql", + job_config=MockQueryJobConfig(), + job_id=1, + timeout=self.credentials.job_creation_timeout_seconds, ) def test_copy_bq_table_appends(self): self._copy_table(write_disposition=dbt.adapters.bigquery.impl.WRITE_APPEND) - args, kwargs = self.mock_client.copy_table.call_args self.mock_client.copy_table.assert_called_once_with( [self._table_ref("project", "dataset", "table1")], self._table_ref("project", "dataset", "table2"), job_config=ANY, + retry=ANY, ) args, kwargs = self.mock_client.copy_table.call_args self.assertEqual( @@ -140,6 +119,7 @@ def test_copy_bq_table_truncates(self): [self._table_ref("project", "dataset", "table1")], self._table_ref("project", "dataset", "table2"), job_config=ANY, + retry=ANY, ) args, kwargs = self.mock_client.copy_table.call_args self.assertEqual( @@ -161,7 +141,7 @@ def test_list_dataset_correctly_calls_lists_datasets(self): self.mock_client.list_datasets = mock_list_dataset result = self.connections.list_dataset("project") self.mock_client.list_datasets.assert_called_once_with( - project="project", max_results=10000 + project="project", max_results=10000, retry=ANY ) assert result == ["d1"] diff --git a/tests/unit/test_configure_dataproc_batch.py b/tests/unit/test_configure_dataproc_batch.py index f56aee129..6e5757589 100644 --- a/tests/unit/test_configure_dataproc_batch.py +++ b/tests/unit/test_configure_dataproc_batch.py @@ -1,6 +1,6 @@ from unittest.mock import patch -from dbt.adapters.bigquery.dataproc.batch import update_batch_from_config +from dbt.adapters.bigquery.python_submissions import _update_batch_from_config from google.cloud import dataproc_v1 from .test_bigquery_adapter import BaseTestBigQueryAdapter @@ -12,7 +12,7 @@ # parsed credentials class TestConfigureDataprocBatch(BaseTestBigQueryAdapter): @patch( - "dbt.adapters.bigquery.credentials.get_bigquery_defaults", + "dbt.adapters.bigquery.credentials._create_bigquery_defaults", return_value=("credentials", "project_id"), ) def test_update_dataproc_serverless_batch(self, mock_get_bigquery_defaults): @@ -39,7 +39,7 @@ def test_update_dataproc_serverless_batch(self, mock_get_bigquery_defaults): batch = dataproc_v1.Batch() - batch = update_batch_from_config(raw_batch_config, batch) + batch = _update_batch_from_config(raw_batch_config, batch) def to_str_values(d): """google's protobuf types expose maps as dict[str, str]""" @@ -64,7 +64,7 @@ def to_str_values(d): ) @patch( - "dbt.adapters.bigquery.credentials.get_bigquery_defaults", + "dbt.adapters.bigquery.credentials._create_bigquery_defaults", return_value=("credentials", "project_id"), ) def test_default_dataproc_serverless_batch(self, mock_get_bigquery_defaults):