Skip to content

Commit

Permalink
Improve and fix tests (#210)
Browse files Browse the repository at this point in the history
Fixes a few failing tests, which were not working properly after other refactoring. Also improves code formatting.
  • Loading branch information
jm-rivera authored Jan 22, 2025
1 parent 57ea77c commit 91bfe23
Show file tree
Hide file tree
Showing 5 changed files with 141 additions and 68 deletions.
157 changes: 115 additions & 42 deletions datacommons_client/tests/endpoints/test_base.py
Original file line number Diff line number Diff line change
@@ -1,47 +1,64 @@
from unittest.mock import patch

import pytest

from datacommons_client.endpoints.base import API
from datacommons_client.endpoints.base import Endpoint
import pytest


@patch("datacommons_client.endpoints.base.build_headers")
@patch("datacommons_client.endpoints.base.resolve_instance_url")
@patch(
"datacommons_client.endpoints.base.build_headers",
return_value={"Content-Type": "application/json"},
)
@patch(
"datacommons_client.endpoints.base.resolve_instance_url",
return_value="https://api.datacommons.org/v2",
)
@patch(
"datacommons_client.endpoints.base.check_instance_is_valid",
return_value="https://api.datacommons.org/v2",
)
def test_api_initialization_default(
mock_resolve_instance_url, mock_build_headers
mock_check_instance, mock_resolve_instance, mock_build_headers
):
"""Tests default API initialization with `datacommons.org` instance."""
mock_resolve_instance_url.return_value = "https://api.datacommons.org/v2"
mock_build_headers.return_value = {"Content-Type": "application/json"}

api = API()

assert api.base_url == "https://api.datacommons.org/v2"
assert api.headers == {"Content-Type": "application/json"}
mock_resolve_instance_url.assert_called_once_with("datacommons.org")
mock_resolve_instance.assert_called_once_with("datacommons.org")
mock_build_headers.assert_called_once_with(None)


@patch("datacommons_client.endpoints.base.build_headers")
def test_api_initialization_with_url(mock_build_headers):
@patch(
"datacommons_client.endpoints.base.build_headers",
return_value={"Content-Type": "application/json"},
)
@patch(
"datacommons_client.endpoints.base.check_instance_is_valid",
return_value="https://custom_instance.api/v2",
)
def test_api_initialization_with_url(mock_check_instance, mock_build_headers):
"""Tests API initialization with a fully qualified URL."""
mock_build_headers.return_value = {"Content-Type": "application/json"}

api = API(url="https://custom_instance.api/v2")
assert api.base_url == "https://custom_instance.api/v2"
assert api.headers == {"Content-Type": "application/json"}
mock_check_instance.assert_called_once_with(
"https://custom_instance.api/v2"
)


@patch("datacommons_client.endpoints.base.build_headers")
@patch("datacommons_client.endpoints.base.resolve_instance_url")
@patch(
"datacommons_client.endpoints.base.resolve_instance_url",
return_value="https://custom-instance/api/v2",
)
@patch(
"datacommons_client.endpoints.base.build_headers",
return_value={"Content-Type": "application/json"},
)
def test_api_initialization_with_dc_instance(
mock_resolve_instance_url, mock_build_headers
mock_build_headers, mock_resolve_instance_url
):
"""Tests API initialization with a custom Data Commons instance."""
mock_resolve_instance_url.return_value = "https://custom-instance/api/v2"
mock_build_headers.return_value = {"Content-Type": "application/json"}

api = API(dc_instance="custom-instance")

assert api.base_url == "https://custom-instance/api/v2"
Expand All @@ -55,22 +72,16 @@ def test_api_initialization_invalid_args():
API(dc_instance="custom-instance", url="https://custom.api/v2")


def test_api_repr():
"""Tests the string representation of the API object."""
api = API(url="https://custom_instance.api/v2", api_key="test-key")
assert (
repr(api) == "<API at https://custom_instance.api/v2 (Authenticated)>"
)

api = API(url="https://custom_instance.api/v2")
assert repr(api) == "<API at https://custom_instance.api/v2>"


@patch("datacommons_client.endpoints.base.post_request")
def test_api_post_request(mock_post_request):
@patch(
"datacommons_client.endpoints.base.post_request",
return_value={"success": True},
)
@patch(
"datacommons_client.endpoints.base.check_instance_is_valid",
return_value="https://custom_instance.api/v2",
)
def test_api_post_request(mock_check_instance, mock_post_request):
"""Tests making a POST request using the API object."""
mock_post_request.return_value = {"success": True}

api = API(url="https://custom_instance.api/v2")
payload = {"key": "value"}

Expand All @@ -83,15 +94,23 @@ def test_api_post_request(mock_post_request):
)


