From 49b406ef52e71835f3ec69b23aa43aabd18e415a Mon Sep 17 00:00:00 2001 From: Oliver Sherouse Date: Sun, 7 Jan 2024 23:43:11 -0500 Subject: [PATCH] working on it --- .github/workflows/publish.yaml | 7 +- .github/workflows/tests.yaml | 42 +- .tool-versions | 2 + pyproject.toml | 64 +-- setup.cfg | 7 - tests/test_api.py | 683 --------------------------------- tests/test_fetcher.py | 33 +- wbdata/__init__.py | 37 +- wbdata/api.py | 505 ------------------------ wbdata/cache.py | 50 +++ wbdata/client.py | 451 ++++++++++++++++++++++ wbdata/dates.py | 94 +++++ wbdata/fetcher.py | 245 ++++++------ wbdata/py.typed | 0 wbdata/types.py | 17 + 15 files changed, 860 insertions(+), 1377 deletions(-) create mode 100644 .tool-versions delete mode 100644 setup.cfg delete mode 100644 tests/test_api.py delete mode 100644 wbdata/api.py create mode 100644 wbdata/cache.py create mode 100644 wbdata/client.py create mode 100644 wbdata/dates.py create mode 100644 wbdata/py.typed create mode 100644 wbdata/types.py diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index afad197..ff53600 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -11,11 +11,12 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v1.1.1 + uses: actions/setup-python@v5 + - name: Checkout - name: Install Poetry - uses: dschep/install-poetry-action@v1.2 + uses: snok/install-poetry@v1.3.4 - name: Publish env: POETRY_PYPI_TOKEN_PYPI: ${{ secrets.pypi_token }} diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 48a98d7..e690eee 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -5,23 +5,53 @@ on: schedule: - cron: "1 1 1 * *" jobs: + lint: + name: Lint + runs-on: "ubuntu-latest" + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.8" + - name: Install Ruff + run: pip install ruff + - name: Run Tests + run: ruff + types: + name: Types + runs-on: "ubuntu-latest" + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.8" + - name: Install Poetry Action + uses: snok/install-poetry@v1.3.4 + - name: Install Dependencies + run: poetry install + - name: Run mypy + run: poetry run pytest test: name: Tests runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ["3.6", "3.7", "3.8"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] os: ["ubuntu-latest", "macos-latest", "windows-latest"] steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v1.1.1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install Poetry - uses: dschep/install-poetry-action@v1.2 + - name: Install Poetry Action + uses: snok/install-poetry@v1.3.4 - name: Install Dependencies - run: poetry install -E pandas + run: poetry install - name: Run Tests run: poetry run pytest diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 0000000..5b11157 --- /dev/null +++ b/.tool-versions @@ -0,0 +1,2 @@ +poetry 1.7.1 +python 3.8.13 diff --git a/pyproject.toml b/pyproject.toml index 3a74b47..d866d6a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "wbdata" -version = "0.3.0.post" +version = "1.0.0.dev" description = "A library to access World Bank data" authors = ["Oliver Sherouse "] license = "GPL-2.0+" @@ -11,36 +11,56 @@ classifiers = [ "Operating System :: OS Independent", "Topic :: Scientific/Engineering", ] - repository = "https://github.com/OliverSherouse/wbdata" documentation = "https://wbdata.readthedocs.io/" keywords = ["World Bank", "data", "economics"] [tool.poetry.dependencies] -python = ">=3.6" -decorator = ">=4.0" -requests = ">=2.0" -tabulate = ">=0.8.5" -appdirs = ">=1.4" +python = "^3.8" +requests = "^2.0" +tabulate = "^0.8.5" +appdirs = "^1.4" -pandas = {version = ">=0.17", optional=true} -sphinx = {version = "^3.0.3", optional=true} -recommonmark = {version = "^0.6.0", optional=true} -ipython = {version = "^7.16.1", optional=true} +pandas = {version = ">=1,<3", optional=true} +cachetools = "^5.3.2" +shelved-cache = "^0.3.1" +backoff = "^2.2.1" +types-cachetools = "^5.3.0.7" +dateparser = "^1.2.0" +decorator = "^5.1.1" [tool.poetry.extras] pandas = ["pandas"] -docs = ["sphinx", "recommonmark", "ipython"] -[tool.poetry.dev-dependencies] -pytest-flake8 = "^=1.0.6" -ipython = "^=7.16.1" -flake8-bugbear = "^20.1.4" -sphinx = "^3.0.3" -recommonmark = "^0.6.0" -flake8 = "^3.8.3" -pytest = "^5.4.3" +[tool.poetry.group.dev.dependencies] +ruff = "^0.1.11" +pytest = "^7.4.4" +mypy = "^1.8.0" +types-tabulate = "^0.9.0.20240106" +types-decorator = "^5.1.8.20240106" +types-appdirs = "^1.4.3.5" +types-requests = "^2.31.0.20240106" +ipython = "<8" +types-dateparser = "^1.1.4.20240106" +mkdocs = "^1.5.3" +mkdocstrings = "^0.24.0" + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", +] [build-system] -requires = ["poetry>=0.12"] -build-backend = "poetry.masonry.api" +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 34d7ae2..0000000 --- a/setup.cfg +++ /dev/null @@ -1,7 +0,0 @@ -[tool:pytest] -addopts = --flake8 - -[flake8] -max-line-length=80 -select = C,E,F,W,B,B9 -ignore = E501,W391,W503 diff --git a/tests/test_api.py b/tests/test_api.py deleted file mode 100644 index d8e31cf..0000000 --- a/tests/test_api.py +++ /dev/null @@ -1,683 +0,0 @@ -#!/usr/bin/env python3 - -import collections -import datetime as dt -import itertools - -import pytest -import wbdata as wbd - - -SimpleCallDefinition = collections.namedtuple( - "SimpleCallDefinition", ["function", "valid_id", "value"] -) - -SimpleCallSpec = collections.namedtuple( - "SimpleCallSpec", ["function", "result_all", "result_one", "id", "value"] -) - -SIMPLE_CALL_DEFINITIONS = [ - SimpleCallDefinition( - function=wbd.get_country, - valid_id="USA", # USA! USA! - value={ - "id": "USA", - "iso2Code": "US", - "name": "United States", - "region": {"id": "NAC", "iso2code": "XU", "value": "North America"}, - "adminregion": {"id": "", "iso2code": "", "value": ""}, - "incomeLevel": {"id": "HIC", "iso2code": "XD", "value": "High income"}, - "lendingType": {"id": "LNX", "iso2code": "XX", "value": "Not classified"}, - "capitalCity": "Washington D.C.", - "longitude": "-77.032", - "latitude": "38.8895", - }, - ), - SimpleCallDefinition( - function=wbd.get_incomelevel, - valid_id="HIC", - value={"id": "HIC", "iso2code": "XD", "value": "High income"}, - ), - SimpleCallDefinition( - function=wbd.get_lendingtype, - valid_id="IBD", - value={"id": "IBD", "iso2code": "XF", "value": "IBRD"}, - ), - SimpleCallDefinition( - function=wbd.get_source, - valid_id="2", - value={ - "id": "2", - "name": "World Development Indicators", - "code": "WDI", - "description": "", - "url": "", - "dataavailability": "Y", - "metadataavailability": "Y", - "concepts": "3", - }, - ), - SimpleCallDefinition( - function=wbd.get_topic, - valid_id="3", - value={ - "id": "3", - "value": "Economy & Growth", - "sourceNote": ( - "Economic growth is central to economic development. When " - "national income grows, real people benefit. While there is " - "no known formula for stimulating economic growth, data can " - "help policy-makers better understand their countries' " - "economic situations and guide any work toward improvement. " - "Data here covers measures of economic growth, such as gross " - "domestic product (GDP) and gross national income (GNI). It " - "also includes indicators representing factors known to be " - "relevant to economic growth, such as capital stock, " - "employment, investment, savings, consumption, government " - "spending, imports, and exports." - ), - }, - ), - SimpleCallDefinition( - function=wbd.get_indicator, - valid_id="SP.POP.TOTL", - value={ - "id": "SP.POP.TOTL", - "name": "Population, total", - "unit": "", - "source": {"id": "2", "value": "World Development Indicators"}, - "sourceNote": ( - "Total population is based on the de facto definition of " - "population, which counts all residents regardless of legal " - "status or citizenship. The values shown are midyear " - "estimates." - ), - "sourceOrganization": ( - "(1) United Nations Population Division. World Population " - "Prospects: 2019 Revision. (2) Census reports and other " - "statistical publications from national statistical offices, " - "(3) Eurostat: Demographic Statistics, (4) United Nations " - "Statistical Division. Population and Vital Statistics " - "Reprot (various years), (5) U.S. Census Bureau: " - "International Database, and (6) Secretariat of the Pacific " - "Community: Statistics and Demography Programme." - ), - "topics": [ - {"id": "19", "value": "Climate Change"}, - {"id": "8", "value": "Health "}, - ], - }, - ), -] - - -@pytest.fixture(params=SIMPLE_CALL_DEFINITIONS, scope="class") -def simple_call_spec(request): - return SimpleCallSpec( - function=request.param.function, - result_all=request.param.function(), - result_one=request.param.function(request.param.valid_id), - id=request.param.valid_id, - value=request.param.value, - ) - - -class TestSimpleQueries: - """ - Test that results of simple queries are close to what we expect - """ - - def test_simple_all_type(self, simple_call_spec): - assert isinstance(simple_call_spec.result_all, wbd.api.WBSearchResult) - - def test_simple_all_len(self, simple_call_spec): - assert len(simple_call_spec.result_all) > 1 - - def test_simple_all_content(self, simple_call_spec): - expected = [] - for val in simple_call_spec.result_all: - try: - del val["lastupdated"] - except KeyError: - pass - expected.append(val) - assert simple_call_spec.value in expected - - def test_simple_one_type(self, simple_call_spec): - assert isinstance(simple_call_spec.result_one, wbd.api.WBSearchResult) - - def test_simple_one_len(self, simple_call_spec): - assert len(simple_call_spec.result_one) == 1 - - def test_simple_one_content(self, simple_call_spec): - got = simple_call_spec.result_one[0] - try: - del got["lastupdated"] - except KeyError: - pass - assert simple_call_spec.result_one[0] == simple_call_spec.value - - def test_simple_bad_call(self, simple_call_spec): - with pytest.raises(RuntimeError): - simple_call_spec.function("Ain'tNotAThing") - - -class TestGetIndicator: - """Extra tests for Get Indicator""" - - def testGetIndicatorBySource(self): - indicators = wbd.get_indicator(source=1) - assert all(i["source"]["id"] == "1" for i in indicators) - - def testGetIndicatorByTopic(self): - indicators = wbd.get_indicator(topic=1) - assert all(any(t["id"] == "1" for t in i["topics"]) for i in indicators) - - def testGetIndicatorBySourceAndTopicFails(self): - with pytest.raises(ValueError): - wbd.get_indicator(source="1", topic=1) - - -SearchDefinition = collections.namedtuple( - "SearchDefinition", - ["function", "query", "value", "facets", "facet_matches", "facet_mismatches"], -) - -SearchData = collections.namedtuple( - "SearchData", - [ - "function", - "query", - "value", - "facets", - "results", - "results_facet_matches", - "results_facet_mismatches", - ], -) - -search_definitions = [ - SearchDefinition( - function=wbd.search_countries, - query="United", - value={ - "id": "USA", - "iso2Code": "US", - "name": "United States", - "region": {"id": "NAC", "iso2code": "XU", "value": "North America"}, - "adminregion": {"id": "", "iso2code": "", "value": ""}, - "incomeLevel": {"id": "HIC", "iso2code": "XD", "value": "High income"}, - "lendingType": {"id": "LNX", "iso2code": "XX", "value": "Not classified"}, - "capitalCity": "Washington D.C.", - "longitude": "-77.032", - "latitude": "38.8895", - }, - facets=["incomelevel", "lendingtype"], - facet_matches=["HIC", "LNX"], - facet_mismatches=["LIC", "IDX"], - ) -] - - -GetDataSpec = collections.namedtuple( - "GetDataSpec", - [ - "result", - "country", - "data_date", - "source", - "convert_date", - "expected_country", - "expected_date", - "expected_value", - ], -) - - -@pytest.fixture(params=search_definitions, scope="class") -def search_data(request): - return SearchData( - function=request.param.function, - query=request.param.query, - value=request.param.value, - facets=request.param.facets, - results=request.param.function(request.param.query), - results_facet_matches=[ - request.param.function(request.param.query, **{facet: value}) - for facet, value in zip(request.param.facets, request.param.facet_matches) - ], - results_facet_mismatches=[ - request.param.function(request.param.query, **{facet: value}) - for facet, value in zip( - request.param.facets, request.param.facet_mismatches - ) - ], - ) - - -class TestSearchFunctions: - def test_search_return_type(self, search_data): - assert isinstance(search_data.results, wbd.api.WBSearchResult) - - def test_facet_return_type(self, search_data): - for results in ( - search_data.results_facet_matches + search_data.results_facet_mismatches - ): - assert isinstance(results, wbd.api.WBSearchResult) - - def test_plain_search(self, search_data): - assert search_data.value in search_data.results - - def test_matched_faceted_searches(self, search_data): - for results in search_data.results_facet_matches: - assert search_data.value in results - - def test_mismatched_faceted_searches(self, search_data): - for results in search_data.results_facet_mismatches: - assert search_data.value not in results - - -COUNTRY_NAMES = { - "ERI": "Eritrea", - "GNQ": "Equatorial Guinea", -} - -common_data_facets = [ - ["all", "ERI", ["ERI", "GNQ"]], - [ - None, - dt.datetime(2010, 1, 1), - [dt.datetime(2010, 1, 1), dt.datetime(2011, 1, 1)], - ], - [None, "2", "11"], - [False, True], -] -get_data_defs = itertools.product(*common_data_facets) - - -@pytest.fixture(params=get_data_defs, scope="class") -def get_data_spec(request): - country, data_date, source, convert_date = request.param - return GetDataSpec( - result=wbd.get_data( - "NY.GDP.MKTP.CD", - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - ), - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - expected_country="Eritrea", - expected_date=dt.datetime(2010, 1, 1) if convert_date else "2010", - expected_value={"2": 2117039512.19512, "11": 2117008130.0813}[source or "2"], - ) - - -class TestGetData: - def test_result_type(self, get_data_spec): - assert isinstance(get_data_spec.result, wbd.fetcher.WBResults) - - def test_country(self, get_data_spec): - if get_data_spec.country == "all": - return - expected = ( - {get_data_spec.country} - if isinstance(get_data_spec.country, str) - else set(get_data_spec.country) - ) - # This is a little complicated because the API returns the iso3 id - # in different places from different sources (which is insane) - got = set( - [ - i["countryiso3code"] if i["countryiso3code"] else i["country"]["id"] - for i in get_data_spec.result - ] - ) - try: - assert got == expected - except AssertionError: - raise - - # Tests both string and converted dates - def test_data_date(self, get_data_spec): - if get_data_spec.data_date is None: - return - expected = ( - set(get_data_spec.data_date) - if isinstance(get_data_spec.data_date, collections.Sequence) - else {get_data_spec.data_date} - ) - if not get_data_spec.convert_date: - expected = {date.strftime("%Y") for date in expected} - - got = {i["date"] for i in get_data_spec.result} - assert got == expected - - # Tests source and correct value - def test_content(self, get_data_spec): - got = next( - datum["value"] - for datum in get_data_spec.result - if datum["country"]["value"] == get_data_spec.expected_country - and datum["date"] == get_data_spec.expected_date - ) - assert got == get_data_spec.expected_value - - def testLastUpdated(self, get_data_spec): - assert isinstance(get_data_spec.result.last_updated, dt.datetime) - - def test_monthly_freq(self): - got = wbd.get_data( - "DPANUSSPB", country="bra", data_date=dt.datetime(2012, 1, 1), freq="M" - )[0]["value"] - assert got == 1.78886363636 - - def test_quarterly_freq(self): - got = wbd.get_data( - "DP.DOD.DECD.CR.BC.CD", - country="chl", - data_date=dt.datetime(2013, 1, 1), - freq="Q", - )[0]["value"] - assert got == 31049138725.7794 - - -series_data_facets = tuple( - itertools.product(*(common_data_facets + [["value", "other"], [False, True]])) -) - - -GetSeriesSpec = collections.namedtuple( - "GetSeriesSpec", - [ - "result", - "country", - "data_date", - "source", - "convert_date", - "column_name", - "keep_levels", - "expected_country", - "expected_date", - "expected_value", - "country_in_index", - "date_in_index", - ], -) - - -@pytest.fixture(params=series_data_facets, scope="class") -def get_series_spec(request): - ( - country, - data_date, - source, - convert_date, - column_name, - keep_levels, - ) = request.param - return GetSeriesSpec( - result=wbd.get_series( - "NY.GDP.MKTP.CD", - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - column_name=column_name, - keep_levels=keep_levels, - ), - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - column_name=column_name, - keep_levels=keep_levels, - expected_country="Eritrea", - expected_date=dt.datetime(2010, 1, 1) if convert_date else "2010", - expected_value={"2": 2117039512.19512, "11": 2117008130.0813}[source or "2"], - country_in_index=( - country == "all" or not isinstance(country, str) or keep_levels - ), - date_in_index=(not isinstance(data_date, dt.datetime) or keep_levels), - ) - - -class TestGetSeries: - def test_index_labels(self, get_series_spec): - index = get_series_spec.result.index - if get_series_spec.country_in_index: - if get_series_spec.date_in_index: - assert index.names == ["country", "date"] - else: - assert index.name == "country" - else: - assert index.name == "date" - - def test_country(self, get_series_spec): - if not get_series_spec.country_in_index: - return - got = sorted(get_series_spec.result.index.unique(level="country")) - - if get_series_spec.country == "all": - assert len(got) > 2 - elif isinstance(get_series_spec.country, str): - assert len(got) == 1 - assert got[0] == COUNTRY_NAMES[get_series_spec.country] - else: - assert got == sorted( - COUNTRY_NAMES[country] for country in get_series_spec.country - ) - - def test_date(self, get_series_spec): - if not get_series_spec.date_in_index: - return - got = sorted(get_series_spec.result.index.unique(level="date")) - if get_series_spec.data_date is None: - assert len(got) > 2 - elif isinstance(get_series_spec.data_date, collections.Sequence): - assert got == sorted( - date if get_series_spec.convert_date else date.strftime("%Y") - for date in get_series_spec.data_date - ) - else: - assert len(got) == 1 - assert got[0] == ( - get_series_spec.data_date - if get_series_spec.convert_date - else get_series_spec.data_date.strftime("%Y") - ) - - def test_column_name(self, get_series_spec): - assert get_series_spec.result.name == get_series_spec.column_name - - def test_value(self, get_series_spec): - if get_series_spec.country_in_index: - if get_series_spec.date_in_index: - index_loc = ( - get_series_spec.expected_country, - get_series_spec.expected_date, - ) - else: - index_loc = get_series_spec.expected_country - else: - index_loc = get_series_spec.expected_date - - assert get_series_spec.result[index_loc] == get_series_spec.expected_value - - def test_last_updated(self, get_series_spec): - assert isinstance(get_series_spec.result.last_updated, dt.datetime) - - def test_bad_value(self): - with pytest.raises(RuntimeError): - wbd.get_series("AintNotAThing") - - def test_monthly_freq(self): - got = wbd.get_series( - "DPANUSSPB", country="bra", data_date=dt.datetime(2012, 1, 1), freq="M" - )["2012M01"] - assert got == 1.78886363636 - - def test_quarterly_freq(self): - got = wbd.get_series( - "DP.DOD.DECD.CR.BC.CD", - country="chl", - data_date=dt.datetime(2013, 1, 1), - freq="Q", - )["2013Q1"] - assert got == 31049138725.7794 - - -GetDataFrameSpec = collections.namedtuple( - "GetDataFrameSpec", - [ - "result", - "country", - "data_date", - "source", - "convert_date", - "column_names", - "keep_levels", - "expected_country", - "expected_date", - "expected_column", - "expected_value", - "country_in_index", - "date_in_index", - ], -) - - -@pytest.fixture(params=series_data_facets, scope="class") -def get_dataframe_spec(request): - ( - country, - data_date, - source, - convert_date, - column_name, - keep_levels, - ) = request.param - return GetDataFrameSpec( - result=wbd.get_dataframe( - {"NY.GDP.MKTP.CD": column_name, "NY.GDP.MKTP.PP.CD": "ppp"}, - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - keep_levels=keep_levels, - ), - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - column_names=[column_name, "ppp"], - keep_levels=keep_levels, - expected_country="Eritrea", - expected_date=dt.datetime(2010, 1, 1) if convert_date else "2010", - expected_column=column_name, - expected_value={"2": 2117039512.19512, "11": 2117008130.0813}[source or "2"], - country_in_index=( - country == "all" or not isinstance(country, str) or keep_levels - ), - date_in_index=(not isinstance(data_date, dt.datetime) or keep_levels), - ) - - -class TestGetDataFrame: - def test_index_labels(self, get_dataframe_spec): - index = get_dataframe_spec.result.index - if get_dataframe_spec.country_in_index: - if get_dataframe_spec.date_in_index: - assert index.names == ["country", "date"] - else: - assert index.name == "country" - else: - assert index.name == "date" - - def test_country(self, get_dataframe_spec): - if not get_dataframe_spec.country_in_index: - return - got = sorted(get_dataframe_spec.result.index.unique(level="country")) - - if get_dataframe_spec.country == "all": - assert len(got) > 2 - elif isinstance(get_dataframe_spec.country, str): - assert len(got) == 1 - assert got[0] == COUNTRY_NAMES[get_dataframe_spec.country] - else: - assert got == sorted( - COUNTRY_NAMES[country] for country in get_dataframe_spec.country - ) - - def test_date(self, get_dataframe_spec): - if not get_dataframe_spec.date_in_index: - return - got = sorted(get_dataframe_spec.result.index.unique(level="date")) - if get_dataframe_spec.data_date is None: - assert len(got) > 2 - elif isinstance(get_dataframe_spec.data_date, collections.Sequence): - assert got == sorted( - date if get_dataframe_spec.convert_date else date.strftime("%Y") - for date in get_dataframe_spec.data_date - ) - else: - assert len(got) == 1 - assert got[0] == ( - get_dataframe_spec.data_date - if get_dataframe_spec.convert_date - else get_dataframe_spec.data_date.strftime("%Y") - ) - - def test_column_name(self, get_dataframe_spec): - assert ( - get_dataframe_spec.result.columns.tolist() - == get_dataframe_spec.column_names - ) - - def test_value(self, get_dataframe_spec): - if get_dataframe_spec.country_in_index: - if get_dataframe_spec.date_in_index: - index_loc = ( - get_dataframe_spec.expected_country, - get_dataframe_spec.expected_date, - ) - else: - index_loc = get_dataframe_spec.expected_country - else: - index_loc = get_dataframe_spec.expected_date - - assert ( - get_dataframe_spec.result[get_dataframe_spec.expected_column][index_loc] - == get_dataframe_spec.expected_value - ) - - def test_last_updated(self, get_dataframe_spec): - assert all( - isinstance(value, dt.datetime) - for value in get_dataframe_spec.result.last_updated.values() - ) - - def test_bad_value(self): - with pytest.raises(RuntimeError): - wbd.get_dataframe({"AintNotAThing": "bad value"}) - - def test_monthly_freq(self): - got = wbd.get_dataframe( - {"DPANUSSPB": "var"}, - country="bra", - data_date=dt.datetime(2012, 1, 1), - freq="M", - )["var"]["2012M01"] - assert got == 1.78886363636 - - def test_quarterly_freq(self): - got = wbd.get_dataframe( - {"DP.DOD.DECD.CR.BC.CD": "var"}, - country="chl", - data_date=dt.datetime(2013, 1, 1), - freq="Q", - )["var"]["2013Q1"] - assert got == 31049138725.7794 diff --git a/tests/test_fetcher.py b/tests/test_fetcher.py index 538d7bb..cc4edc0 100644 --- a/tests/test_fetcher.py +++ b/tests/test_fetcher.py @@ -1,15 +1,28 @@ +import json +from unittest import mock + import pytest -import wbdata.fetcher -import wbdata.api +from wbdata import fetcher + + +@pytest.fixture +def mocked_fetcher() -> fetcher.Fetcher: + return fetcher.Fetcher(cache={}, session=mock.Mock()) + + +class MockHTTPResponse: + def __init__(self, value): + self.text = json.dumps(value) -def test_bad_indicator_error(): - expected = ( - r"Got error 120 \(Invalid value\): The provided parameter value is " - r"not valid" +def test_get_request_content(mocked_fetcher): + url = "http://foo.bar" + params = {"baz": "bat"} + expected = {"hello": "there"} + mocked_fetcher.session.get = mock.Mock( + return_value=MockHTTPResponse(value=expected) ) - with pytest.raises(RuntimeError, match=expected): - wbdata.fetcher.fetch( - wbdata.api.COUNTRIES_URL + "/all/AINT.NOT.A.THING" - ) + result = mocked_fetcher._get_response_body(url=url, params=params) + assert mocked_fetcher.session.get.called_once_with(url=url, params=params) + assert json.loads(result) == expected diff --git a/wbdata/__init__.py b/wbdata/__init__.py index e93b56d..e5ab208 100644 --- a/wbdata/__init__.py +++ b/wbdata/__init__.py @@ -3,16 +3,27 @@ """ __version__ = "0.3.0.post" -from .api import ( # noqa: F401 - get_country, - get_data, - get_series, - get_dataframe, - get_indicator, - get_incomelevel, - get_lendingtype, - get_source, - get_topic, - search_countries, - search_indicators, -) +import functools + +from .client import Client + + +@functools.lru_cache +def get_default_client() -> Client: + """ + Get the default client + """ + return Client() + + +get_country = get_default_client().get_country +get_data = get_default_client().get_data +get_dataframe = get_default_client().get_dataframe +get_incomelevel = get_default_client().get_incomelevel +get_indicator = get_default_client().get_indicator +get_lendingtype = get_default_client().get_lendingtype +get_series = get_default_client().get_series +get_source = get_default_client().get_source +get_topic = get_default_client().get_topic +search_countries = get_default_client().search_countries +search_indicators = get_default_client().search_indicators diff --git a/wbdata/api.py b/wbdata/api.py deleted file mode 100644 index ba59e16..0000000 --- a/wbdata/api.py +++ /dev/null @@ -1,505 +0,0 @@ -""" -wbdata.api: Where all the functions go -""" - -import collections -import datetime -import re -import warnings - -import tabulate - -try: - import pandas as pd -except ImportError: - pd = None - -from decorator import decorator -from . import fetcher - -BASE_URL = "https://api.worldbank.org/v2" -COUNTRIES_URL = f"{BASE_URL}/countries" -ILEVEL_URL = f"{BASE_URL}/incomeLevels" -INDICATOR_URL = f"{BASE_URL}/indicators" -LTYPE_URL = f"{BASE_URL}/lendingTypes" -SOURCES_URL = f"{BASE_URL}/sources" -TOPIC_URL = f"{BASE_URL}/topics" -INDIC_ERROR = "Cannot specify more than one of indicator, source, and topic" - - -class WBSearchResult(list): - """ - A list that prints out a user-friendly table when printed or returned on the - command line - - - Items are expected to be dict-like and have an "id" key and a "name" or - "value" key - """ - - def __repr__(self): - try: - return tabulate.tabulate( - [[o["id"], o["name"]] for o in self], - headers=["id", "name"], - tablefmt="simple", - ) - except KeyError: - return tabulate.tabulate( - [[o["id"], o["value"]] for o in self], - headers=["id", "value"], - tablefmt="simple", - ) - - -if pd: - - class WBSeries(pd.Series): - """ - A pandas Series with a last_updated attribute - """ - - _metadata = ["last_updated"] - - @property - def _constructor(self): - return WBSeries - - class WBDataFrame(pd.DataFrame): - """ - A pandas DataFrame with a last_updated attribute - """ - - _metadata = ["last_updated"] - - @property - def _constructor(self): - return WBDataFrame - - -@decorator -def uses_pandas(f, *args, **kwargs): - """Raise ValueError if pandas is not loaded""" - if not pd: - raise ValueError("Pandas must be installed to be used") - return f(*args, **kwargs) - - -def parse_value_or_iterable(arg): - """ - If arg is a single value, return it as a string; if an iterable, return a - ;-joined string of all values - """ - if str(arg) == arg: - return arg - if type(arg) == int: - return str(arg) - return ";".join(arg) - - -def convert_year_to_datetime(yearstr): - """return datetime.datetime object from %Y formatted string""" - return datetime.datetime.strptime(yearstr, "%Y") - - -def convert_month_to_datetime(monthstr): - """return datetime.datetime object from %YM%m formatted string""" - split = monthstr.split("M") - return datetime.datetime(int(split[0]), int(split[1]), 1) - - -def convert_quarter_to_datetime(quarterstr): - """ - return datetime.datetime object from %YQ%# formatted string, where # is - the desired quarter - """ - split = quarterstr.split("Q") - quarter = int(split[1]) - month = quarter * 3 - 2 - return datetime.datetime(int(split[0]), month, 1) - - -def convert_dates_to_datetime(data): - """ - Return a datetime.datetime object from a date string as provided by the - World Bank - """ - first = data[0]["date"] - if isinstance(first, datetime.datetime): - return data - if "M" in first: - converter = convert_month_to_datetime - elif "Q" in first: - converter = convert_quarter_to_datetime - else: - converter = convert_year_to_datetime - for datum in data: - datum_date = datum["date"] - if "MRV" in datum_date: - continue - if "-" in datum_date: - continue - datum["date"] = converter(datum_date) - return data - - -def cast_float(value): - """ - Return a floated value or none - """ - try: - return float(value) - except (ValueError, TypeError): - return None - - -def get_series( - indicator, - country="all", - data_date=None, - freq="Y", - source=None, - convert_date=False, - column_name="value", - keep_levels=False, - cache=True, -): - """ - Retrieve indicators for given countries and years - - :indicator: the desired indicator code - :country: a country code, sequence of country codes, or "all" (default) - :data_date: the desired date as a datetime object or a 2-tuple with start - and end dates - :freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' - (monthly), or 'Q' (quarterly). The indicator may or may not support the - specified frequency. - :source: the specific source to retrieve data from (defaults on API to 2, - World Development Indicators) - :convert_date: if True, convert date field to a datetime.datetime object. - :column_name: the desired name for the pandas column - :keep_levels: if True and pandas is True, don't reduce the number of index - levels returned if only getting one date or country - :cache: use the cache - :returns: WBSeries - """ - raw_data = get_data( - indicator=indicator, - country=country, - data_date=data_date, - freq=freq, - source=source, - convert_date=convert_date, - cache=cache, - ) - df = pd.DataFrame( - [[i["country"]["value"], i["date"], i["value"]] for i in raw_data], - columns=["country", "date", column_name], - ) - df[column_name] = df[column_name].map(cast_float) - if not keep_levels and len(df["country"].unique()) == 1: - df = df.set_index("date") - elif not keep_levels and len(df["date"].unique()) == 1: - df = df.set_index("country") - else: - df = df.set_index(["country", "date"]) - series = WBSeries(df[column_name]) - series.last_updated = raw_data.last_updated - return series - - -def data_date_to_str(data_date, freq): - """ - Convert data_date to the appropriate representation base on freq - - - :data_date: A datetime.datetime object to be formatted - :freq: One of 'Y' (year), 'M' (month) or 'Q' (quarter) - - """ - if freq == "Y": - return data_date.strftime("%Y") - if freq == "M": - return data_date.strftime("%YM%m") - if freq == "Q": - return f"{data_date.year}Q{(data_date.month - 1) // 3 + 1}" - - -def get_data( - indicator, - country="all", - data_date=None, - freq="Y", - source=None, - convert_date=False, - pandas=False, - column_name="value", - keep_levels=False, - cache=True, -): - """ - Retrieve indicators for given countries and years - - :indicator: the desired indicator code - :country: a country code, sequence of country codes, or "all" (default) - :data_date: the desired date as a datetime object or a 2-tuple with start - and end dates - :freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' - (monthly), or 'Q' (quarterly). The indicator may or may not support the - specified frequency. - :source: the specific source to retrieve data from (defaults on API to 2, - World Development Indicators) - :convert_date: if True, convert date field to a datetime.datetime object. - :cache: use the cache - :returns: list of dictionaries - """ - if pandas: - warnings.warn( - ( - "Argument 'pandas' is deprecated and will be removed in a " - "future version. Use get_series or get_dataframe instead." - ), - PendingDeprecationWarning, - ) - return get_series( - indicator=indicator, - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - column_name=column_name, - keep_levels=keep_levels, - cache=cache, - ) - query_url = COUNTRIES_URL - try: - c_part = parse_value_or_iterable(country) - except TypeError: - raise TypeError("'country' must be a string or iterable'") - query_url = "/".join((query_url, c_part, "indicators", indicator)) - args = {} - if data_date: - args["date"] = ( - ":".join(data_date_to_str(dd, freq) for dd in data_date) - if isinstance(data_date, collections.Sequence) - else data_date_to_str(data_date, freq) - ) - if source: - args["source"] = source - data = fetcher.fetch(query_url, args, cache=cache) - if convert_date: - data = convert_dates_to_datetime(data) - return data - - -def id_only_query(query_url, query_id, cache): - """ - Retrieve information when ids are the only arguments - - :query_url: the base url to use for the query - :query_id: an id or sequence of ids - :cache: use the cache - :returns: WBSearchResult containing dictionary objects describing results - """ - if query_id: - query_url = "/".join((query_url, parse_value_or_iterable(query_id))) - return WBSearchResult(fetcher.fetch(query_url)) - - -def get_source(source_id=None, cache=True): - """ - Retrieve information on a source - - :source_id: a source id or sequence thereof. None returns all sources - :cache: use the cache - :returns: WBSearchResult containing dictionary objects describing selected - sources - """ - return id_only_query(SOURCES_URL, source_id, cache=cache) - - -def get_incomelevel(level_id=None, cache=True): - """ - Retrieve information on an income level aggregate - - :level_id: a level id or sequence thereof. None returns all income level - aggregates - :cache: use the cache - :returns: WBSearchResult containing dictionary objects describing selected - income level aggregates - """ - return id_only_query(ILEVEL_URL, level_id, cache=cache) - - -def get_topic(topic_id=None, cache=True): - """ - Retrieve information on a topic - - :topic_id: a topic id or sequence thereof. None returns all topics - :cache: use the cache - :returns: WBSearchResult containing dictionary objects describing selected - topic aggregates - """ - return id_only_query(TOPIC_URL, topic_id, cache=cache) - - -def get_lendingtype(type_id=None, cache=True): - """ - Retrieve information on an income level aggregate - - :level_id: lending type id or sequence thereof. None returns all lending - type aggregates - :cache: use the cache - :returns: WBSearchResult containing dictionary objects describing selected - topic aggregates - """ - return id_only_query(LTYPE_URL, type_id, cache=cache) - - -def get_country(country_id=None, incomelevel=None, lendingtype=None, cache=True): - """ - Retrieve information on a country or regional aggregate. Can specify - either country_id, or the aggregates, but not both - - :country_id: a country id or sequence thereof. None returns all countries - and aggregates. - :incomelevel: desired incomelevel id or ids. - :lendingtype: desired lendingtype id or ids. - :cache: use the cache - :returns: WBSearchResult containing dictionary objects representing each - country - """ - if country_id: - if incomelevel or lendingtype: - raise ValueError("Can't specify country_id and aggregates") - return id_only_query(COUNTRIES_URL, country_id, cache=cache) - args = {} - if incomelevel: - args["incomeLevel"] = parse_value_or_iterable(incomelevel) - if lendingtype: - args["lendingType"] = parse_value_or_iterable(lendingtype) - return WBSearchResult(fetcher.fetch(COUNTRIES_URL, args, cache=cache)) - - -def get_indicator(indicator=None, source=None, topic=None, cache=True): - """ - Retrieve information about an indicator or indicators. Only one of - indicator, source, and topic can be specified. Specifying none of the - three will return all indicators. - - :indicator: an indicator code or sequence thereof - :source: a source id or sequence thereof - :topic: a topic id or sequence thereof - :cache: use the cache - :returns: WBSearchResult containing dictionary objects representing - indicators - """ - if indicator: - if source or topic: - raise ValueError(INDIC_ERROR) - query_url = "/".join((INDICATOR_URL, parse_value_or_iterable(indicator))) - elif source: - if topic: - raise ValueError(INDIC_ERROR) - query_url = "/".join( - (SOURCES_URL, parse_value_or_iterable(source), "indicators") - ) - elif topic: - query_url = "/".join((TOPIC_URL, parse_value_or_iterable(topic), "indicators")) - else: - query_url = INDICATOR_URL - return WBSearchResult(fetcher.fetch(query_url, cache=cache)) - - -def search_indicators(query, source=None, topic=None, cache=True): - """ - Search indicators for a certain regular expression. Only one of source or - topic can be specified. In interactive mode, will return None and print ids - and names unless suppress_printing is True. - - :query: the term to match against indicator names - :source: if present, id of desired source - :topic: if present, id of desired topic - :cache: use the cache - :returns: WBSearchResult containing dictionary objects representing search - indicators - """ - indicators = get_indicator(source=source, topic=topic, cache=cache) - pattern = re.compile(query, re.IGNORECASE) - return WBSearchResult(i for i in indicators if pattern.search(i["name"])) - - -def search_countries(query, incomelevel=None, lendingtype=None, cache=True): - """ - Search countries by name. Very simple search. - - :query: the string to match against country names - :incomelevel: if present, search only the matching incomelevel - :lendingtype: if present, search only the matching lendingtype - :cache: use the cache - :returns: WBSearchResult containing dictionary objects representing - countries - """ - countries = get_country( - incomelevel=incomelevel, lendingtype=lendingtype, cache=cache - ) - pattern = re.compile(query, re.IGNORECASE) - return WBSearchResult(i for i in countries if pattern.search(i["name"])) - - -@uses_pandas -def get_dataframe( - indicators, - country="all", - data_date=None, - freq="Y", - source=None, - convert_date=False, - keep_levels=False, - cache=True, -): - """ - Convenience function to download a set of indicators and merge them into a - pandas DataFrame. The index will be the same as if calls were made to - get_data separately. - - :indicators: An dictionary where the keys are desired indicators and the - values are the desired column names - :country: a country code, sequence of country codes, or "all" (default) - :data_date: the desired date as a datetime object or a 2-sequence with - start and end dates - :freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' - (monthly), or 'Q' (quarterly). The indicator may or may not support the - specified frequency. - :source: the specific source to retrieve data from (defaults on API to 2, - World Development Indicators) - :convert_date: if True, convert date field to a datetime.datetime object. - :keep_levels: if True don't reduce the number of index levels returned if - only getting one date or country - :cache: use the cache - :returns: a WBDataFrame - """ - serieses = [ - ( - get_series( - indicator=indicator, - country=country, - data_date=data_date, - freq=freq, - source=source, - convert_date=convert_date, - keep_levels=keep_levels, - cache=cache, - ).rename(name) - ) - for indicator, name in indicators.items() - ] - result = None - for series in serieses: - if result is None: - result = series.to_frame() - else: - result = result.join(series.to_frame(), how="outer") - result = WBDataFrame(result) - result.last_updated = {i.name: i.last_updated for i in serieses} - return result diff --git a/wbdata/cache.py b/wbdata/cache.py new file mode 100644 index 0000000..cc81a0e --- /dev/null +++ b/wbdata/cache.py @@ -0,0 +1,50 @@ +import datetime as dt +import logging +import os +from pathlib import Path +from typing import Union + +import appdirs +import cachetools +import shelved_cache # type: ignore[import-untyped] + +from wbdata import __version__ + +log = logging.getLogger(__name__) + +try: + TTL_DAYS = int(os.getenv("WBDATA_CACHE_TTL_DAYS", "7")) +except ValueError: + logging.warning("Couldn't parse WBDATA_CACHE_TTL_DAYS value, defaulting to 7") + TTL_DAYS = 7 + +try: + MAX_SIZE = int(os.getenv("WBDATA_CACHE_MAX_SIZE", "100")) +except ValueError: + logging.warning("Couldn't parse WBDATA_CACHE_MAX_SIZE value, defaulting to 100") + MAX_SIZE = 7 + + +def get_cache( + path: Union[str, Path, None] = None, + ttl_days: Union[int, None] = None, + max_size: Union[int, None] = None, +) -> cachetools.Cache: + """ + Get the global cache + """ + path = path or Path( + appdirs.user_cache_dir(appname="wbdata", version=__version__) + ).joinpath("cache") + Path(path).parent.mkdir(parents=True, exist_ok=True) + ttl_days = ttl_days or TTL_DAYS + max_size = max_size or MAX_SIZE + cache: cachetools.TTLCache = shelved_cache.PersistentCache( + cachetools.TTLCache, + filename=str(path), + maxsize=max_size, + ttl=dt.timedelta(days=ttl_days), + timer=dt.datetime.now, + ) + cache.expire() + return cache diff --git a/wbdata/client.py b/wbdata/client.py new file mode 100644 index 0000000..e108593 --- /dev/null +++ b/wbdata/client.py @@ -0,0 +1,451 @@ +import contextlib +import dataclasses +import re +from pathlib import Path +from typing import Any, Dict, Iterable, List, Sequence, Union + +import decorator +import requests +import tabulate + +try: + import pandas as pd # type: ignore[import-untyped] +except ImportError: + pd = None + + +from . import cache, dates, fetcher +from .types import DateArg, IdArg, Row + +BASE_URL = "https://api.worldbank.org/v2" +COUNTRIES_URL = f"{BASE_URL}/countries" +ILEVEL_URL = f"{BASE_URL}/incomeLevels" +INDICATOR_URL = f"{BASE_URL}/indicators" +LTYPE_URL = f"{BASE_URL}/lendingTypes" +SOURCES_URL = f"{BASE_URL}/sources" +TOPIC_URL = f"{BASE_URL}/topics" +INDIC_ERROR = "Cannot specify more than one of indicator, source, and topic" + + +class SearchResult(List[Row]): + """ + A list that prints out a user-friendly table when printed or returned on the + command line + + + Items are expected to be dict-like and have an "id" key and a "name" or + "value" key + """ + + def __repr__(self) -> str: + try: + return tabulate.tabulate( + [[o["id"], o["name"]] for o in self], + headers=["id", "name"], + tablefmt="simple", + ) + except KeyError: + return tabulate.tabulate( + [[o["id"], o["value"]] for o in self], + headers=["id", "value"], + tablefmt="simple", + ) + + +if pd: + + class Series(pd.Series): + """ + A pandas Series with a last_updated attribute + """ + + _metadata = ["last_updated"] + + @property + def _constructor(self): + return Series + + class DataFrame(pd.DataFrame): + """ + A pandas DataFrame with a last_updated attribute + """ + + _metadata = ["last_updated"] + + @property + def _constructor(self): + return DataFrame +else: + Series = Any # type: ignore[misc, assignment] + DataFrame = Any # type: ignore[misc, assignment] + + +@decorator.decorator +def needs_pandas(f, *args, **kwargs): + if pd is None: + raise RuntimeError(f"{f.__name__} requires pandas") + return f(*args, **kwargs) + + +def parse_value_or_iterable(arg: Any) -> str: + """ + If arg is a single value, return it as a string; if an iterable, return a + ;-joined string of all values + """ + if isinstance(arg, str): + return arg + if isinstance(arg, Iterable): + return ";".join(str(i) for i in arg) + return str(arg) + + +def cast_float(value: str) -> Union[float, None]: + """ + Return a value coerced to float or None + """ + with contextlib.suppress(ValueError, TypeError): + return float(value) + return None + + +@dataclasses.dataclass +class Client: + cache_path: Union[str, Path, None] = None + cache_ttl_days: Union[int, None] = None + cache_max_size: Union[int, None] = None + session: Union[requests.Session, None] = None + + def __post_init__(self): + self.fetcher = fetcher.Fetcher( + cache=cache.get_cache( + path=self.cache_path, + ttl_days=self.cache_ttl_days, + max_size=self.cache_max_size, + ) + ) + self.has_pandas = pd is None + + def get_data( + self, + indicator: str, + country: Union[str, Sequence[str]] = "all", + data_date: DateArg = None, + freq: str = "Y", + source: IdArg = None, + convert_date: bool = False, + skip_cache: bool = False, + ): + """ + Retrieve indicators for given countries and years + + :indicator: the desired indicator code + :country: a country code, sequence of country codes, or "all" (default) + :data_date: the desired date as a datetime object or a 2-tuple with start + and end dates + :freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' + (monthly), or 'Q' (quarterly). The indicator may or may not support the + specified frequency. + :source: the specific source to retrieve data from (defaults on API to 2, + World Development Indicators) + :convert_date: if True, convert date field to a datetime.datetime object. + :skip_cache: bypass the cache when downloading + :returns: list of dictionaries + """ + query_url = COUNTRIES_URL + try: + c_part = parse_value_or_iterable(country) + except TypeError as e: + raise TypeError("'country' must be a string or iterable'") from e + query_url = "/".join((query_url, c_part, "indicators", indicator)) + args: Dict[str, Any] = {} + if data_date: + args["date"] = dates.datespec_to_arg(data_date, freq) + if source: + args["source"] = source + data = self.fetcher.fetch(query_url, args, skip_cache=skip_cache) + if convert_date: + dates.convert_dates_to_datetime(data) + return data + + def _id_only_query( + self, + query_url: str, + query_id: Any, + skip_cache: bool, + ) -> SearchResult: + """ + Retrieve information when ids are the only arguments + + :query_url: the base url to use for the query + :query_id: an id or sequence of ids + :skip_cache: bypass cache when downloading + :returns: SearchResult containing dictionary objects describing results + """ + if query_id: + query_url = "/".join((query_url, parse_value_or_iterable(query_id))) + return SearchResult(self.fetcher.fetch(url=query_url, skip_cache=skip_cache)[0]) + + def get_source( + self, source_id: IdArg = None, skip_cache: bool = False + ) -> SearchResult: + """ + Retrieve information on a source + + :source_id: a source id or sequence thereof. None returns all sources + :skip_cache: bypass cache when downloading + :returns: SearchResult containing dictionary objects describing selected + sources + """ + return self._id_only_query( + query_url=SOURCES_URL, query_id=source_id, skip_cache=skip_cache + ) + + def get_incomelevel( + self, level_id: IdArg = None, skip_cache: bool = False + ) -> SearchResult: + """ + Retrieve information on an income level aggregate + + :level_id: a level id or sequence thereof. None returns all income level + aggregates + :skip_cache: bypass cache when downloading + :returns: SearchResult containing dictionary objects describing selected + income level aggregates + """ + return self._id_only_query(ILEVEL_URL, level_id, skip_cache=skip_cache) + + def get_topic( + self, topic_id: IdArg = None, skip_cache: bool = False + ) -> SearchResult: + """ + Retrieve information on a topic + + :topic_id: a topic id or sequence thereof. None returns all topics + :skip_cache: bypass cache when downloading + :returns: SearchResult containing dictionary objects describing selected + topic aggregates + """ + return self._id_only_query(TOPIC_URL, topic_id, skip_cache=skip_cache) + + def get_lendingtype(self, type_id=None, skip_cache=False): + """ + Retrieve information on an income level aggregate + + :level_id: lending type id or sequence thereof. None returns all lending + type aggregates + :skip_cache: bypass cache when downloading + :returns: SearchResult containing dictionary objects describing selected + topic aggregates + """ + return self._id_only_query(LTYPE_URL, type_id, skip_cache=skip_cache) + + def get_country( + self, country_id=None, incomelevel=None, lendingtype=None, skip_cache=False + ): + """ + Retrieve information on a country or regional aggregate. Can specify + either country_id, or the aggregates, but not both + + :country_id: a country id or sequence thereof. None returns all countries + and aggregates. + :incomelevel: desired incomelevel id or ids. + :lendingtype: desired lendingtype id or ids. + :skip_cache: bypass cache when downloading + :returns: SearchResult containing dictionary objects representing each + country + """ + if country_id: + if incomelevel or lendingtype: + raise ValueError("Can't specify country_id and aggregates") + return self._id_only_query(COUNTRIES_URL, country_id, skip_cache=skip_cache) + args = {} + if incomelevel: + args["incomeLevel"] = parse_value_or_iterable(incomelevel) + if lendingtype: + args["lendingType"] = parse_value_or_iterable(lendingtype) + return SearchResult( + self.fetcher.fetch(COUNTRIES_URL, args, skip_cache=skip_cache)[0] + ) + + def get_indicator(self, indicator=None, source=None, topic=None, skip_cache=False): + """ + Retrieve information about an indicator or indicators. Only one of + indicator, source, and topic can be specified. Specifying none of the + three will return all indicators. + + :indicator: an indicator code or sequence thereof + :source: a source id or sequence thereof + :topic: a topic id or sequence thereof + :skip_cache: bypass cache when downloading + :returns: SearchResult containing dictionary objects representing + indicators + """ + if indicator: + if source or topic: + raise ValueError(INDIC_ERROR) + query_url = "/".join((INDICATOR_URL, parse_value_or_iterable(indicator))) + elif source: + if topic: + raise ValueError(INDIC_ERROR) + query_url = "/".join( + (SOURCES_URL, parse_value_or_iterable(source), "indicators") + ) + elif topic: + query_url = "/".join( + (TOPIC_URL, parse_value_or_iterable(topic), "indicators") + ) + else: + query_url = INDICATOR_URL + return SearchResult(self.fetcher.fetch(query_url, skip_cache=skip_cache)) + + def search_indicators(self, query, source=None, topic=None, skip_cache=False): + """ + Search indicators for a certain regular expression. Only one of source or + topic can be specified. In interactive mode, will return None and print ids + and names unless suppress_printing is True. + + :query: the term to match against indicator names + :source: if present, id of desired source + :topic: if present, id of desired topic + :skip_cache: bypass cache when downloading + :returns: SearchResult containing dictionary objects representing search + indicators + """ + indicators = self.get_indicator( + source=source, topic=topic, skip_cache=skip_cache + ) + pattern = re.compile(query, re.IGNORECASE) + return SearchResult(i for i in indicators if pattern.search(i["name"])) + + def search_countries( + self, query, incomelevel=None, lendingtype=None, skip_cache=False + ): + """ + Search countries by name. Very simple search. + + :query: the string to match against country names + :incomelevel: if present, search only the matching incomelevel + :lendingtype: if present, search only the matching lendingtype + :skip_cache: bypass cache when downloading + :returns: SearchResult containing dictionary objects representing + countries + """ + countries = self.get_country( + incomelevel=incomelevel, lendingtype=lendingtype, skip_cache=skip_cache + ) + pattern = re.compile(query, re.IGNORECASE) + return SearchResult(i for i in countries if pattern.search(i["name"])) + + @needs_pandas + def get_series( + self, + indicator: str, + country: Union[str, Sequence[str]] = "all", + data_date: DateArg = None, + freq: str = "Y", + source: IdArg = None, + convert_date: bool = False, + column_name: str = "value", + keep_levels: bool = False, + skip_cache: bool = False, + ) -> Series: + """ + Retrieve indicators for given countries and years + + :indicator: the desired indicator code + :country: a country code, sequence of country codes, or "all" (default) + :data_date: the desired date as a datetime object or a 2-tuple with start + and end dates + :freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' + (monthly), or 'Q' (quarterly). The indicator may or may not support the + specified frequency. + :source: the specific source to retrieve data from (defaults on API to 2, + World Development Indicators) + :convert_date: if True, convert date field to a datetime.datetime object. + :column_name: the desired name for the pandas column + :keep_levels: if True don't reduce the number of index + levels returned if only getting one date or country + :skip_cache: bypass the cache when downloading + :returns: Series + """ + raw_data = self.get_data( + indicator=indicator, + country=country, + data_date=data_date, + freq=freq, + source=source, + convert_date=convert_date, + skip_cache=skip_cache, + ) + df = pd.DataFrame( + [[i["country"]["value"], i["date"], i["value"]] for i in raw_data], + columns=["country", "date", column_name], + ) + df[column_name] = df[column_name].map(cast_float) + if not keep_levels and len(df["country"].unique()) == 1: + df = df.set_index("date") + elif not keep_levels and len(df["date"].unique()) == 1: + df = df.set_index("country") + else: + df = df.set_index(["country", "date"]) + series = Series(df[column_name]) + series.last_updated = raw_data.last_updated + return series + + @needs_pandas + def get_dataframe( + self, + indicators: Dict[str, str], + country="all", + data_date=None, + freq="Y", + source=None, + convert_date=False, + keep_levels=False, + skip_cache: bool = False, + ) -> DataFrame: + """ + Convenience function to download a set of indicators and merge them into a + pandas DataFrame. The index will be the same as if calls were made to + get_data separately. + + :indicators: An dictionary where the keys are desired indicators and the + values are the desired column names + :country: a country code, sequence of country codes, or "all" (default) + :data_date: the desired date as a datetime object or a 2-sequence with + start and end dates + :freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' + (monthly), or 'Q' (quarterly). The indicator may or may not support the + specified frequency. + :source: the specific source to retrieve data from (defaults on API to 2, + World Development Indicators) + :convert_date: if True, convert date field to a datetime.datetime object. + :keep_levels: if True don't reduce the number of index levels returned if + only getting one date or country + :skip_cache: bypass the cache when downloading + :returns: a DataFrame + """ + serieses = [ + ( + self.get_series( + indicator=indicator, + country=country, + data_date=data_date, + freq=freq, + source=source, + convert_date=convert_date, + keep_levels=keep_levels, + skip_cache=skip_cache, + ).rename(name) + ) + for indicator, name in indicators.items() + ] + result = None + for series in serieses: + if result is None: + result = series.to_frame() + else: + result = result.join(series.to_frame(), how="outer") + result = DataFrame(result) + result.last_updated = {i.name: i.last_updated for i in serieses} + return result diff --git a/wbdata/dates.py b/wbdata/dates.py new file mode 100644 index 0000000..3713591 --- /dev/null +++ b/wbdata/dates.py @@ -0,0 +1,94 @@ +import datetime as dt +import re +from typing import Any, Dict, List, Union + +import dateparser + +from .types import DateSpec + +PATTERN_YEAR = re.compile("\d{4}") +PATTERN_MONTH = re.compile("\d{4}M\d{1,2}") +PATTERN_QUARTER = re.compile("\d{4}Q\d{1,2}") + + +def convert_year_to_datetime(datestr: str) -> dt.datetime: + """return datetime.datetime object from %Y formatted string""" + return dt.datetime.strptime(datestr, "%Y") + + +def convert_month_to_datetime(datestr: str) -> dt.datetime: + """return datetime.datetime object from %YM%m formatted string""" + split = datestr.split("M") + return dt.datetime(int(split[0]), int(split[1]), 1) + + +def convert_quarter_to_datetime(datestr: str) -> dt.datetime: + """ + return datetime.datetime object from %YQ%# formatted string, where # is + the desired quarter + """ + split = datestr.split("Q") + quarter = int(split[1]) + month = quarter * 3 - 2 + return dt.datetime(int(split[0]), month, 1) + + +def convert_dates_to_datetime(data: List[Dict[str, Any]]) -> None: + """Replace date strings in raw response with datetime objects.""" + first = data[0]["date"] + if isinstance(first, dt.datetime): + return + if PATTERN_MONTH.match(first): + converter = convert_month_to_datetime + elif PATTERN_QUARTER.match(first): + converter = convert_quarter_to_datetime + else: + converter = convert_year_to_datetime + for datum in data: + datum_date = datum["date"] + if "MRV" in datum_date or "-" in datum_date: + continue + datum["date"] = converter(datum_date) + + +def data_date_to_str(data_date: dt.datetime, freq: str) -> str: + """ + Convert data_date to the appropriate representation base on freq + + + :data_date: A datetime.datetime object to be formatted + :freq: One of 'Y' (year), 'M' (month) or 'Q' (quarter) + + """ + if freq == "Y": + return data_date.strftime("%Y") + if freq == "M": + return data_date.strftime("%YM%m") + if freq == "Q": + return f"{data_date.year}Q{(data_date.month - 1) // 3 + 1}" + raise ValueError(f"Unknown Frequency type: {freq}") + + +def parse_single_date(date: Union[str, dt.datetime]) -> dt.datetime: + if isinstance(date, dt.datetime): + return date + if PATTERN_YEAR.match(date): + return convert_year_to_datetime(date) + if PATTERN_MONTH.match(date): + return convert_month_to_datetime(date) + if PATTERN_QUARTER.match(date): + return convert_quarter_to_datetime(date) + last_chance = dateparser.parse(date) + if last_chance: + return last_chance + raise ValueError(f"Unable to parse date string {date}") + + +def datespec_to_arg(spec: DateSpec, freq) -> str: + if isinstance(spec, tuple): + return ( + f"{data_date_to_str(parse_single_date(spec[0]), freq)}" + ":" + f"{data_date_to_str(parse_single_date(spec[1]), freq)}" + ) + return data_date_to_str(parse_single_date(spec), freq) diff --git a/wbdata/fetcher.py b/wbdata/fetcher.py index 932ced7..e65cd73 100644 --- a/wbdata/fetcher.py +++ b/wbdata/fetcher.py @@ -2,149 +2,138 @@ wbdata.fetcher: retrieve and cache queries """ -import datetime +import contextlib +import dataclasses +import datetime as dt import json import logging -import pickle import pprint +from typing import Any, Dict, List, MutableMapping, NamedTuple, Tuple, Union -import appdirs +import backoff import requests -import wbdata +from .types import Row -from pathlib import Path - -EXP = 7 PER_PAGE = 1000 -TODAY = datetime.date.today() -TRIES = 5 - - -class WBResults(list): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.last_updated = None - - -class Cache(object): - """Docstring for Cache """ - - def __init__(self): - self.path = Path( - appdirs.user_cache_dir(appname="wbdata", version=wbdata.__version__) - ) - self.path.parent.mkdir(parents=True, exist_ok=True) - try: - with self.path.open("rb") as cachefile: - self.cache = { - i: (date, json) - for i, (date, json) in pickle.load(cachefile).items() - if (TODAY - datetime.date.fromordinal(date)).days < EXP - } - except (IOError, EOFError): - self.cache = {} - - def __getitem__(self, key): - return self.cache[key][1] +TRIES = 3 - def __setitem__(self, key, value): - self.cache[key] = TODAY.toordinal(), value - self.sync() - def __contains__(self, item): - return item in self.cache +def _strip_id(row: Row) -> None: + with contextlib.suppress(KeyError): + row["id"] = row["id"].strip() # type: ignore[union-attr] - def sync(self): - """Sync cache to disk""" - with self.path.open("wb") as cachefile: - pickle.dump(self.cache, cachefile) +Response = Tuple[Dict[str, Any], List[Dict[str, Any]]] -CACHE = Cache() +class ParsedResponse(NamedTuple): + rows: List[Row] + page: int + pages: int + last_updated: Union[str, None] -def get_json_from_url(url, args): - """ - Fetch a url directly from the World Bank, up to TRIES tries - :url: the url to retrieve - :args: a dictionary of GET arguments - :returns: a string with the url contents - """ - for _ in range(TRIES): - try: - return requests.get(url, args).text - except requests.ConnectionError: - continue - logging.error(f"Error connecting to {url}") - raise RuntimeError("Couldn't connect to API") - - -def get_response(url, args, cache=True): - """ - Get single page response from World Bank API or from cache - : query_url: the base url to be queried - : args: a dictionary of GET arguments - : cache: use the cache - : returns: a dictionary with the response from the API - """ - logging.debug(f"fetching {url}") - key = (url, tuple(sorted(args.items()))) - if cache and key in CACHE: - response = CACHE[key] - else: - response = get_json_from_url(url, args) - if cache: - CACHE[key] = response - return json.loads(response) - - -def fetch(url, args=None, cache=True): - """Fetch data from the World Bank API or from cache. - - Given the base url, keep fetching results until there are no more pages. - - : query_url: the base url to be queried - : args: a dictionary of GET arguments - : cache: use the cache - : returns: a list of dictionaries containing the response to the query - """ - if args is None: - args = {} - else: - args = dict(args) - args["format"] = "json" - args["per_page"] = PER_PAGE - results = [] - pages, this_page = 0, 1 - while pages != this_page: - response = get_response(url, args, cache=cache) - try: - results.extend(response[1]) - this_page = response[0]["page"] - pages = response[0]["pages"] - except (IndexError, KeyError): - try: - message = response[0]["message"][0] - raise RuntimeError( - f"Got error {message['id']} ({message['key']}): " - f"{message['value']}" - ) - except (IndexError, KeyError): - raise RuntimeError( - f"Got unexpected response:\n{pprint.pformat(response)}" - ) - logging.debug(f"Processed page {this_page} of {pages}") - args["page"] = int(this_page) + 1 - for i in results: - if "id" in i: - i["id"] = i["id"].strip() - results = WBResults(results) +def _parse_response(response: Response) -> ParsedResponse: try: - results.last_updated = datetime.datetime.strptime( - response[0]["lastupdated"], "%Y-%m-%d" + return ParsedResponse( + rows=response[1], + page=int(response[0]["page"]), + pages=int(response[0]["pages"]), + last_updated=response[0].get("lastupdated"), ) - except KeyError: - pass - return results + except (IndexError, KeyError) as e: + try: + message = response[0]["message"][0] + raise RuntimeError( + f"Got error {message['id']} ({message['key']}): " f"{message['value']}" + ) from e + except (IndexError, KeyError) as e: + raise RuntimeError( + f"Got unexpected response:\n{pprint.pformat(response)}" + ) from e + + +CacheKey = Tuple[str, Tuple[Tuple[str, Any], ...]] + + +@dataclasses.dataclass +class Fetcher: + cache: MutableMapping[CacheKey, str] + session: requests.Session = dataclasses.field(default_factory=requests.Session) + + @backoff.on_exception( + wait_gen=backoff.expo, + exception=requests.exceptions.ConnectTimeout, + max_tries=TRIES, + ) + def _get_response_body( + self, + url: str, + params: Dict[str, Any], + ) -> str: + """ + Fetch a url directly from the World Bank + + :url: the url to retrieve + :params: a dictionary of GET parameters + :returns: a string with the response content + """ + return self.session.get(url=url, params=params).text + + def _get_response( + self, + url: str, + params: Dict[str, Any], + skip_cache=False, + ) -> ParsedResponse: + """ + Get single page response from World Bank API or from cache + : query_url: the base url to be queried + : params: a dictionary of GET arguments + : skip_cache: bypass the cache + : returns: a dictionary with the response from the API + """ + key = (url, tuple(sorted(params.items()))) + if not skip_cache and key in self.cache: + body = self.cache[key] + else: + body = self._get_response_body(url, params) + self.cache[key] = body + return _parse_response(tuple(json.loads(body))) + + def fetch( + self, + url: str, + params=None, + skip_cache=False, + ) -> Tuple[List[Row], Union[dt.datetime, None]]: + """Fetch data from the World Bank API or from cache. + + Given the base url, keep fetching results until there are no more pages. + + : query_url: the base url to be queried + : params: a dictionary of GET arguments + : skip_cache: use the cache + : returns: a list of dictionaries containing the response to the query + """ + params = params or {} + params["format"] = "json" + params["per_page"] = PER_PAGE + page, pages = -1, -2 + rows: List[Row] = [] + while pages != page: + response = self._get_response( + url=url, + params=params, + skip_cache=skip_cache, + ) + rows.extend(response.rows) + page, pages = response.page, response.pages + logging.debug(f"Processed page {page} of {pages}") + params["page"] = page + 1 + for row in rows: + _strip_id(row) + if response.last_updated is None: + return rows, None + return rows, dt.datetime.strptime(response.last_updated, "%Y-%m-%d") diff --git a/wbdata/py.typed b/wbdata/py.typed new file mode 100644 index 0000000..e69de29 diff --git a/wbdata/types.py b/wbdata/types.py new file mode 100644 index 0000000..24472a1 --- /dev/null +++ b/wbdata/types.py @@ -0,0 +1,17 @@ +import datetime as dt +from typing import Dict, Sequence, Tuple, Union + +Value = Union[str, int, float, dt.datetime, None] +Row = Dict[str, Value] + +IdArg = Union[int, str, Sequence[Union[int, str]], None] + +DateSpec = Union[ + str, + dt.datetime, + Tuple[ + Union[str, dt.datetime], + Union[str, dt.datetime], + ], +] +DateArg = Union[DateSpec, None]