Skip to content

Commit

Permalink
fix windows tests and also more client tests
Browse files Browse the repository at this point in the history
  • Loading branch information
OliverSherouse committed Jan 22, 2024
1 parent c1d6746 commit 0993710
Show file tree
Hide file tree
Showing 5 changed files with 278 additions and 22 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,9 @@ jobs:
matrix:
python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"]
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
defaults:
run:
shell: bash
steps:
- name: Checkout
uses: actions/checkout@v4
Expand Down
242 changes: 241 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import datetime as dt
import itertools
import re
from unittest import mock

import pandas as pd
import pytest

from wbdata import client
from wbdata import client, fetcher


@pytest.mark.parametrize(
Expand Down Expand Up @@ -220,3 +222,241 @@ def test_get_country(
mock_client.fetcher.fetch.assert_called_once_with(
url=f"{client.COUNTRIES_URL}{path}", args=args, skip_cache=skip_cache
)


@pytest.mark.parametrize(
["kwargs"],
(
[{"country_id": "foo", "incomelevel": "bar"}],
[{"country_id": "foo", "lendingtype": "bar"}],
),
)
def test_get_country_bad(mock_client, kwargs):
with pytest.raises(ValueError, match=r"country_id and aggregates"):
mock_client.get_country(**kwargs)


@pytest.mark.parametrize(
("indicator", "source", "topic", "skip_cache", "expected_url"),
(
("foo", None, None, True, f"{client.INDICATOR_URL}/foo"),
(["foo", "bar"], None, None, False, f"{client.INDICATOR_URL}/foo;bar"),
(None, "foo", None, False, f"{client.SOURCES_URL}/foo/indicators"),
(None, ["foo", "bar"], None, True, f"{client.SOURCES_URL}/foo;bar/indicators"),
(None, None, "foo", False, f"{client.TOPIC_URL}/foo/indicators"),
(None, None, ["foo", "bar"], True, f"{client.TOPIC_URL}/foo;bar/indicators"),
),
)
def test_get_indicator(mock_client, indicator, source, topic, skip_cache, expected_url):
mock_client.fetcher.fetch = mock.Mock(return_value=[["foo"]])
got = mock_client.get_indicator(
indicator=indicator,
source=source,
topic=topic,
skip_cache=skip_cache,
)
assert list(got) == [["foo"]]
mock_client.fetcher.fetch.assert_called_once_with(
url=expected_url,
skip_cache=skip_cache,
)


@pytest.mark.parametrize(
("indicator", "source", "topic"),
(
("foo", "bar", None),
("foo", None, "baz"),
("foo", "bar", "baz"),
(None, "foo", "bar"),
),
)
def test_get_indicator_bad(mock_client, indicator, source, topic):
with pytest.raises(ValueError, match=client.INDIC_ERROR):
mock_client.get_indicator(indicator=indicator, source=source, topic=topic)


@pytest.mark.parametrize(
("raw", "query", "expected"),
(
(
[{"name": "United States"}, {"name": "Great Britain"}],
"states",
[{"name": "United States"}],
),
(
[{"name": "United States"}, {"name": "Great Britain"}],
re.compile("states"),
[],
),
),
)
def test_search_countries(mock_client, raw, query, expected):
with mock.patch.object(mock_client, "get_country") as mock_get_countries:
mock_get_countries.return_value = raw
got = mock_client.search_countries(query)
assert list(got) == expected


@pytest.mark.parametrize(
("raw", "query", "expected"),
(
(
[{"name": "United States"}, {"name": "Great Britain"}],
"states",
[{"name": "United States"}],
),
(
[{"name": "United States"}, {"name": "Great Britain"}],
re.compile("states"),
[],
),
),
)
def test_search_indicators(mock_client, raw, query, expected):
with mock.patch.object(mock_client, "get_indicator") as mock_get_indicators:
mock_get_indicators.return_value = raw
got = mock_client.search_indicators(query)
assert list(got) == expected


def test_get_series_passthrough(mock_client):
with mock.patch.object(mock_client, "get_data") as mock_get_data:
mock_get_data.return_value = fetcher.Result(
[{"country": {"value": "usa"}, "date": "2023", "value": "5"}]
)
kwargs = dict(
indicator="foo",
country="usa",
date="2023",
freq="Q",
source="2",
parse_dates=True,
skip_cache=True,
)
mock_client.get_series(**kwargs)

mock_get_data.assert_called_once_with(**kwargs)


@pytest.mark.parametrize(
["response", "keep_levels", "expected"],
(
pytest.param(
fetcher.Result(
[
{"country": {"value": "usa"}, "date": "2023", "value": "5"},
{"country": {"value": "usa"}, "date": "2024", "value": "6"},
{"country": {"value": "gbr"}, "date": "2023", "value": "7"},
{"country": {"value": "gbr"}, "date": "2024", "value": "8"},
]
),
True,
client.Series(
[5.0, 6.0, 7.0, 8.0],
index=pd.MultiIndex.from_tuples(
tuples=(
("usa", "2023"),
("usa", "2024"),
("gbr", "2023"),
("gbr", "2024"),
),
names=["country", "date"],
),
name="value",
),
id="multi-country, multi-date",
),
pytest.param(
fetcher.Result(
[
{"country": {"value": "usa"}, "date": "2023", "value": "5"},
{"country": {"value": "usa"}, "date": "2024", "value": "6"},
]
),
True,
client.Series(
[5.0, 6.0],
index=pd.MultiIndex.from_tuples(
tuples=(
("usa", "2023"),
("usa", "2024"),
),
names=["country", "date"],
),
name="value",
),
id="one-country, multi-date, keep_levels",
),
pytest.param(
fetcher.Result(
[
{"country": {"value": "usa"}, "date": "2023", "value": "5"},
{"country": {"value": "gbr"}, "date": "2023", "value": "7"},
]
),
True,
client.Series(
[5.0, 7.0],
index=pd.MultiIndex.from_tuples(
tuples=(
("usa", "2023"),
("gbr", "2023"),
),
names=["country", "date"],
),
name="value",
),
id="multi-country, one-date, keep_levels",
),
pytest.param(
fetcher.Result(
[
{"country": {"value": "usa"}, "date": "2023", "value": "5"},
{"country": {"value": "usa"}, "date": "2024", "value": "6"},
]
),
False,
client.Series(
[5.0, 6.0],
index=pd.Index(("2023", "2024"), name="date"),
name="value",
),
id="one-country, multi-date, no keep_levels",
),
pytest.param(
fetcher.Result(
[
{"country": {"value": "usa"}, "date": "2023", "value": "5"},
{"country": {"value": "gbr"}, "date": "2023", "value": "7"},
]
),
False,
client.Series(
[5.0, 7.0],
index=pd.Index(("usa", "gbr"), name="country"),
name="value",
),
id="multi-country, one-date, no keep_levels",
),
pytest.param(
fetcher.Result(
[
{"country": {"value": "usa"}, "date": "2023", "value": "5"},
]
),
False,
client.Series(
[5.0],
index=pd.Index(("2023",), name="date"),
name="value",
),
id="one-country, one-date, no keep_levels",
),
),
)
def test_get_series(mock_client, response, keep_levels, expected):
with mock.patch.object(mock_client, "get_data") as mock_get_data:
mock_get_data.return_value = response
got = mock_client.get_series("foo", keep_levels=keep_levels)
pd.testing.assert_series_equal(got, expected)
11 changes: 7 additions & 4 deletions tests/test_fetcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ def test_skip_cache(mock_fetcher):
[
[{"page": "1", "pages": "1"}, [{"hello": "there"}]],
],
([{"hello": "there"}], None),
fetcher.Result([{"hello": "there"}], last_updated=None),
id="No date",
),
pytest.param(
Expand All @@ -130,7 +130,7 @@ def test_skip_cache(mock_fetcher):
[{"hello": "there"}],
],
],
([{"hello": "there"}], dt.datetime(2023, 2, 1)),
fetcher.Result([{"hello": "there"}], last_updated=dt.datetime(2023, 2, 1)),
id="with date",
),
pytest.param(
Expand All @@ -146,7 +146,10 @@ def test_skip_cache(mock_fetcher):
[{"howare": "you"}],
],
],
([{"hello": "there"}, {"howare": "you"}], dt.datetime(2023, 2, 1)),
fetcher.Result(
[{"hello": "there"}, {"howare": "you"}],
last_updated=dt.datetime(2023, 2, 1),
),
id="paged with date",
),
pytest.param(
Expand All @@ -162,7 +165,7 @@ def test_skip_cache(mock_fetcher):
[{"howare": "you"}],
],
],
([{"hello": "there"}, {"howare": "you"}], None),
fetcher.Result([{"hello": "there"}, {"howare": "you"}], last_updated=None),
id="paged without date",
),
),
Expand Down
27 changes: 14 additions & 13 deletions wbdata/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -285,27 +285,27 @@ def get_indicator(self, indicator=None, source=None, topic=None, skip_cache=Fals
:returns: SearchResult containing dictionary objects representing
indicators
"""
if ((source or topic) and indicator) or (source and topic):
raise ValueError(INDIC_ERROR)
if indicator:
if source or topic:
raise ValueError(INDIC_ERROR)
url = "/".join((INDICATOR_URL, parse_value_or_iterable(indicator)))
elif source:
if topic:
raise ValueError(INDIC_ERROR)
url = "/".join((SOURCES_URL, parse_value_or_iterable(source), "indicators"))
elif topic:
url = "/".join((TOPIC_URL, parse_value_or_iterable(topic), "indicators"))
else:
url = INDICATOR_URL
return SearchResult(self.fetcher.fetch(url, skip_cache=skip_cache))
return SearchResult(self.fetcher.fetch(url=url, skip_cache=skip_cache))

def search_indicators(self, query, source=None, topic=None, skip_cache=False):
def search_indicators(
self, query: Union[str, re.Pattern], source=None, topic=None, skip_cache=False
):
"""
Search indicators for a certain regular expression. Only one of source or
topic can be specified. In interactive mode, will return None and print ids
and names unless suppress_printing is True.
topic can be specified.
:query: the term to match against indicator names
:query: string or pattern object to match. If a string is supplied,
search will be case-insensitive
:source: if present, id of desired source
:topic: if present, id of desired topic
:skip_cache: bypass cache when downloading
Expand All @@ -315,16 +315,17 @@ def search_indicators(self, query, source=None, topic=None, skip_cache=False):
indicators = self.get_indicator(
source=source, topic=topic, skip_cache=skip_cache
)
pattern = re.compile(query, re.IGNORECASE)
pattern = re.compile(query, re.IGNORECASE) if isinstance(query, str) else query
return SearchResult(i for i in indicators if pattern.search(i["name"]))

def search_countries(
self, query, incomelevel=None, lendingtype=None, skip_cache=False
):
"""
Search countries by name. Very simple search.
Search country names using a regular expression.
:query: the string to match against country names
:query: string or pattern object to match. If a string is supplied,
search will be case-insensitive
:incomelevel: if present, search only the matching incomelevel
:lendingtype: if present, search only the matching lendingtype
:skip_cache: bypass cache when downloading
Expand All @@ -334,7 +335,7 @@ def search_countries(
countries = self.get_country(
incomelevel=incomelevel, lendingtype=lendingtype, skip_cache=skip_cache
)
pattern = re.compile(query, re.IGNORECASE)
pattern = re.compile(query, re.IGNORECASE) if isinstance(query, str) else query
return SearchResult(i for i in countries if pattern.search(i["name"]))

@needs_pandas
Expand Down
Loading

0 comments on commit 0993710

Please sign in to comment.