def test_api_post_request_invalid_payload():
@patch(
"datacommons_client.endpoints.base.check_instance_is_valid",
return_value="https://custom_instance.api/v2",
)
def test_api_post_request_invalid_payload(mock_check_instance):
"""Tests that an invalid payload raises a ValueError."""
api = API(url="https://custom_instance.api/v2")

with pytest.raises(ValueError):
api.post(payload=["invalid", "payload"], endpoint="test-endpoint")


def test_endpoint_initialization():
@patch(
"datacommons_client.endpoints.base.check_instance_is_valid",
return_value="https://custom_instance.api/v2",
)
def test_endpoint_initialization(mock_check_instance):
"""Tests initializing an Endpoint with a valid API instance."""
api = API(url="https://custom_instance.api/v2")
endpoint = Endpoint(endpoint="node", api=api)
Expand All @@ -100,7 +119,11 @@ def test_endpoint_initialization():
assert endpoint.api is api


def test_endpoint_repr():
@patch(
"datacommons_client.endpoints.base.check_instance_is_valid",
return_value="https://custom.api/v2",
)
def test_endpoint_repr(mock_check_instance):
"""Tests the string representation of the Endpoint object."""
api = API(url="https://custom.api/v2")
endpoint = Endpoint(endpoint="node", api=api)
Expand All @@ -110,11 +133,16 @@ def test_endpoint_repr():
)


@patch("datacommons_client.endpoints.base.post_request")
def test_endpoint_post_request(mock_post_request):
@patch(
"datacommons_client.endpoints.base.post_request",
return_value={"success": True},
)
@patch(
"datacommons_client.endpoints.base.check_instance_is_valid",
return_value="https://custom.api/v2",
)
def test_endpoint_post_request(mock_check_instance, mock_post_request):
"""Tests making a POST request using the Endpoint object."""
mock_post_request.return_value = {"success": True}

api = API(url="https://custom.api/v2")
endpoint = Endpoint(endpoint="node", api=api)
payload = {"key": "value"}
Expand All @@ -128,10 +156,55 @@ def test_endpoint_post_request(mock_post_request):
)


def test_endpoint_post_request_invalid_payload():
@patch(
"datacommons_client.endpoints.base.check_instance_is_valid",
return_value="https://custom.api/v2",
)
def test_endpoint_post_request_invalid_payload(mock_check_instance):
"""Tests that an invalid payload raises a ValueError in the Endpoint post method."""
api = API(url="https://custom.api/v2")
endpoint = Endpoint(endpoint="node", api=api)

with pytest.raises(ValueError):
endpoint.post(payload=["invalid", "payload"])


@patch(
"datacommons_client.endpoints.base.build_headers",
side_effect=lambda api_key: {"X-API-Key": api_key} if api_key else {},
)
@patch(
"datacommons_client.endpoints.base.check_instance_is_valid",
side_effect=lambda url: url.rstrip("/"),
)
def test_api_repr(mock_check_instance, mock_build_headers):
"""Tests the __repr__ method of the API class."""
# Without API key
api = API(url="https://custom.api/v2")
assert repr(api) == "<API at https://custom.api/v2>"

# With API key
api_with_key = API(url="https://custom.api/v2", api_key="test_key")
assert (
repr(api_with_key) == "<API at https://custom.api/v2 (Authenticated)>"
)

mock_build_headers.assert_any_call(None)
mock_build_headers.assert_any_call("test_key")


@patch(
"datacommons_client.endpoints.base.build_headers",
return_value={"Content-Type": "application/json"},
)
@patch(
"datacommons_client.endpoints.base.check_instance_is_valid",
return_value="https://custom.api/v2",
)
def test_endpoint_repr(mock_check_instance, mock_build_headers):
"""Tests the __repr__ method of the Endpoint class."""
api = API(url="https://custom.api/v2")
endpoint = Endpoint(endpoint="node", api=api)

expected_repr = "<Node Endpoint using <API at https://custom.api/v2>>"
assert repr(endpoint) == expected_repr
13 changes: 6 additions & 7 deletions datacommons_client/tests/endpoints/test_error_handling.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,12 @@
from datacommons_client.utils.error_handling import APIError
from datacommons_client.utils.error_handling import DataCommonsError
from datacommons_client.utils.error_handling import DCAuthenticationError
from datacommons_client.utils.error_handling import DCConnectionError
from datacommons_client.utils.error_handling import DCStatusError
from datacommons_client.utils.error_handling import InvalidDCInstanceError
from requests import Request
from requests import Response

