Skip to content

Commit

Permalink
fix mocks in unit tests
Browse files Browse the repository at this point in the history
  • Loading branch information
mikealfare committed Nov 5, 2024
1 parent 81bfa0c commit 3e32872
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 31 deletions.
15 changes: 8 additions & 7 deletions dbt/adapters/bigquery/connections.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,6 @@
import google.auth
from google.auth import impersonated_credentials
import google.auth.exceptions
import google.cloud.bigquery
import google.cloud.exceptions
from google.oauth2 import (
credentials as GoogleCredentials,
service_account as GoogleServiceAccountCredentials,
)
from google.cloud.bigquery import (
Client,
CopyJobConfig,
Expand All @@ -27,6 +21,12 @@
TableReference,
WriteDisposition,
)
import google.cloud.exceptions
from google.oauth2 import (
credentials as GoogleCredentials,
service_account as GoogleServiceAccountCredentials,
)
from requests.exceptions import ConnectionError

from dbt_common.events.contextvars import get_node_info
from dbt_common.events.functions import fire_event
Expand Down Expand Up @@ -622,7 +622,8 @@ def _query_and_results(
client: Client = conn.handle
"""Query the client and wait for results."""
# Cannot reuse job_config if destination is set and ddl is used
job_config = QueryJobConfig(**job_params)
job_factory = QueryJobConfig
job_config = job_factory(**job_params)
query_job = client.query(
query=sql,
job_config=job_config,
Expand Down
4 changes: 2 additions & 2 deletions dbt/adapters/bigquery/retry.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from google.api_core import retry
from google.api_core.exceptions import Forbidden
from google.cloud.exceptions import BadGateway, BadRequest, ServerError
from requests.exceptions import ConnectionError

from dbt.adapters.events.logging import AdapterLogger

from dbt.adapters.bigquery.connections import logger
from dbt.adapters.bigquery.credentials import BigQueryCredentials


Expand Down Expand Up @@ -124,7 +124,7 @@ def count_error(self, error):
return False # Don't log
self._error_count += 1
if _is_retryable(error) and self._error_count <= self._retries:
logger.debug(
_logger.debug(
"Retry attempt {} of {} after error: {}".format(
self._error_count, self._retries, repr(error)
)
Expand Down
13 changes: 6 additions & 7 deletions tests/unit/test_bigquery_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -386,21 +386,20 @@ def test_cancel_open_connections_single(self):
adapter.connections.thread_connections.update({key: master, 1: model})
self.assertEqual(len(list(adapter.cancel_open_connections())), 1)

@patch("dbt.adapters.bigquery.impl.google.api_core.client_options.ClientOptions")
@patch("dbt.adapters.bigquery.impl.google.auth.default")
@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
def test_location_user_agent(self, mock_bq, mock_auth_default, MockClientOptions):
@patch("dbt.adapters.bigquery.connections.client_options.ClientOptions")
@patch("dbt.adapters.bigquery.credentials.google.auth.default")
@patch("dbt.adapters.bigquery.connections.Client")
def test_location_user_agent(self, MockClient, mock_auth_default, MockClientOptions):
creds = MagicMock()
mock_auth_default.return_value = (creds, MagicMock())
adapter = self.get_adapter("loc")

connection = adapter.acquire_connection("dummy")
mock_client = mock_bq.Client
mock_client_options = MockClientOptions.return_value

mock_client.assert_not_called()
MockClient.assert_not_called()
connection.handle
mock_client.assert_called_once_with(
MockClient.assert_called_once_with(
"dbt-unit-000000",
creds,
location="Luna Station",
Expand Down
45 changes: 30 additions & 15 deletions tests/unit/test_bigquery_connection_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,27 @@
from unittest.mock import patch, MagicMock, Mock, ANY

import dbt.adapters
import google.cloud.bigquery

from dbt.adapters.bigquery import BigQueryCredentials
from dbt.adapters.bigquery import BigQueryRelation
from dbt.adapters.bigquery.connections import BigQueryConnectionManager
from dbt.adapters.bigquery.retry import RetryFactory


class TestBigQueryConnectionManager(unittest.TestCase):
def setUp(self):
credentials = Mock(BigQueryCredentials)
profile = Mock(query_comment=None, credentials=credentials)
self.credentials = Mock(BigQueryCredentials)
self.credentials.job_retries = 1
profile = Mock(query_comment=None, credentials=self.credentials)
self.connections = BigQueryConnectionManager(profile=profile, mp_context=Mock())

self.mock_client = Mock(dbt.adapters.bigquery.impl.google.cloud.bigquery.Client)
self.mock_client = Mock(google.cloud.bigquery.Client)
self.mock_connection = MagicMock()

self.mock_connection.handle = self.mock_client

self.connections.get_thread_connection = lambda: self.mock_connection
self.connections.get_job_retries = lambda x: 1

@patch("dbt.adapters.bigquery.retry._is_retryable", return_value=True)
def test_retry_connection_reset(self, is_retryable):
Expand All @@ -37,8 +39,18 @@ def dummy_handler(msg):

self.connections.exception_handler = dummy_handler

mock_conn = Mock(credentials=Mock(retries=1))
# do something that will raise a ConnectionResetError
retry = RetryFactory(Mock(job_retries=1, job_execution_timeout_seconds=60))
mock_conn = Mock()

on_error = self.connections._reopen_on_error(mock_conn)

@retry.job_execution(on_error)
def generate_connection_reset_error():
raise ConnectionResetError

with self.assertRaises(ConnectionResetError):
# this will always raise the error, we just want to test that the connection was reopening in between
generate_connection_reset_error()
self.connections.close.assert_called_once_with(mock_conn)
self.connections.open.assert_called_once_with(mock_conn)

Expand Down Expand Up @@ -72,20 +84,21 @@ def test_drop_dataset(self):
self.mock_client.delete_table.assert_not_called()
self.mock_client.delete_dataset.assert_called_once()

@patch("dbt.adapters.bigquery.impl.google.cloud.bigquery")
def test_query_and_results(self, mock_bq):
@patch("dbt.adapters.bigquery.connections.QueryJobConfig")
def test_query_and_results(self, MockQueryJobConfig):
self.connections._query_and_results(
self.mock_client,
self.mock_connection,
"sql",
{"job_param_1": "blah"},
{"dry_run": True},
job_id=1,
job_creation_timeout=15,
job_execution_timeout=100,
)

mock_bq.QueryJobConfig.assert_called_once()
MockQueryJobConfig.assert_called_once()
self.mock_client.query.assert_called_once_with(
query="sql", job_config=mock_bq.QueryJobConfig(), job_id=1, timeout=15
query="sql",
job_config=MockQueryJobConfig(),
job_id=1,
timeout=self.credentials.job_creation_timeout_seconds,
)

def test_copy_bq_table_appends(self):
Expand All @@ -95,6 +108,7 @@ def test_copy_bq_table_appends(self):
[self._table_ref("project", "dataset", "table1")],
self._table_ref("project", "dataset", "table2"),
job_config=ANY,
retry=ANY,
)
args, kwargs = self.mock_client.copy_table.call_args
self.assertEqual(
Expand All @@ -108,6 +122,7 @@ def test_copy_bq_table_truncates(self):
[self._table_ref("project", "dataset", "table1")],
self._table_ref("project", "dataset", "table2"),
job_config=ANY,
retry=ANY,
)
args, kwargs = self.mock_client.copy_table.call_args
self.assertEqual(
Expand All @@ -129,7 +144,7 @@ def test_list_dataset_correctly_calls_lists_datasets(self):
self.mock_client.list_datasets = mock_list_dataset
result = self.connections.list_dataset("project")
self.mock_client.list_datasets.assert_called_once_with(
project="project", max_results=10000
project="project", max_results=10000, retry=ANY
)
assert result == ["d1"]

Expand Down

0 comments on commit 3e32872

Please sign in to comment.