from datacommons_client.utils.error_hanlding import APIError
from datacommons_client.utils.error_hanlding import DataCommonsError
from datacommons_client.utils.error_hanlding import DCAuthenticationError
from datacommons_client.utils.error_hanlding import DCConnectionError
from datacommons_client.utils.error_hanlding import DCStatusError
from datacommons_client.utils.error_hanlding import InvalidDCInstanceError


def test_data_commons_error_default_message():
"""Tests that DataCommonsError uses the default message."""
Expand Down
4 changes: 3 additions & 1 deletion datacommons_client/tests/endpoints/test_payloads.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,9 @@ def test_node_payload_normalize():
payload = NodeRequestPayload(node_dcids="node1", expression="prop1")
assert payload.node_dcids == ["node1"]

payload = NodeRequestPayload(node_dcids=["node1", "node2"], expression="prop1")
payload = NodeRequestPayload(
node_dcids=["node1", "node2"], expression="prop1"
)
assert payload.node_dcids == ["node1", "node2"]


Expand Down
19 changes: 9 additions & 10 deletions datacommons_client/tests/endpoints/test_request_handling.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,21 @@
from unittest.mock import MagicMock
from unittest.mock import patch

import pytest
import requests

from datacommons_client.utils.error_hanlding import APIError
from datacommons_client.utils.error_hanlding import DCAuthenticationError
from datacommons_client.utils.error_hanlding import DCConnectionError
from datacommons_client.utils.error_hanlding import DCStatusError
from datacommons_client.utils.error_hanlding import InvalidDCInstanceError
from datacommons_client.utils.request_handling import check_instance_is_valid
from datacommons_client.utils.error_handling import APIError
from datacommons_client.utils.error_handling import DCAuthenticationError
from datacommons_client.utils.error_handling import DCConnectionError
from datacommons_client.utils.error_handling import DCStatusError
from datacommons_client.utils.error_handling import InvalidDCInstanceError
from datacommons_client.utils.request_handling import _fetch_with_pagination
from datacommons_client.utils.request_handling import _merge_values
from datacommons_client.utils.request_handling import _recursively_merge_dicts
from datacommons_client.utils.request_handling import _send_post_request
from datacommons_client.utils.request_handling import build_headers
from datacommons_client.utils.request_handling import check_instance_is_valid
from datacommons_client.utils.request_handling import post_request
from datacommons_client.utils.request_handling import resolve_instance_url
import pytest
import requests


def test_resolve_instance_url_default():
Expand Down Expand Up @@ -47,7 +46,7 @@ def test_send_post_request_connection_error(mock_post):
_send_post_request("https://api.test.com", {}, {})


@patch("datacommons_client.utils.request_handling._check_instance_is_valid")
@patch("datacommons_client.utils.request_handling.check_instance_is_valid")
def test_resolve_instance_url_custom(mock_check_instance_is_valid):
"""Tests resolving a custom Data Commons instance."""
mock_check_instance_is_valid.return_value = (
Expand Down
16 changes: 8 additions & 8 deletions datacommons_client/tests/endpoints/test_response.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,15 @@
from datacommons_client.endpoints.response import _unpack_arcs
from datacommons_client.endpoints.response import DCResponse
from datacommons_client.endpoints.response import extract_observations
from datacommons_client.endpoints.response import flatten_properties
from datacommons_client.endpoints.response import NodeResponse
from datacommons_client.endpoints.response import ObservationResponse
from datacommons_client.endpoints.response import ResolveResponse
from datacommons_client.endpoints.response import SparqlResponse
from datacommons_client.models.observation import Facet
from datacommons_client.models.observation import Observation
from datacommons_client.models.observation import OrderedFacets
from datacommons_client.models.observation import Variable
from datacommons_client.utils.response import _unpack_arcs
from datacommons_client.utils.response import DCResponse
from datacommons_client.utils.response import extract_observations
from datacommons_client.utils.response import flatten_properties
from datacommons_client.utils.response import NodeResponse
from datacommons_client.utils.response import ObservationResponse
from datacommons_client.utils.response import ResolveResponse
from datacommons_client.utils.response import SparqlResponse

### ----- Test DCResponse ----- ###

Expand Down

0 comments on commit 91bfe23

Please sign in to comment.