diff --git a/.github/workflows/publish.yaml b/.github/workflows/publish.yaml index afad197..61819cf 100644 --- a/.github/workflows/publish.yaml +++ b/.github/workflows/publish.yaml @@ -11,11 +11,11 @@ jobs: runs-on: ubuntu-latest steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v1.1.1 + uses: actions/setup-python@v5 - name: Install Poetry - uses: dschep/install-poetry-action@v1.2 + uses: snok/install-poetry@v1.3.4 - name: Publish env: POETRY_PYPI_TOKEN_PYPI: ${{ secrets.pypi_token }} diff --git a/.github/workflows/tests.yaml b/.github/workflows/tests.yaml index 48a98d7..06ea2a2 100644 --- a/.github/workflows/tests.yaml +++ b/.github/workflows/tests.yaml @@ -1,27 +1,60 @@ name: Tests on: push: + branches: + - 'master' pull_request: - schedule: - - cron: "1 1 1 * *" jobs: + lint: + name: Lint + runs-on: "ubuntu-latest" + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.8" + - name: Install Ruff + run: pip install ruff + - name: Run Tests + run: ruff check wbdata + types: + name: Types + runs-on: "ubuntu-latest" + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.8" + - name: Install Poetry Action + uses: snok/install-poetry@v1.3.4 + - name: Install Dependencies + run: poetry install --only main,types + - name: Run mypy + run: poetry run mypy wbdata test: name: Tests runs-on: ${{ matrix.os }} strategy: matrix: - python-version: ["3.6", "3.7", "3.8"] + python-version: ["3.8", "3.9", "3.10", "3.11", "3.12"] os: ["ubuntu-latest", "macos-latest", "windows-latest"] + defaults: + run: + shell: bash steps: - name: Checkout - uses: actions/checkout@v1 + uses: actions/checkout@v4 - name: Setup Python - uses: actions/setup-python@v1.1.1 + uses: actions/setup-python@v5 with: python-version: ${{ matrix.python-version }} - - name: Install Poetry - uses: dschep/install-poetry-action@v1.2 + - name: Install Poetry Action + uses: snok/install-poetry@v1.3.4 - name: Install Dependencies - run: poetry install -E pandas + run: poetry install --only main,tests -E pandas - name: Run Tests run: poetry run pytest diff --git a/.gitignore b/.gitignore index ace17f8..7f97897 100644 --- a/.gitignore +++ b/.gitignore @@ -36,3 +36,5 @@ nosetests.xml # Docs docs/_build/ + +.mypy_cache diff --git a/.readthedocs.yml b/.readthedocs.yml index 923b15b..855bddf 100644 --- a/.readthedocs.yml +++ b/.readthedocs.yml @@ -1,8 +1,15 @@ version: 2 formats: all +build: + os: ubuntu-22.04 + tools: + python: "3" python: install: - method: pip path: . extra_requirements: + - pandas - docs +mkdocs: + configuration: mkdocs.yml diff --git a/.tool-versions b/.tool-versions new file mode 100644 index 0000000..5b11157 --- /dev/null +++ b/.tool-versions @@ -0,0 +1,2 @@ +poetry 1.7.1 +python 3.8.13 diff --git a/CHANGES.txt b/CHANGES.txt deleted file mode 100644 index 61e9a74..0000000 --- a/CHANGES.txt +++ /dev/null @@ -1,7 +0,0 @@ -0.1.0, 28 Dec 2012 -- Initial Release -0.2.0, 10 Mar 2013 -- Added get_panel, cleaned up dataset, and changed to use -index by default -0.2.3, 12 Apr 2014 -- fixed index reducer, cache expiration -0.2.5, 21 Apr 2014 -- fixed cache compatibility -0.2.7, 10 May 2014 -- fixed install dependencies - (thanks to Ramiro Gómez for the patch!) diff --git a/README.md b/README.md index 2c637cd..6ff22d3 100644 --- a/README.md +++ b/README.md @@ -1,15 +1,15 @@ # wbdata -| Branch | Status | -|--------|------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------| -| master | [![master branch status](https://github.com/OliverSherouse/wbdata/workflows/Tests/badge.svg?branch=master)](https://github.com/OliverSherouse/wbdata/actions?query=workflow%3A%22Tests%22+branch%3Amaster) | -| dev | [![dev branch status](https://github.com/OliverSherouse/wbdata/workflows/Tests/badge.svg?branch=dev)](https://github.com/OliverSherouse/wbdata/actions?query=workflow%3A%22Tests%22+branch%3Adev) | +[![Tests](https://github.com/OliverSherouse/wbdata/actions/workflows/tests.yaml/badge.svg?branch=master)](https://github.com/OliverSherouse/wbdata/actions/workflows/tests.yaml) +[![Documentation Status](https://readthedocs.org/projects/wbdata/badge/?version=stable)](https://wbdata.readthedocs.io/en/stable/?badge=stable) +[![PyPI version](https://badge.fury.io/py/wbdata.svg)](https://badge.fury.io/py/wbdata) +[![Downloads](https://static.pepy.tech/badge/wbdata/month)](https://pepy.tech/project/wbdata) Wbdata is a simple python interface to find and request information from the World Bank's various databases, either as a dictionary containing full metadata or as a [pandas](http://pandas.pydata.org) DataFrame or series. Currently, wbdata wraps most of the [World Bank API](http://data.worldbank.org/developers/api-overview), and also adds some -convenience functions for searching and retrieving information. +convenience functionality for searching and retrieving information. Documentation is available at . diff --git a/docs/Makefile b/docs/Makefile deleted file mode 100644 index d0c3cbf..0000000 --- a/docs/Makefile +++ /dev/null @@ -1,20 +0,0 @@ -# Minimal makefile for Sphinx documentation -# - -# You can set these variables from the command line, and also -# from the environment for the first two. -SPHINXOPTS ?= -SPHINXBUILD ?= sphinx-build -SOURCEDIR = source -BUILDDIR = build - -# Put it first so that "make" without argument is like "make help". -help: - @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) - -.PHONY: help Makefile - -# Catch-all target: route all unknown targets to Sphinx using the new -# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). -%: Makefile - @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) diff --git a/docs/basic_functionality.md b/docs/basic_functionality.md new file mode 100644 index 0000000..f8db291 --- /dev/null +++ b/docs/basic_functionality.md @@ -0,0 +1,90 @@ +# Basic functionality + +The basic functionality for `wbdata` users is provided by a set of functions in +the topic level package namespace. + +## Data Retrieval + +These are the functions for actually getting data values from the World Bank +API. + +### Raw Data Retrieval + +::: wbdata.client.Client.get_data + options: + show_root_heading: true + show_root_full_path: false + heading_level: 4 + + +### Pandas Data Retrieval + +These functions require Pandas to be installed to work. + +::: wbdata.client.Client.get_series + options: + show_root_heading: true + show_root_full_path: false + heading_level: 4 + +::: wbdata.client.Client.get_dataframe + options: + show_root_heading: true + show_root_full_path: false + heading_level: 4 + +## Metadata Retrieval + +These functions, for the most part, are for finding the parameters you want to +put into the data retrieval functions. These all return +[SearchResult][wbdata.client.SearchResult], which are lists that pretty-print +the table in an interactive environment, and which contain dictionary +representations of the requested resource. + +### Searchable Metadata + +There are enough indicators and countries that it's annoying to look through +them, so the functions for retrieving information about them can be narrowed +with additional facets and filtered with a search term or regular expression +supplied to the `query` parameter. + +::: wbdata.client.Client.get_countries + options: + show_root_heading: true + show_root_full_path: false + heading_level: 4 + + +::: wbdata.client.Client.get_indicators + options: + show_root_heading: true + show_root_full_path: false + heading_level: 4 + +### Indicator Facets + +::: wbdata.client.Client.get_sources + options: + show_root_heading: true + show_root_full_path: false + heading_level: 4 + +::: wbdata.client.Client.get_topics + options: + show_root_heading: true + show_root_full_path: false + heading_level: 4 + +### Country Facets + +::: wbdata.client.Client.get_incomelevels + options: + show_root_heading: true + show_root_full_path: false + heading_level: 4 + +::: wbdata.client.Client.get_lendingtypes + options: + show_root_heading: true + show_root_full_path: false + heading_level: 4 diff --git a/docs/source/index.md b/docs/index.md similarity index 66% rename from docs/source/index.md rename to docs/index.md index d479e72..0dc6978 100644 --- a/docs/source/index.md +++ b/docs/index.md @@ -29,17 +29,13 @@ dealt with. Wbdata is available on [PyPi](https://pypi.python.org/pypi/wbdata) which means you can install using pip: -> pip install -U wbdata +``` bash +pip3 install -U wbdata +``` You can also download or get the source from [GitHub](http://github.com/OliverSherouse/wbdata). -## Detailed Documentation - - - [Wbdata library reference](wbdata_library.md) - - [api module](api_module.md) - - [fetcher module](fetcher_module.md) - ## A Typical User Session Let's say we want to find some data for the ease of doing business in some @@ -47,10 +43,11 @@ well-off countries. I might start off by seeing what sources are available and look promising: ``` ipython -In [1]: import wbdata +In [1]: import wbdata + +In [2]: wbdata.get_sources() +Out[2]: -In [2]: wbdata.get_source() -Out[2]: id name ---- -------------------------------------------------------------------- 1 Doing Business @@ -69,7 +66,6 @@ Out[2]: 20 Quarterly Public Sector Debt 22 Quarterly External Debt Statistics SDDS 23 Quarterly External Debt Statistics GDDS - 24 Poverty and Equity 25 Jobs 27 Global Economic Prospects 28 Global Financial Inclusion @@ -80,14 +76,12 @@ Out[2]: 33 G20 Financial Inclusion Indicators 34 Global Partnership for Education 35 Sustainable Energy for All - 36 Statistical Capacity Indicators 37 LAC Equity Lab 38 Subnational Poverty 39 Health Nutrition and Population Statistics by Wealth Quintile 40 Population estimates and projections 41 Country Partnership Strategy for India (FY2013 - 17) 43 Adjusted Net Savings - 44 Readiness for Investment in Sustainable Energy 45 Indonesia Database for Policy and Economic Research 46 Sustainable Development Goals 50 Subnational Population @@ -107,13 +101,22 @@ Out[2]: 69 Global Financial Inclusion and Consumer Protection Survey 70 Economic Fitness 2 71 International Comparison Program (ICP) 2005 - 72 PEFA_Test 73 Global Financial Inclusion and Consumer Protection Survey (Internal) 75 Environment, Social and Governance (ESG) Data 76 Remittance Prices Worldwide (Sending Countries) 77 Remittance Prices Worldwide (Receiving Countries) 78 ICP 2017 79 PEFA_GRPFM + 80 Gender Disaggregated Labor Database (GDLD) + 81 International Debt Statistics: DSSI + 82 Global Public Procurement + 83 Statistical Performance Indicators (SPI) + 84 Education Policy + 85 PEFA_2021_SNG + 86 Global Jobs Indicators Database (JOIN) + 87 Country Climate and Development Report (CCDR) + 88 Food Prices for Nutrition + 89 Identification for Development (ID4D) Data ``` @@ -122,8 +125,8 @@ we've got available to us there. ``` ipython -In [3]: wbdata.get_indicator(source=1) -Out[3]: +In [3]: wbdata.get_indicators(source=1) +Out[3]: id name ------------------------------------------------- --------------------------------------------------------------------------------------------------------------- ENF.CONT.COEN.ATDR Enforcing contracts: Alternative dispute resolution (0-3) (DB16-20 methodology) @@ -168,13 +171,13 @@ developing a question and go for the most general measure, which is the "Ease of Doing Business Index" with the id "IC.BUS.EASE.XQ". Now remember, we're only interested in high-income countries right now, because -we're elitist. So let's use one of the convenience search functions to figure -out the code for the United States so we don't have to wait for data from a -bunch of other countries: +we're elitist. So let's use the query parameter of the `get_countries` +function to figure out the code for the United States so we don't have to wait +for data from a bunch of other countries: ``` ipython -In [4]: wbdata.search_countries('united') -Out[4]: +In [4]: wbdata.get_countries(query='united') +Out[4]: id name ---- -------------------- ARE United Arab Emirates @@ -187,34 +190,52 @@ data: ``` ipython In [5]: wbdata.get_data("IC.BUS.EASE.XQ", country="USA") -Out[5]: +Out[5]: [{'indicator': {'id': 'IC.BUS.EASE.XQ', - 'value': 'Ease of doing business index (1=most business-friendly regulations)'}, + 'value': 'Ease of doing business rank (1=most business-friendly regulations)'}, 'country': {'id': 'US', 'value': 'United States'}, 'countryiso3code': 'USA', - 'date': '2019', - 'value': 6, + 'date': '2022', + 'value': None, 'unit': '', 'obs_status': '', 'decimal': 0}, {'indicator': {'id': 'IC.BUS.EASE.XQ', - 'value': 'Ease of doing business index (1=most business-friendly regulations)'}, + 'value': 'Ease of doing business rank (1=most business-friendly regulations)'}, 'country': {'id': 'US', 'value': 'United States'}, 'countryiso3code': 'USA', - 'date': '2018', + 'date': '2021', 'value': None, 'unit': '', 'obs_status': '', 'decimal': 0}, {'indicator': {'id': 'IC.BUS.EASE.XQ', - 'value': 'Ease of doing business index (1=most business-friendly regulations)'}, + 'value': 'Ease of doing business rank (1=most business-friendly regulations)'}, 'country': {'id': 'US', 'value': 'United States'}, 'countryiso3code': 'USA', - 'date': '2017', + 'date': '2020', 'value': None, 'unit': '', 'obs_status': '', 'decimal': 0}, + {'indicator': {'id': 'IC.BUS.EASE.XQ', + 'value': 'Ease of doing business rank (1=most business-friendly regulations)'}, + 'country': {'id': 'US', 'value': 'United States'}, + 'countryiso3code': 'USA', + 'date': '2019', + 'value': 6, + 'unit': '', + 'obs_status': '', + 'decimal': 0}, + {'indicator': {'id': 'IC.BUS.EASE.XQ', + 'value': 'Ease of doing business rank (1=most business-friendly regulations)'}, + 'country': {'id': 'US', 'value': 'United States'}, + 'countryiso3code': 'USA', + 'date': '2018', + 'value': None, + 'unit': '', + 'obs_status': '', + 'decimal': 0}] [And so on] ``` @@ -226,12 +247,8 @@ can actually search using multiple countries and restrict the dates using datetime objects. Here's what that would look like: ``` ipython -In [6]: import datetime - -In [7]: data_date = datetime.datetime(2010, 1, 1), datetime.datetime(2011, 1, 1) - -In [8]: wbdata.get_data("IC.BUS.EASE.XQ", country=["USA", "GBR"], data_date=data_date) -Out[8]: +In [6]: wbdata.get_data("IC.BUS.EASE.XQ", country=["USA", "GBR"], date=("2010", "2011")) +Out[6]: [{'indicator': {'id': 'IC.BUS.EASE.XQ', 'value': 'Ease of doing business index (1=most business-friendly regulations)'}, 'country': {'id': 'GB', 'value': 'United Kingdom'}, @@ -275,47 +292,34 @@ And we get another list of dictionaries, which we can parse any which way we please. So let's get a little bit more analytic. Let's say we want to fetch this same -indicator, but also GDP per capita and for all high-income countries. Let's find -the other indicator we want using another convenience search function: +indicator, but also GDP per capita and for all high-income countries. Let's +find the other indicator using the query parameter to search, and limiting +ourselves to indicators from source 2, the World Development Indicators. ``` ipython -In [9]: wbdata.search_indicators("gdp per capita") -Out[9]: -id name --------------------------- ---------------------------------------------------------------------------------------- -6.0.GDPpc_constant GDP per capita, PPP (constant 2011 international $) -FB.DPT.INSU.PC.ZS Deposit insurance coverage (% of GDP per capita) -NV.AGR.PCAP.KD.ZG Real agricultural GDP per capita growth rate (%) -NY.GDP.PCAP.CD GDP per capita (current US$) -NY.GDP.PCAP.CN GDP per capita (current LCU) -NY.GDP.PCAP.KD GDP per capita (constant 2010 US$) -NY.GDP.PCAP.KD.ZG GDP per capita growth (annual %) -NY.GDP.PCAP.KN GDP per capita (constant LCU) -NY.GDP.PCAP.PP.CD GDP per capita, PPP (current international $) -NY.GDP.PCAP.PP.KD GDP per capita, PPP (constant 2017 international $) -NY.GDP.PCAP.PP.KD.87 GDP per capita, PPP (constant 1987 international $) -NY.GDP.PCAP.PP.KD.ZG GDP per capita, PPP annual growth (%) -SE.XPD.PRIM.PC.ZS Government expenditure per student, primary (% of GDP per capita) -SE.XPD.SECO.PC.ZS Government expenditure per student, secondary (% of GDP per capita) -SE.XPD.TERT.PC.ZS Government expenditure per student, tertiary (% of GDP per capita) -UIS.XUNIT.GDPCAP.02.FSGOV Initial government funding per pre-primary student as a percentage of GDP per capita -UIS.XUNIT.GDPCAP.1.FSGOV Initial government funding per primary student as a percentage of GDP per capita -UIS.XUNIT.GDPCAP.1.FSHH Initial household funding per primary student as a percentage of GDP per capita -UIS.XUNIT.GDPCAP.2.FSGOV Initial government funding per lower secondary student as a percentage of GDP per capita -UIS.XUNIT.GDPCAP.23.FSGOV Initial government funding per secondary student as a percentage of GDP per capita -UIS.XUNIT.GDPCAP.23.FSHH Initial household funding per secondary student as a percentage of GDP per capita -UIS.XUNIT.GDPCAP.3.FSGOV Initial government funding per upper secondary student as a percentage of GDP per capita -UIS.XUNIT.GDPCAP.5T8.FSGOV Initial government funding per tertiary student as a percentage of GDP per capita -UIS.XUNIT.GDPCAP.5T8.FSHH Initial household funding per tertiary student as a percentage of GDP per capita +In [7]: wbdata.get_indicators(query="gdp per capita", source=2) +Out[7]: +id name +----------------- ------------------------------------------------------------------- +NY.GDP.PCAP.CD GDP per capita (current US$) +NY.GDP.PCAP.CN GDP per capita (current LCU) +NY.GDP.PCAP.KD GDP per capita (constant 2015 US$) +NY.GDP.PCAP.KD.ZG GDP per capita growth (annual %) +NY.GDP.PCAP.KN GDP per capita (constant LCU) +NY.GDP.PCAP.PP.CD GDP per capita, PPP (current international $) +NY.GDP.PCAP.PP.KD GDP per capita, PPP (constant 2017 international $) +SE.XPD.PRIM.PC.ZS Government expenditure per student, primary (% of GDP per capita) +SE.XPD.SECO.PC.ZS Government expenditure per student, secondary (% of GDP per capita) +SE.XPD.TERT.PC.ZS Government expenditure per student, tertiary (% of GDP per capita) ``` Like good economists, we'll use the one that seems most impressive: GDP per -capita at PPP in constant 2005 dollars, which has the id "NY.GDP.PCAP.PP.KD". +capita at PPP in constant 2017 dollars, which has the id "NY.GDP.PCAP.PP.KD". But what about using high-income countries? ``` ipython -In [10]: wbdata.get_incomelevel() -Out[10]: +In [8]: wbdata.get_incomelevels() +Out[8]: id value ---- ------------------- HIC High income @@ -332,40 +336,37 @@ DataFrame, suitable for analysis with that library, statsmodels, or whatever else we'd like. ``` ipython -In [11]: countries = [i['id'] for i in wbdata.get_country(incomelevel='HIC')] +In [9]: countries = [i['id'] for i in wbdata.get_countries(incomelevel='HIC')] -In [12]: indicators = {"IC.BUS.EASE.XQ": "doing_business", "NY.GDP.PCAP.PP.KD": "gdppc"} +In [10]: indicators = {"IC.BUS.EASE.XQ": "doing_business", "NY.GDP.PCAP.PP.KD": "gdppc"} -In [13]: df = wbdata.get_dataframe(indicators, country=countries, convert_date=True) +In [11]: df = wbdata.get_dataframe(indicators, country=countries, parse_dates=True) -In [14]: df.describe() -Out[14]: +In [12]: df.describe() +Out[12]: doing_business gdppc -count 57.000000 1713.000000 -mean 49.561404 39660.815372 -std 37.568042 21052.599082 -min 1.000000 9492.153507 -25% 20.000000 25522.078578 -50% 41.000000 35889.316248 -75% 72.000000 48233.136498 -max 145.000000 161938.749262 +count 58.000000 2040.000000 +mean 49.534483 40524.953859 +std 36.754384 21654.925324 +min 1.000000 4217.814643 +25% 20.500000 25921.663744 +50% 41.500000 37138.923235 +75% 71.000000 49218.423295 +max 139.000000 157600.647353 + ``` -The `doing_business` variable is only available for 2018, and `gdppc` is only -available for prior years, so let's take the latest observation of each to get -the correlation. +Now we can look at the correlation: ``` ipython -In [15]: df = wbdata.get_dataframe(indicators, country=countries, convert_date=True) +In [13]: df.sort_index().groupby('country').last().corr() +Out[13]: -In [16]: df.sort_index().groupby('country').last().corr() -Out[16]: doing_business gdppc -doing_business 1.000000 -0.393077 -gdppc -0.393077 1.000000 +doing_business 1.000000 -0.407761 +gdppc -0.407761 1.000000 ``` And, since lower scores on that indicator mean more business-friendly -regulations, that's exactly what we would expect. It goes without saying that we -can use our data now to do any other analysis required. +regulations, that's exactly what we would expect. Hooray! diff --git a/docs/make.bat b/docs/make.bat deleted file mode 100644 index 6247f7e..0000000 --- a/docs/make.bat +++ /dev/null @@ -1,35 +0,0 @@ -@ECHO OFF - -pushd %~dp0 - -REM Command file for Sphinx documentation - -if "%SPHINXBUILD%" == "" ( - set SPHINXBUILD=sphinx-build -) -set SOURCEDIR=source -set BUILDDIR=build - -if "%1" == "" goto help - -%SPHINXBUILD% >NUL 2>NUL -if errorlevel 9009 ( - echo. - echo.The 'sphinx-build' command was not found. Make sure you have Sphinx - echo.installed, then set the SPHINXBUILD environment variable to point - echo.to the full path of the 'sphinx-build' executable. Alternatively you - echo.may add the Sphinx directory to PATH. - echo. - echo.If you don't have Sphinx installed, grab it from - echo.http://sphinx-doc.org/ - exit /b 1 -) - -%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% -goto end - -:help -%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O% - -:end -popd diff --git a/docs/reference/cache.md b/docs/reference/cache.md new file mode 100644 index 0000000..131c50e --- /dev/null +++ b/docs/reference/cache.md @@ -0,0 +1,2 @@ +# Cache Module +::: wbdata.cache diff --git a/docs/reference/client.md b/docs/reference/client.md new file mode 100644 index 0000000..392f565 --- /dev/null +++ b/docs/reference/client.md @@ -0,0 +1,2 @@ +# Client Module +::: wbdata.client diff --git a/docs/reference/dates.md b/docs/reference/dates.md new file mode 100644 index 0000000..c7d20e8 --- /dev/null +++ b/docs/reference/dates.md @@ -0,0 +1,3 @@ +# Dates module + +::: wbdata.dates diff --git a/docs/reference/fetcher.md b/docs/reference/fetcher.md new file mode 100644 index 0000000..98d4fcd --- /dev/null +++ b/docs/reference/fetcher.md @@ -0,0 +1,3 @@ +# Fetcher Module + +::: wbdata.fetcher diff --git a/docs/source/api_module.md b/docs/source/api_module.md deleted file mode 100644 index ef82567..0000000 --- a/docs/source/api_module.md +++ /dev/null @@ -1,11 +0,0 @@ -# wbdata.api - -The api module is where all the nuts-and-bolts functions are defined, as well as -those which are imported into the main namespace. - -## Reference - -``` eval_rst -.. automodule:: wbdata.api - :members: -``` diff --git a/docs/source/conf.py b/docs/source/conf.py deleted file mode 100644 index 9917d85..0000000 --- a/docs/source/conf.py +++ /dev/null @@ -1,82 +0,0 @@ -# Configuration file for the Sphinx documentation builder. -# -# This file only contains a selection of the most common options. For a full -# list see the documentation: -# https://www.sphinx-doc.org/en/master/usage/configuration.html - -# -- Path setup -------------------------------------------------------------- - -# If extensions (or modules to document with autodoc) are in another directory, -# add these directories to sys.path here. If the directory is relative to the -# documentation root, use os.path.abspath to make it absolute, like shown here. -# -# import os -# import sys -# sys.path.insert(0, os.path.abspath('.')) - -from recommonmark.transform import AutoStructify - - -# -- Project information ----------------------------------------------------- - -project = "wbdata" -copyright = "2012-2020, Oliver Sherouse" -author = "Oliver Sherouse" - -# The full version, including alpha/beta/rc tags -release = "0.3.0" - - -# -- General configuration --------------------------------------------------- - -# Add any Sphinx extension module names here, as strings. They can be -# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom -# ones. -extensions = [ - "sphinx.ext.autodoc", - "recommonmark", - "IPython.sphinxext.ipython_console_highlighting", -] - -# Add any paths that contain templates here, relative to this directory. -templates_path = ["_templates"] - - -source_suffix = [".rst", ".md"] - -# List of patterns, relative to source directory, that match files and -# directories to ignore when looking for source files. -# This pattern also affects html_static_path and html_extra_path. -exclude_patterns = [] - - -# -- Options for HTML output ------------------------------------------------- - -# The theme to use for HTML and HTML Help pages. See the documentation for -# a list of builtin themes. -# -html_theme = "alabaster" - -html_sidebars = { - "*": ["about.html", "navigation.html", "relations.html", "searchbox.html"], -} - -# Add any paths that contain custom static files (such as style sheets) here, -# relative to this directory. They are copied after the builtin static files, -# so a file named "default.css" will overwrite the builtin "default.css". -html_static_path = ["_static"] - - -github_doc_root = "https://github.com/OliverSherouse/wbdata/tree/master/docs/" - - -def setup(app): - app.add_config_value( - "recommonmark_config", - { - # "url_resolver": lambda url: github_doc_root + url, - "auto_toc_tree_section": "Detailed Documentation", - }, - True, - ) - app.add_transform(AutoStructify) diff --git a/docs/source/fetcher_module.md b/docs/source/fetcher_module.md deleted file mode 100644 index e54fc90..0000000 --- a/docs/source/fetcher_module.md +++ /dev/null @@ -1,12 +0,0 @@ -# wbdata.fetcher - -Fetcher is the mechanism used by wbdata for reading, paging, caching, and -converting responses from the World Bank. Luckily, most users will have no need -to deal with this module, which is why it is separate. - -## Reference - -``` eval_rst -.. automodule:: wbdata.fetcher - :members: -``` diff --git a/docs/source/wbdata_library.md b/docs/source/wbdata_library.md deleted file mode 100644 index e3441fd..0000000 --- a/docs/source/wbdata_library.md +++ /dev/null @@ -1,36 +0,0 @@ -# wbdata library reference - -Wbdata provides a set of functions that are used to interface with the World -Bank's databases. For any function involving pandas capabilities, pandas must -(obviously) be installed. - -## Finding the data you want - -``` eval_rst - -.. autofunction:: wbdata.search_indicators -.. autofunction:: wbdata.search_countries -.. autofunction:: wbdata.get_source -.. autofunction:: wbdata.get_topic -.. autofunction:: wbdata.get_lendingtype -.. autofunction:: wbdata.get_incomelevel -.. autofunction:: wbdata.get_country -.. autofunction:: wbdata.get_indicator -``` - -## Retrieving your data - -### With JSON - -``` eval_rst -.. autofunction:: wbdata.get_data -``` - -### With Pandas - -``` eval_rst - -.. autofunction:: wbdata.get_series -.. autofunction:: wbdata.get_dataframe - -``` diff --git a/docs/whats_new.md b/docs/whats_new.md new file mode 100644 index 0000000..3dbd64a --- /dev/null +++ b/docs/whats_new.md @@ -0,0 +1,34 @@ +# What's New + +## What's new in wbdata 1.0 + +The 1.0 release of `wbdata` is not *quite* a full rewrite, but is pretty much the next best thing. The architecture has been reworked, function and argument names have been changed to be more consistent and clear, and a few dependencies have been added for better and more reliable functionality. + + +### Features + +* Date arguments can now be strings, not just `datetime.datetime` objects. Strings can be in the year, month, or quarter formats used by the World Bank API or in any other format that can be handled by [dateparser][https://dateparser.readthedocs.io/en/latest/]. +* Default cache behavior can be configured with environment variables, including the path, TTL, and max number of items to cache. See [Cache Module documentation](reference/cache.md) for details. +* Users can now create `Client` objects if they want to set cache behavior programmatically have multiple caches, or supply their own requests Session. +* Caching is now provided using the [shelved_cache](https://github.com/mariushelf/shelved_cache) and [cachetools](https://github.com/tkem/cachetools/) libraries. Since a lot of annoying bugs seemed to come from wbdata's home-rolled cache implementation, this should be a good quality-of-life improvement for many people. +* Type annotations are available. + +### Breaking API Changes + +* Supported version of Python are now 3.8+. +* All of the metadata retrieval functions have been renamed to their plural forms to reflect the fact that they always return a sequence: + + | Old Name | New Name | + |-------------------|--------------------| + | `get_country` | `get_countries` | + | `get_indicator` | `get_indicators` | + | `get_incomelevel` | `get_incomelevels` | + | `get_lendingtype` | `get_lendingtypes` | + | `get_topic` | `get_topics` | + | `get_source` | `get_sources` | + +* The functions `search_countries` and `search_indicators` have been removed. Searching by name is now available using the `query` parameter of the `get_countries` and `get_indicators` functions. +* The parameter `data_date` has been renamed `date`. +* The parameter `convert_dates` has been renamed `parse_dates`. +* The parameter `cache` with a default value `True` has been replaced with a parameter `skip_cache` with a default value of `False`. + diff --git a/mkdocs.yml b/mkdocs.yml new file mode 100644 index 0000000..5832bc0 --- /dev/null +++ b/mkdocs.yml @@ -0,0 +1,15 @@ +site_name: WBData +theme: readthedocs +plugins: + - mkdocstrings: + default_handler: python + handlers: + python: + options: + show_source: false + members_order: source + docstrings_options: + returns_named_value: false + returns_multiple_items: false +watch: + - wbdata diff --git a/pyproject.toml b/pyproject.toml index 3a74b47..82bb214 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,46 +1,79 @@ [tool.poetry] name = "wbdata" -version = "0.3.0.post" +version = "1.0.0" description = "A library to access World Bank data" authors = ["Oliver Sherouse "] license = "GPL-2.0+" readme = "README.md" classifiers = [ - "Development Status :: 4 - Beta", + "Development Status :: 5 - Production/Stable", "Intended Audience :: Science/Research", "Operating System :: OS Independent", "Topic :: Scientific/Engineering", + "Typing :: Typed" ] - repository = "https://github.com/OliverSherouse/wbdata" documentation = "https://wbdata.readthedocs.io/" keywords = ["World Bank", "data", "economics"] [tool.poetry.dependencies] -python = ">=3.6" -decorator = ">=4.0" -requests = ">=2.0" -tabulate = ">=0.8.5" -appdirs = ">=1.4" +python = "^3.8" +requests = "^2.0" +tabulate = "^0.8.5" +appdirs = "^1.4" + +pandas = {version = ">=1,<3", optional=true} +cachetools = "^5.3.2" +shelved-cache = "^0.3.1" +backoff = "^2.2.1" +dateparser = "^1.2.0" +decorator = "^5.1.1" + +mkdocs = {version = "^1.5.3", optional=true} +mkdocstrings = {extras = ["python"], version = "^0.24.0", optional=true} + +[tool.poetry.group.dev.dependencies] +ruff = "^0.1.11" +ipython = "<8" -pandas = {version = ">=0.17", optional=true} -sphinx = {version = "^3.0.3", optional=true} -recommonmark = {version = "^0.6.0", optional=true} -ipython = {version = "^7.16.1", optional=true} +[tool.poetry.group.tests.dependencies] +pytest = "^7.4.4" +pytest-cov = "^4.1.0" + +[tool.poetry.group.types.dependencies] +mypy = "^1.8.0" +types-cachetools = "^5.3.0.7" +types-tabulate = "^0.9.0.20240106" +types-decorator = "^5.1.8.20240106" +types-appdirs = "^1.4.3.5" +types-requests = "^2.31.0.20240106" +types-dateparser = "^1.1.4.20240106" [tool.poetry.extras] pandas = ["pandas"] -docs = ["sphinx", "recommonmark", "ipython"] +docs = ["mkdocs", "mkdocstrings"] + +[tool.ruff.lint] +select = [ + # pycodestyle + "E", + # Pyflakes + "F", + # pyupgrade + "UP", + # flake8-bugbear + "B", + # flake8-simplify + "SIM", + # isort + "I", +] +ignore-init-module-imports = true -[tool.poetry.dev-dependencies] -pytest-flake8 = "^=1.0.6" -ipython = "^=7.16.1" -flake8-bugbear = "^20.1.4" -sphinx = "^3.0.3" -recommonmark = "^0.6.0" -flake8 = "^3.8.3" -pytest = "^5.4.3" +[tool.pytest.ini_options] +minversion = "6.0" +addopts = "--cov=wbdata" [build-system] -requires = ["poetry>=0.12"] -build-backend = "poetry.masonry.api" +requires = ["poetry-core"] +build-backend = "poetry.core.masonry.api" diff --git a/setup.cfg b/setup.cfg deleted file mode 100644 index 34d7ae2..0000000 --- a/setup.cfg +++ /dev/null @@ -1,7 +0,0 @@ -[tool:pytest] -addopts = --flake8 - -[flake8] -max-line-length=80 -select = C,E,F,W,B,B9 -ignore = E501,W391,W503 diff --git a/tests/test_api.py b/tests/test_api.py deleted file mode 100644 index d8e31cf..0000000 --- a/tests/test_api.py +++ /dev/null @@ -1,683 +0,0 @@ -#!/usr/bin/env python3 - -import collections -import datetime as dt -import itertools - -import pytest -import wbdata as wbd - - -SimpleCallDefinition = collections.namedtuple( - "SimpleCallDefinition", ["function", "valid_id", "value"] -) - -SimpleCallSpec = collections.namedtuple( - "SimpleCallSpec", ["function", "result_all", "result_one", "id", "value"] -) - -SIMPLE_CALL_DEFINITIONS = [ - SimpleCallDefinition( - function=wbd.get_country, - valid_id="USA", # USA! USA! - value={ - "id": "USA", - "iso2Code": "US", - "name": "United States", - "region": {"id": "NAC", "iso2code": "XU", "value": "North America"}, - "adminregion": {"id": "", "iso2code": "", "value": ""}, - "incomeLevel": {"id": "HIC", "iso2code": "XD", "value": "High income"}, - "lendingType": {"id": "LNX", "iso2code": "XX", "value": "Not classified"}, - "capitalCity": "Washington D.C.", - "longitude": "-77.032", - "latitude": "38.8895", - }, - ), - SimpleCallDefinition( - function=wbd.get_incomelevel, - valid_id="HIC", - value={"id": "HIC", "iso2code": "XD", "value": "High income"}, - ), - SimpleCallDefinition( - function=wbd.get_lendingtype, - valid_id="IBD", - value={"id": "IBD", "iso2code": "XF", "value": "IBRD"}, - ), - SimpleCallDefinition( - function=wbd.get_source, - valid_id="2", - value={ - "id": "2", - "name": "World Development Indicators", - "code": "WDI", - "description": "", - "url": "", - "dataavailability": "Y", - "metadataavailability": "Y", - "concepts": "3", - }, - ), - SimpleCallDefinition( - function=wbd.get_topic, - valid_id="3", - value={ - "id": "3", - "value": "Economy & Growth", - "sourceNote": ( - "Economic growth is central to economic development. When " - "national income grows, real people benefit. While there is " - "no known formula for stimulating economic growth, data can " - "help policy-makers better understand their countries' " - "economic situations and guide any work toward improvement. " - "Data here covers measures of economic growth, such as gross " - "domestic product (GDP) and gross national income (GNI). It " - "also includes indicators representing factors known to be " - "relevant to economic growth, such as capital stock, " - "employment, investment, savings, consumption, government " - "spending, imports, and exports." - ), - }, - ), - SimpleCallDefinition( - function=wbd.get_indicator, - valid_id="SP.POP.TOTL", - value={ - "id": "SP.POP.TOTL", - "name": "Population, total", - "unit": "", - "source": {"id": "2", "value": "World Development Indicators"}, - "sourceNote": ( - "Total population is based on the de facto definition of " - "population, which counts all residents regardless of legal " - "status or citizenship. The values shown are midyear " - "estimates." - ), - "sourceOrganization": ( - "(1) United Nations Population Division. World Population " - "Prospects: 2019 Revision. (2) Census reports and other " - "statistical publications from national statistical offices, " - "(3) Eurostat: Demographic Statistics, (4) United Nations " - "Statistical Division. Population and Vital Statistics " - "Reprot (various years), (5) U.S. Census Bureau: " - "International Database, and (6) Secretariat of the Pacific " - "Community: Statistics and Demography Programme." - ), - "topics": [ - {"id": "19", "value": "Climate Change"}, - {"id": "8", "value": "Health "}, - ], - }, - ), -] - - -@pytest.fixture(params=SIMPLE_CALL_DEFINITIONS, scope="class") -def simple_call_spec(request): - return SimpleCallSpec( - function=request.param.function, - result_all=request.param.function(), - result_one=request.param.function(request.param.valid_id), - id=request.param.valid_id, - value=request.param.value, - ) - - -class TestSimpleQueries: - """ - Test that results of simple queries are close to what we expect - """ - - def test_simple_all_type(self, simple_call_spec): - assert isinstance(simple_call_spec.result_all, wbd.api.WBSearchResult) - - def test_simple_all_len(self, simple_call_spec): - assert len(simple_call_spec.result_all) > 1 - - def test_simple_all_content(self, simple_call_spec): - expected = [] - for val in simple_call_spec.result_all: - try: - del val["lastupdated"] - except KeyError: - pass - expected.append(val) - assert simple_call_spec.value in expected - - def test_simple_one_type(self, simple_call_spec): - assert isinstance(simple_call_spec.result_one, wbd.api.WBSearchResult) - - def test_simple_one_len(self, simple_call_spec): - assert len(simple_call_spec.result_one) == 1 - - def test_simple_one_content(self, simple_call_spec): - got = simple_call_spec.result_one[0] - try: - del got["lastupdated"] - except KeyError: - pass - assert simple_call_spec.result_one[0] == simple_call_spec.value - - def test_simple_bad_call(self, simple_call_spec): - with pytest.raises(RuntimeError): - simple_call_spec.function("Ain'tNotAThing") - - -class TestGetIndicator: - """Extra tests for Get Indicator""" - - def testGetIndicatorBySource(self): - indicators = wbd.get_indicator(source=1) - assert all(i["source"]["id"] == "1" for i in indicators) - - def testGetIndicatorByTopic(self): - indicators = wbd.get_indicator(topic=1) - assert all(any(t["id"] == "1" for t in i["topics"]) for i in indicators) - - def testGetIndicatorBySourceAndTopicFails(self): - with pytest.raises(ValueError): - wbd.get_indicator(source="1", topic=1) - - -SearchDefinition = collections.namedtuple( - "SearchDefinition", - ["function", "query", "value", "facets", "facet_matches", "facet_mismatches"], -) - -SearchData = collections.namedtuple( - "SearchData", - [ - "function", - "query", - "value", - "facets", - "results", - "results_facet_matches", - "results_facet_mismatches", - ], -) - -search_definitions = [ - SearchDefinition( - function=wbd.search_countries, - query="United", - value={ - "id": "USA", - "iso2Code": "US", - "name": "United States", - "region": {"id": "NAC", "iso2code": "XU", "value": "North America"}, - "adminregion": {"id": "", "iso2code": "", "value": ""}, - "incomeLevel": {"id": "HIC", "iso2code": "XD", "value": "High income"}, - "lendingType": {"id": "LNX", "iso2code": "XX", "value": "Not classified"}, - "capitalCity": "Washington D.C.", - "longitude": "-77.032", - "latitude": "38.8895", - }, - facets=["incomelevel", "lendingtype"], - facet_matches=["HIC", "LNX"], - facet_mismatches=["LIC", "IDX"], - ) -] - - -GetDataSpec = collections.namedtuple( - "GetDataSpec", - [ - "result", - "country", - "data_date", - "source", - "convert_date", - "expected_country", - "expected_date", - "expected_value", - ], -) - - -@pytest.fixture(params=search_definitions, scope="class") -def search_data(request): - return SearchData( - function=request.param.function, - query=request.param.query, - value=request.param.value, - facets=request.param.facets, - results=request.param.function(request.param.query), - results_facet_matches=[ - request.param.function(request.param.query, **{facet: value}) - for facet, value in zip(request.param.facets, request.param.facet_matches) - ], - results_facet_mismatches=[ - request.param.function(request.param.query, **{facet: value}) - for facet, value in zip( - request.param.facets, request.param.facet_mismatches - ) - ], - ) - - -class TestSearchFunctions: - def test_search_return_type(self, search_data): - assert isinstance(search_data.results, wbd.api.WBSearchResult) - - def test_facet_return_type(self, search_data): - for results in ( - search_data.results_facet_matches + search_data.results_facet_mismatches - ): - assert isinstance(results, wbd.api.WBSearchResult) - - def test_plain_search(self, search_data): - assert search_data.value in search_data.results - - def test_matched_faceted_searches(self, search_data): - for results in search_data.results_facet_matches: - assert search_data.value in results - - def test_mismatched_faceted_searches(self, search_data): - for results in search_data.results_facet_mismatches: - assert search_data.value not in results - - -COUNTRY_NAMES = { - "ERI": "Eritrea", - "GNQ": "Equatorial Guinea", -} - -common_data_facets = [ - ["all", "ERI", ["ERI", "GNQ"]], - [ - None, - dt.datetime(2010, 1, 1), - [dt.datetime(2010, 1, 1), dt.datetime(2011, 1, 1)], - ], - [None, "2", "11"], - [False, True], -] -get_data_defs = itertools.product(*common_data_facets) - - -@pytest.fixture(params=get_data_defs, scope="class") -def get_data_spec(request): - country, data_date, source, convert_date = request.param - return GetDataSpec( - result=wbd.get_data( - "NY.GDP.MKTP.CD", - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - ), - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - expected_country="Eritrea", - expected_date=dt.datetime(2010, 1, 1) if convert_date else "2010", - expected_value={"2": 2117039512.19512, "11": 2117008130.0813}[source or "2"], - ) - - -class TestGetData: - def test_result_type(self, get_data_spec): - assert isinstance(get_data_spec.result, wbd.fetcher.WBResults) - - def test_country(self, get_data_spec): - if get_data_spec.country == "all": - return - expected = ( - {get_data_spec.country} - if isinstance(get_data_spec.country, str) - else set(get_data_spec.country) - ) - # This is a little complicated because the API returns the iso3 id - # in different places from different sources (which is insane) - got = set( - [ - i["countryiso3code"] if i["countryiso3code"] else i["country"]["id"] - for i in get_data_spec.result - ] - ) - try: - assert got == expected - except AssertionError: - raise - - # Tests both string and converted dates - def test_data_date(self, get_data_spec): - if get_data_spec.data_date is None: - return - expected = ( - set(get_data_spec.data_date) - if isinstance(get_data_spec.data_date, collections.Sequence) - else {get_data_spec.data_date} - ) - if not get_data_spec.convert_date: - expected = {date.strftime("%Y") for date in expected} - - got = {i["date"] for i in get_data_spec.result} - assert got == expected - - # Tests source and correct value - def test_content(self, get_data_spec): - got = next( - datum["value"] - for datum in get_data_spec.result - if datum["country"]["value"] == get_data_spec.expected_country - and datum["date"] == get_data_spec.expected_date - ) - assert got == get_data_spec.expected_value - - def testLastUpdated(self, get_data_spec): - assert isinstance(get_data_spec.result.last_updated, dt.datetime) - - def test_monthly_freq(self): - got = wbd.get_data( - "DPANUSSPB", country="bra", data_date=dt.datetime(2012, 1, 1), freq="M" - )[0]["value"] - assert got == 1.78886363636 - - def test_quarterly_freq(self): - got = wbd.get_data( - "DP.DOD.DECD.CR.BC.CD", - country="chl", - data_date=dt.datetime(2013, 1, 1), - freq="Q", - )[0]["value"] - assert got == 31049138725.7794 - - -series_data_facets = tuple( - itertools.product(*(common_data_facets + [["value", "other"], [False, True]])) -) - - -GetSeriesSpec = collections.namedtuple( - "GetSeriesSpec", - [ - "result", - "country", - "data_date", - "source", - "convert_date", - "column_name", - "keep_levels", - "expected_country", - "expected_date", - "expected_value", - "country_in_index", - "date_in_index", - ], -) - - -@pytest.fixture(params=series_data_facets, scope="class") -def get_series_spec(request): - ( - country, - data_date, - source, - convert_date, - column_name, - keep_levels, - ) = request.param - return GetSeriesSpec( - result=wbd.get_series( - "NY.GDP.MKTP.CD", - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - column_name=column_name, - keep_levels=keep_levels, - ), - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - column_name=column_name, - keep_levels=keep_levels, - expected_country="Eritrea", - expected_date=dt.datetime(2010, 1, 1) if convert_date else "2010", - expected_value={"2": 2117039512.19512, "11": 2117008130.0813}[source or "2"], - country_in_index=( - country == "all" or not isinstance(country, str) or keep_levels - ), - date_in_index=(not isinstance(data_date, dt.datetime) or keep_levels), - ) - - -class TestGetSeries: - def test_index_labels(self, get_series_spec): - index = get_series_spec.result.index - if get_series_spec.country_in_index: - if get_series_spec.date_in_index: - assert index.names == ["country", "date"] - else: - assert index.name == "country" - else: - assert index.name == "date" - - def test_country(self, get_series_spec): - if not get_series_spec.country_in_index: - return - got = sorted(get_series_spec.result.index.unique(level="country")) - - if get_series_spec.country == "all": - assert len(got) > 2 - elif isinstance(get_series_spec.country, str): - assert len(got) == 1 - assert got[0] == COUNTRY_NAMES[get_series_spec.country] - else: - assert got == sorted( - COUNTRY_NAMES[country] for country in get_series_spec.country - ) - - def test_date(self, get_series_spec): - if not get_series_spec.date_in_index: - return - got = sorted(get_series_spec.result.index.unique(level="date")) - if get_series_spec.data_date is None: - assert len(got) > 2 - elif isinstance(get_series_spec.data_date, collections.Sequence): - assert got == sorted( - date if get_series_spec.convert_date else date.strftime("%Y") - for date in get_series_spec.data_date - ) - else: - assert len(got) == 1 - assert got[0] == ( - get_series_spec.data_date - if get_series_spec.convert_date - else get_series_spec.data_date.strftime("%Y") - ) - - def test_column_name(self, get_series_spec): - assert get_series_spec.result.name == get_series_spec.column_name - - def test_value(self, get_series_spec): - if get_series_spec.country_in_index: - if get_series_spec.date_in_index: - index_loc = ( - get_series_spec.expected_country, - get_series_spec.expected_date, - ) - else: - index_loc = get_series_spec.expected_country - else: - index_loc = get_series_spec.expected_date - - assert get_series_spec.result[index_loc] == get_series_spec.expected_value - - def test_last_updated(self, get_series_spec): - assert isinstance(get_series_spec.result.last_updated, dt.datetime) - - def test_bad_value(self): - with pytest.raises(RuntimeError): - wbd.get_series("AintNotAThing") - - def test_monthly_freq(self): - got = wbd.get_series( - "DPANUSSPB", country="bra", data_date=dt.datetime(2012, 1, 1), freq="M" - )["2012M01"] - assert got == 1.78886363636 - - def test_quarterly_freq(self): - got = wbd.get_series( - "DP.DOD.DECD.CR.BC.CD", - country="chl", - data_date=dt.datetime(2013, 1, 1), - freq="Q", - )["2013Q1"] - assert got == 31049138725.7794 - - -GetDataFrameSpec = collections.namedtuple( - "GetDataFrameSpec", - [ - "result", - "country", - "data_date", - "source", - "convert_date", - "column_names", - "keep_levels", - "expected_country", - "expected_date", - "expected_column", - "expected_value", - "country_in_index", - "date_in_index", - ], -) - - -@pytest.fixture(params=series_data_facets, scope="class") -def get_dataframe_spec(request): - ( - country, - data_date, - source, - convert_date, - column_name, - keep_levels, - ) = request.param - return GetDataFrameSpec( - result=wbd.get_dataframe( - {"NY.GDP.MKTP.CD": column_name, "NY.GDP.MKTP.PP.CD": "ppp"}, - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - keep_levels=keep_levels, - ), - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - column_names=[column_name, "ppp"], - keep_levels=keep_levels, - expected_country="Eritrea", - expected_date=dt.datetime(2010, 1, 1) if convert_date else "2010", - expected_column=column_name, - expected_value={"2": 2117039512.19512, "11": 2117008130.0813}[source or "2"], - country_in_index=( - country == "all" or not isinstance(country, str) or keep_levels - ), - date_in_index=(not isinstance(data_date, dt.datetime) or keep_levels), - ) - - -class TestGetDataFrame: - def test_index_labels(self, get_dataframe_spec): - index = get_dataframe_spec.result.index - if get_dataframe_spec.country_in_index: - if get_dataframe_spec.date_in_index: - assert index.names == ["country", "date"] - else: - assert index.name == "country" - else: - assert index.name == "date" - - def test_country(self, get_dataframe_spec): - if not get_dataframe_spec.country_in_index: - return - got = sorted(get_dataframe_spec.result.index.unique(level="country")) - - if get_dataframe_spec.country == "all": - assert len(got) > 2 - elif isinstance(get_dataframe_spec.country, str): - assert len(got) == 1 - assert got[0] == COUNTRY_NAMES[get_dataframe_spec.country] - else: - assert got == sorted( - COUNTRY_NAMES[country] for country in get_dataframe_spec.country - ) - - def test_date(self, get_dataframe_spec): - if not get_dataframe_spec.date_in_index: - return - got = sorted(get_dataframe_spec.result.index.unique(level="date")) - if get_dataframe_spec.data_date is None: - assert len(got) > 2 - elif isinstance(get_dataframe_spec.data_date, collections.Sequence): - assert got == sorted( - date if get_dataframe_spec.convert_date else date.strftime("%Y") - for date in get_dataframe_spec.data_date - ) - else: - assert len(got) == 1 - assert got[0] == ( - get_dataframe_spec.data_date - if get_dataframe_spec.convert_date - else get_dataframe_spec.data_date.strftime("%Y") - ) - - def test_column_name(self, get_dataframe_spec): - assert ( - get_dataframe_spec.result.columns.tolist() - == get_dataframe_spec.column_names - ) - - def test_value(self, get_dataframe_spec): - if get_dataframe_spec.country_in_index: - if get_dataframe_spec.date_in_index: - index_loc = ( - get_dataframe_spec.expected_country, - get_dataframe_spec.expected_date, - ) - else: - index_loc = get_dataframe_spec.expected_country - else: - index_loc = get_dataframe_spec.expected_date - - assert ( - get_dataframe_spec.result[get_dataframe_spec.expected_column][index_loc] - == get_dataframe_spec.expected_value - ) - - def test_last_updated(self, get_dataframe_spec): - assert all( - isinstance(value, dt.datetime) - for value in get_dataframe_spec.result.last_updated.values() - ) - - def test_bad_value(self): - with pytest.raises(RuntimeError): - wbd.get_dataframe({"AintNotAThing": "bad value"}) - - def test_monthly_freq(self): - got = wbd.get_dataframe( - {"DPANUSSPB": "var"}, - country="bra", - data_date=dt.datetime(2012, 1, 1), - freq="M", - )["var"]["2012M01"] - assert got == 1.78886363636 - - def test_quarterly_freq(self): - got = wbd.get_dataframe( - {"DP.DOD.DECD.CR.BC.CD": "var"}, - country="chl", - data_date=dt.datetime(2013, 1, 1), - freq="Q", - )["var"]["2013Q1"] - assert got == 31049138725.7794 diff --git a/tests/test_client.py b/tests/test_client.py new file mode 100644 index 0000000..9bb278a --- /dev/null +++ b/tests/test_client.py @@ -0,0 +1,809 @@ +import datetime as dt +import itertools +import re +from unittest import mock + +import pandas as pd # type: ignore[import-untyped] +import pytest + +from wbdata import client, fetcher + + +@pytest.mark.parametrize( + ["data", "expected"], + [ + pytest.param( + [{"id": "USA", "name": "United States"}], + "id name\n---- -------------\nUSA United States", + ), + pytest.param( + [{"id": "WB", "value": "World Bank"}], + "id value\n---- ----------\nWB World Bank", + ), + ], +) +def test_search_results_repr(data, expected): + assert repr(client.SearchResult(data)) == expected + + +@pytest.mark.parametrize( + ["arg", "expected"], + [ + pytest.param("foo", "foo", id="string"), + pytest.param(["foo", "bar", "baz"], "foo;bar;baz", id="list of strings"), + pytest.param({1: "a", 2: "b", 3: "c"}, "1;2;3", id="dict of ints"), + pytest.param(5.356, "5.356", id="float"), + ], +) +def test_parse_value_or_iterable(arg, expected): + assert client._parse_value_or_iterable(arg) == expected + + +@pytest.mark.parametrize( + ["value", "expected"], + [ + pytest.param("5.1", 5.1, id="float"), + pytest.param("3", 3.0, id="int"), + pytest.param("heloooo", None, id="non-numeric"), + ], +) +def test_cast_float(value, expected): + assert client._cast_float(value) == expected + + +@pytest.fixture +def mock_client(): + with mock.patch("wbdata.client.fetcher.Fetcher", mock.Mock): + yield client.Client() + + +@pytest.mark.parametrize( + [ + "kwargs", + "expected_url", + "expected_args", + ], + [ + pytest.param( + {"indicator": "FOO"}, + "https://api.worldbank.org/v2/countries/all/indicators/FOO", + {}, + id="simple", + ), + pytest.param( + {"indicator": "FOO", "country": "USA"}, + "https://api.worldbank.org/v2/countries/USA/indicators/FOO", + {}, + id="one country", + ), + pytest.param( + {"indicator": "FOO", "country": ["USA", "GBR"]}, + "https://api.worldbank.org/v2/countries/USA;GBR/indicators/FOO", + {}, + id="two countries", + ), + pytest.param( + {"indicator": "FOO", "date": "2005M02"}, + "https://api.worldbank.org/v2/countries/all/indicators/FOO", + {"date": "2005"}, + id="date", + ), + pytest.param( + {"indicator": "FOO", "date": ("2006M02", "2008M10"), "freq": "Q"}, + "https://api.worldbank.org/v2/countries/all/indicators/FOO", + {"date": "2006Q1:2008Q4"}, + id="date and freq", + ), + pytest.param( + {"indicator": "FOO", "source": "1"}, + "https://api.worldbank.org/v2/countries/all/indicators/FOO", + {"source": "1"}, + id="one source", + ), + pytest.param( + {"indicator": "FOO", "skip_cache": True}, + "https://api.worldbank.org/v2/countries/all/indicators/FOO", + {}, + id="skip cache true", + ), + ], +) +def test_get_data_args(mock_client, kwargs, expected_url, expected_args): + mock_client.fetcher.fetch = mock.Mock(return_value="Foo") + mock_client.get_data(**kwargs) + mock_client.fetcher.fetch.assert_called_once_with( + url=expected_url, + params=expected_args, + skip_cache=kwargs.get("skip_cache", False), + ) + + +def test_parse_dates(mock_client): + expected = [{"date": dt.datetime(2023, 4, 1)}] + mock_client.fetcher.fetch = mock.Mock(return_value=[{"date": "2023Q2"}]) + got = mock_client.get_data("foo", parse_dates=True) + assert got == expected + + +@pytest.mark.parametrize( + ["url", "id_", "skip_cache", "expected_url"], + [ + pytest.param("https://foo.bar", None, False, "https://foo.bar", id="no id"), + pytest.param( + "https://foo.bar", "baz", False, "https://foo.bar/baz", id="one id" + ), + pytest.param( + "https://foo.bar", + ["baz", "bat"], + False, + "https://foo.bar/baz;bat", + id="two ids", + ), + pytest.param("https://foo.bar", None, True, "https://foo.bar", id="nocache"), + ], +) +def test_id_only_query(mock_client, url, id_, skip_cache, expected_url): + mock_client.fetcher.fetch = mock.Mock(return_value=["foo"]) + got = mock_client._id_only_query(url, id_, skip_cache=skip_cache) + assert list(got) == ["foo"] + mock_client.fetcher.fetch.assert_called_once_with( + url=expected_url, skip_cache=skip_cache + ) + + +@pytest.mark.parametrize( + ["function", "id_", "skip_cache", "expected_url"], + [ + (function, id_, skip_cache, f"{host}{path}") + for ((function, host), (id_, path), skip_cache) in itertools.product( + ( + ("get_sources", client.SOURCE_URL), + ("get_incomelevels", client.ILEVEL_URL), + ("get_topics", client.TOPIC_URL), + ("get_lendingtypes", client.LTYPE_URL), + ), + ( + (None, ""), + ("foo", "/foo"), + (["foo", "bar"], "/foo;bar"), + ), + (True, False), + ) + ], +) +def test_id_only_functions(mock_client, function, id_, skip_cache, expected_url): + mock_client.fetcher.fetch = mock.Mock(return_value=["foo"]) + got = getattr(mock_client, function)(id_, skip_cache=skip_cache) + assert list(got) == ["foo"] + mock_client.fetcher.fetch.assert_called_once_with( + url=expected_url, skip_cache=skip_cache + ) + + +@pytest.mark.parametrize( + ["country_id", "incomelevel", "lendingtype", "path", "args", "skip_cache"], + ( + (cid, il, ltype, path, {**il_args, **ltype_args}, skip_cache) # type: ignore[dict-item] + for ( + (cid, path), + (il, il_args), + (ltype, ltype_args), + skip_cache, + ) in itertools.product( + ( + (None, ""), + ("foo", "/foo"), + (["foo", "bar"], "/foo;bar"), + ), + ( + (None, {}), + (2, {"incomeLevel": "2"}), + ([2, 3], {"incomeLevel": "2;3"}), + ), + ( + (None, {}), + (4, {"lendingType": "4"}), + ([4, 5], {"lendingType": "4;5"}), + ), + (True, False), + ) + if cid is None or (il is None and ltype is None) + ), +) +def test_get_countries( + mock_client, country_id, incomelevel, lendingtype, path, args, skip_cache +): + mock_client.fetcher.fetch = mock.Mock(return_value=["foo"]) + got = mock_client.get_countries( + country_id=country_id, + incomelevel=incomelevel, + lendingtype=lendingtype, + skip_cache=skip_cache, + ) + assert list(got) == ["foo"] + if country_id: + mock_client.fetcher.fetch.assert_called_once_with( + url=f"{client.COUNTRIES_URL}{path}", skip_cache=skip_cache + ) + else: + mock_client.fetcher.fetch.assert_called_once_with( + url=f"{client.COUNTRIES_URL}{path}", params=args, skip_cache=skip_cache + ) + + +@pytest.mark.parametrize( + ["kwargs"], + ( + [{"country_id": "foo", "incomelevel": "bar"}], + [{"country_id": "foo", "lendingtype": "bar"}], + ), +) +def test_get_countries_bad(mock_client, kwargs): + with pytest.raises(ValueError, match=r"country_id and aggregates"): + mock_client.get_countries(**kwargs) + + +@pytest.mark.parametrize( + ("indicator", "source", "topic", "skip_cache", "expected_url"), + ( + ("foo", None, None, True, f"{client.INDICATOR_URL}/foo"), + (["foo", "bar"], None, None, False, f"{client.INDICATOR_URL}/foo;bar"), + (None, "foo", None, False, f"{client.SOURCE_URL}/foo/indicators"), + (None, ["foo", "bar"], None, True, f"{client.SOURCE_URL}/foo;bar/indicators"), + (None, None, "foo", False, f"{client.TOPIC_URL}/foo/indicators"), + (None, None, ["foo", "bar"], True, f"{client.TOPIC_URL}/foo;bar/indicators"), + (None, None, None, True, client.INDICATOR_URL), + ), +) +def test_get_indicators( + mock_client, indicator, source, topic, skip_cache, expected_url +): + mock_client.fetcher.fetch = mock.Mock(return_value=[["foo"]]) + got = mock_client.get_indicators( + indicator=indicator, + source=source, + topic=topic, + skip_cache=skip_cache, + ) + assert list(got) == [["foo"]] + mock_client.fetcher.fetch.assert_called_once_with( + url=expected_url, + skip_cache=skip_cache, + ) + + +@pytest.mark.parametrize( + ("indicator", "source", "topic"), + ( + ("foo", "bar", None), + ("foo", None, "baz"), + ("foo", "bar", "baz"), + (None, "foo", "bar"), + ), +) +def test_get_indicator_indicator_and_facet(mock_client, indicator, source, topic): + with pytest.raises(ValueError, match="Cannot specify"): + mock_client.get_indicators(indicator=indicator, source=source, topic=topic) + + +@pytest.mark.parametrize( + ("raw", "query", "expected"), + ( + ( + [{"name": "United States"}, {"name": "Great Britain"}], + "states", + [{"name": "United States"}], + ), + ( + [{"name": "United States"}, {"name": "Great Britain"}], + re.compile("states"), + [], + ), + ), +) +def test_search_countries(mock_client, raw, query, expected): + mock_client.fetcher.fetch = mock.Mock(return_value=raw) + got = mock_client.get_countries(query=query) + assert list(got) == expected + + +@pytest.mark.parametrize( + ("raw", "query", "expected"), + ( + ( + [{"name": "United States"}, {"name": "Great Britain"}], + "states", + [{"name": "United States"}], + ), + ( + [{"name": "United States"}, {"name": "Great Britain"}], + re.compile("states"), + [], + ), + ), +) +def test_search_indicators(mock_client, raw, query, expected): + mock_client.fetcher.fetch = mock.Mock(return_value=raw) + got = mock_client.get_indicators(query=query) + assert list(got) == expected + + +def test_get_series_passthrough(mock_client): + with mock.patch.object( + mock_client, + "get_data", + mock.Mock( + return_value=fetcher.Result( + [{"country": {"value": "usa"}, "date": "2023", "value": "5"}] + ) + ), + ) as mock_get_data: + kwargs = dict( + indicator="foo", + country="usa", + date="2023", + freq="Q", + source="2", + parse_dates=True, + skip_cache=True, + ) + mock_client.get_series(**kwargs) + + mock_get_data.assert_called_once_with(**kwargs) + + +@pytest.mark.parametrize( + ["response", "keep_levels", "expected"], + ( + pytest.param( + fetcher.Result( + [ + {"country": {"value": "usa"}, "date": "2023", "value": "5"}, + {"country": {"value": "usa"}, "date": "2024", "value": "6"}, + {"country": {"value": "gbr"}, "date": "2023", "value": "7"}, + {"country": {"value": "gbr"}, "date": "2024", "value": "8"}, + ] + ), + True, + client.Series( + [5.0, 6.0, 7.0, 8.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ("gbr", "2023"), + ("gbr", "2024"), + ), + names=["country", "date"], + ), + name="value", + ), + id="multi-country, multi-date", + ), + pytest.param( + fetcher.Result( + [ + {"country": {"value": "usa"}, "date": "2023", "value": "5"}, + {"country": {"value": "usa"}, "date": "2024", "value": "6"}, + ] + ), + True, + client.Series( + [5.0, 6.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ), + names=["country", "date"], + ), + name="value", + ), + id="one-country, multi-date, keep_levels", + ), + pytest.param( + fetcher.Result( + [ + {"country": {"value": "usa"}, "date": "2023", "value": "5"}, + {"country": {"value": "gbr"}, "date": "2023", "value": "7"}, + ] + ), + True, + client.Series( + [5.0, 7.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("gbr", "2023"), + ), + names=["country", "date"], + ), + name="value", + ), + id="multi-country, one-date, keep_levels", + ), + pytest.param( + fetcher.Result( + [ + {"country": {"value": "usa"}, "date": "2023", "value": "5"}, + {"country": {"value": "usa"}, "date": "2024", "value": "6"}, + ] + ), + False, + client.Series( + [5.0, 6.0], + index=pd.Index(("2023", "2024"), name="date"), + name="value", + ), + id="one-country, multi-date, no keep_levels", + ), + pytest.param( + fetcher.Result( + [ + {"country": {"value": "usa"}, "date": "2023", "value": "5"}, + {"country": {"value": "gbr"}, "date": "2023", "value": "7"}, + ] + ), + False, + client.Series( + [5.0, 7.0], + index=pd.Index(("usa", "gbr"), name="country"), + name="value", + ), + id="multi-country, one-date, no keep_levels", + ), + pytest.param( + fetcher.Result( + [ + {"country": {"value": "usa"}, "date": "2023", "value": "5"}, + ] + ), + False, + client.Series( + [5.0], + index=pd.Index(("2023",), name="date"), + name="value", + ), + id="one-country, one-date, no keep_levels", + ), + ), +) +def test_get_series(mock_client, response, keep_levels, expected): + with mock.patch.object(mock_client, "get_data", mock.Mock(return_value=response)): + got = mock_client.get_series("foo", keep_levels=keep_levels) + pd.testing.assert_series_equal(got, expected) + + +def test_get_dataframe_passthrough(mock_client): + with mock.patch.object( + mock_client, + "get_series", + mock.Mock( + side_effect=[ + client.Series( + [5.0, 6.0, 7.0, 8.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ("gbr", "2023"), + ("gbr", "2024"), + ), + names=["country", "date"], + ), + ), + client.Series( + [9.0, 10.0, 11.0, 12.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ("gbr", "2023"), + ("gbr", "2024"), + ), + names=["country", "date"], + ), + ), + ] + ), + ) as mock_get_series: + kwargs = dict( + country="usa", + date="2023", + freq="Q", + source="2", + parse_dates=True, + keep_levels=True, + skip_cache=True, + ) + + indicators = {"foo": "bar", "baz": "bat"} + mock_client.get_dataframe(indicators=indicators, **kwargs) + for i, indicator in enumerate(indicators): + assert mock_get_series.mock_calls[i].kwargs == { + "indicator": indicator, + **kwargs, + } + + +@pytest.mark.parametrize( + ("results", "indicators", "keep_levels", "expected"), + ( + pytest.param( + [ + client.Series( + [5.0, 6.0, 7.0, 8.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ("gbr", "2023"), + ("gbr", "2024"), + ), + names=["country", "date"], + ), + ), + client.Series( + [9.0, 10.0, 11.0, 12.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ("gbr", "2023"), + ("gbr", "2024"), + ), + names=["country", "date"], + ), + ), + ], + {"foo": "bar", "baz": "bat"}, + False, + client.DataFrame( + { + "bar": [5.0, 6.0, 7.0, 8.0], + "bat": [9.0, 10.0, 11.0, 12.0], + }, + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ("gbr", "2023"), + ("gbr", "2024"), + ), + names=["country", "date"], + ), + ), + id="matching index, multi-country, multi-year", + ), + pytest.param( + [ + client.Series( + [5.0, 6.0, 7.0, 8.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ("gbr", "2023"), + ("gbr", "2024"), + ), + names=["country", "date"], + ), + ), + client.Series( + [9.0, 10.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ), + names=["country", "date"], + ), + ), + ], + {"foo": "bar", "baz": "bat"}, + False, + client.DataFrame( + { + "bar": [5.0, 6.0, 7.0, 8.0], + "bat": [9.0, 10.0, None, None], + }, + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ("gbr", "2023"), + ("gbr", "2024"), + ), + names=["country", "date"], + ), + ), + id="overlapping index, multi-country, multi-year", + ), + pytest.param( + [ + client.Series( + [7.0], + index=pd.MultiIndex.from_tuples( + tuples=(("gbr", "2023"),), + names=["country", "date"], + ), + ), + client.Series( + [10.0], + index=pd.MultiIndex.from_tuples( + tuples=(("usa", "2024"),), + names=["country", "date"], + ), + ), + ], + {"foo": "bar", "baz": "bat"}, + False, + client.DataFrame( + { + "bar": [None, 7.0], + "bat": [10.0, None], + }, + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2024"), + ("gbr", "2023"), + ), + names=["country", "date"], + ), + ), + id="disjoint index, multi-country, multi-year", + ), + pytest.param( + [ + client.Series( + [5.0, 6.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ), + names=["country", "date"], + ), + ), + client.Series( + [9.0, 10.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ), + names=["country", "date"], + ), + ), + ], + {"foo": "bar", "baz": "bat"}, + True, + client.DataFrame( + { + "bar": [5.0, 6.0], + "bat": [9.0, 10.0], + }, + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ), + names=["country", "date"], + ), + ), + id="One country, keep levels", + ), + pytest.param( + [ + client.Series( + [5.0, 7.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("gbr", "2023"), + ), + names=["country", "date"], + ), + ), + client.Series( + [9.0, 11.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("gbr", "2023"), + ), + names=["country", "date"], + ), + ), + ], + {"foo": "bar", "baz": "bat"}, + True, + client.DataFrame( + { + "bar": [5.0, 7.0], + "bat": [9.0, 11.0], + }, + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("gbr", "2023"), + ), + names=["country", "date"], + ), + ), + id="multi-country, one date, keep levels", + ), + pytest.param( + [ + client.Series( + [5.0, 6.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ), + names=["country", "date"], + ), + ), + client.Series( + [9.0, 10.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("usa", "2024"), + ), + names=["country", "date"], + ), + ), + ], + {"foo": "bar", "baz": "bat"}, + False, + client.DataFrame( + { + "bar": [5.0, 6.0], + "bat": [9.0, 10.0], + }, + index=pd.Index(["2023", "2024"], name="date"), + ), + id="One country, no keep levels", + ), + pytest.param( + [ + client.Series( + [5.0, 7.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("gbr", "2023"), + ), + names=["country", "date"], + ), + ), + client.Series( + [9.0, 11.0], + index=pd.MultiIndex.from_tuples( + tuples=( + ("usa", "2023"), + ("gbr", "2023"), + ), + names=["country", "date"], + ), + ), + ], + {"foo": "bar", "baz": "bat"}, + False, + client.DataFrame( + { + "bar": [5.0, 7.0], + "bat": [9.0, 11.0], + }, + index=pd.Index(["usa", "gbr"], name="country"), + ), + id="multi-country, one date, no keep levels", + ), + ), +) +def test_get_dataframe(mock_client, results, indicators, keep_levels, expected): + with mock.patch.object(mock_client, "get_series", mock.Mock(side_effect=results)): + got = mock_client.get_dataframe(indicators=indicators, keep_levels=keep_levels) + pd.testing.assert_frame_equal(got.loc[expected.index], expected) diff --git a/tests/test_dates.py b/tests/test_dates.py new file mode 100644 index 0000000..086b56b --- /dev/null +++ b/tests/test_dates.py @@ -0,0 +1,128 @@ +import datetime as dt + +import pytest + +from wbdata import dates + + +def test_parse_year(): + assert dates._parse_year("2003") == dt.datetime(2003, 1, 1) + + +def test_parse_month(): + assert dates._parse_month("2003M12") == dt.datetime(2003, 12, 1) + + +def test_parse_quarter(): + assert dates._parse_quarter("2003Q2") == dt.datetime(2003, 4, 1) + + +@pytest.mark.parametrize( + ["rows", "expected"], + [ + pytest.param( + [{"date": dt.datetime(2003, 1, 1)}, {"date": 5.26}], + [{"date": dt.datetime(2003, 1, 1)}, {"date": 5.26}], + id="not strings", + ), + pytest.param( + [{"date": "2003"}, {"date": "2004"}], + [{"date": dt.datetime(2003, 1, 1)}, {"date": dt.datetime(2004, 1, 1)}], + id="years", + ), + pytest.param( + [{"date": "2003M01"}, {"date": "2003M02"}], + [{"date": dt.datetime(2003, 1, 1)}, {"date": dt.datetime(2003, 2, 1)}], + id="months", + ), + pytest.param( + [{"date": "2003Q1"}, {"date": "2003Q02"}], + [{"date": dt.datetime(2003, 1, 1)}, {"date": dt.datetime(2003, 4, 1)}], + id="quarters", + ), + pytest.param( + [{"date": "MRV"}, {"date": "-"}], + [{"date": "MRV"}, {"date": "-"}], + id="MRV and dash", + ), + pytest.param( + [{"date": "2003M1"}, {"date": "MRV"}, {"date": "-"}, {"date": 2003}], + [ + {"date": dt.datetime(2003, 1, 1)}, + {"date": "MRV"}, + {"date": "-"}, + {"date": 2003}, + ], + id="sneaky values", + ), + ], +) +def test_parse_row_dates(rows, expected): + dates.parse_row_dates(rows) + assert rows == expected + + +@pytest.mark.parametrize( + ["date", "freq", "expected"], + [ + pytest.param(dt.datetime(2003, 5, 1), "Y", "2003", id="year"), + pytest.param(dt.datetime(2003, 5, 1), "M", "2003M05", id="month"), + pytest.param(dt.datetime(2003, 5, 1), "Q", "2003Q2", id="quarter"), + ], +) +def test_format_date(date, freq, expected): + assert dates._format_date(date=date, freq=freq) == expected + + +def test_bad_format_date(): + with pytest.raises(ValueError, match=r"Unknown Frequency type"): + dates._format_date(date=dt.datetime(2000, 1, 1), freq="Foobar") + + +@pytest.mark.parametrize( + ["date", "expected"], + [ + pytest.param(dt.datetime(2003, 4, 5), dt.datetime(2003, 4, 5), id="datetime"), + pytest.param("2003", dt.datetime(2003, 1, 1), id="year"), + pytest.param("2003M05", dt.datetime(2003, 5, 1), id="month"), + pytest.param("2003Q2", dt.datetime(2003, 4, 1), id="quarter"), + pytest.param("Feb 3, 2025", dt.datetime(2025, 2, 3), id="dateparser"), + ], +) +def test_parse_date(date, expected): + assert dates._parse_date(date=date) == expected + + +def test_parse_bad_date(): + with pytest.raises(ValueError): + dates._parse_date("Gobbledygook") + + +def test_parse_and_format_date(): + assert dates._parse_and_format_date("Nov 3, 2025", "Q") == "2025Q4" + + +@pytest.mark.parametrize( + ["dates_", "freq", "expected"], + [ + pytest.param(dt.datetime(2003, 4, 5), "Y", "2003", id="single datetime"), + pytest.param("May 2018", "M", "2018M05", id="single string"), + pytest.param( + ( + dt.datetime(2004, 5, 6), + dt.datetime(2005, 6, 7), + ), + "M", + "2004M05:2005M06", + id="two datetimes", + ), + pytest.param( + ("2020M01", "2021M12"), + "Q", + "2020Q1:2021Q4", + id="two strings", + ), + ], +) +def test_format_dates(dates_, freq, expected): + assert dates.format_dates(dates_, freq) == expected diff --git a/tests/test_fetcher.py b/tests/test_fetcher.py index 538d7bb..b19300f 100644 --- a/tests/test_fetcher.py +++ b/tests/test_fetcher.py @@ -1,15 +1,247 @@ +import datetime as dt +import json +from unittest import mock + import pytest -import wbdata.fetcher -import wbdata.api +from wbdata import fetcher + + +@pytest.fixture +def mock_fetcher() -> fetcher.Fetcher: + return fetcher.Fetcher(cache={}, session=mock.Mock()) + + +class MockHTTPResponse: + def __init__(self, value): + self.text = json.dumps(value) + +def test_get_request_content(mock_fetcher): + url = "http://foo.bar" + params = {"baz": "bat"} + expected = {"hello": "there"} + mock_fetcher.session.get = mock.Mock(return_value=MockHTTPResponse(value=expected)) + result = mock_fetcher._get_response_body(url=url, params=params) + mock_fetcher.session.get.assert_called_once_with(url=url, params=params) + assert json.loads(result) == expected -def test_bad_indicator_error(): - expected = ( - r"Got error 120 \(Invalid value\): The provided parameter value is " - r"not valid" + +@pytest.mark.parametrize( + ["url", "params", "response", "expected"], + ( + pytest.param( + "http://foo.bar", + {"baz": "bat"}, + [{"page": "1", "pages": "1"}, [{"hello": "there"}]], + fetcher.ParsedResponse( + rows=[{"hello": "there"}], + page=1, + pages=1, + last_updated=None, + ), + id="No date", + ), + pytest.param( + "http://foo.bar", + {"baz": "bat"}, + [ + {"page": "1", "pages": "1", "lastupdated": "2023-02-01"}, + [{"hello": "there"}], + ], + fetcher.ParsedResponse( + rows=[{"hello": "there"}], + page=1, + pages=1, + last_updated="2023-02-01", + ), + id="with date", + ), + ), +) +def test_get_response(url, params, response, expected, mock_fetcher): + mock_fetcher.session.get = mock.Mock(return_value=MockHTTPResponse(value=response)) + got = mock_fetcher._get_response(url=url, params=params) + mock_fetcher.session.get.assert_called_once_with(url=url, params=params) + assert got == expected + assert mock_fetcher.cache[(url), (("baz", "bat"),)] == json.dumps(response) + + +def test_cache_used(mock_fetcher): + url = "http://foo.bar" + response = [ + {"page": "1", "pages": "1"}, + [{"hello": "there"}], + ] + params = {"baz": "bat"} + expected = fetcher.ParsedResponse( + rows=[{"hello": "there"}], + page=1, + pages=1, + last_updated=None, ) - with pytest.raises(RuntimeError, match=expected): - wbdata.fetcher.fetch( - wbdata.api.COUNTRIES_URL + "/all/AINT.NOT.A.THING" + + mock_fetcher.cache[(url), (("baz", "bat"),)] = json.dumps(response) + mock_fetcher._get_response(url=url, params=params) + got = mock_fetcher._get_response(url=url, params=params) + assert got == expected + assert mock_fetcher.cache[(url), (("baz", "bat"),)] == json.dumps(response) + + +def test_skip_cache(mock_fetcher): + url = "http://foo.bar" + response = [ + {"page": "1", "pages": "1"}, + [{"hello": "there"}], + ] + params = {"baz": "bat"} + expected = fetcher.ParsedResponse( + rows=[{"hello": "there"}], + page=1, + pages=1, + last_updated=None, + ) + mock_fetcher.session.get = mock.Mock(return_value=MockHTTPResponse(value=response)) + mock_fetcher.cache[(url), (("baz", "bat"),)] = json.dumps({"old": "garbage"}) + got = mock_fetcher._get_response(url=url, params=params, skip_cache=True) + mock_fetcher.session.get.assert_called_once_with(url=url, params=params) + assert got == expected + assert mock_fetcher.cache[(url), (("baz", "bat"),)] == json.dumps(response) + + +@pytest.mark.parametrize( + ["url", "params", "responses", "expected"], + ( + pytest.param( + "http://foo.bar", + {"baz": "bat"}, + [ + [{"page": "1", "pages": "1"}, [{"hello": "there"}]], + ], + fetcher.Result([{"hello": "there"}], last_updated=None), + id="No date", + ), + pytest.param( + "http://foo.bar", + {"baz": "bat"}, + [ + [ + {"page": "1", "pages": "1", "lastupdated": "2023-02-01"}, + [{"hello": "there"}], + ], + ], + fetcher.Result([{"hello": "there"}], last_updated=dt.datetime(2023, 2, 1)), + id="with date", + ), + pytest.param( + "http://foo.bar", + {"baz": "bat"}, + [ + [ + {"page": "1", "pages": "2", "lastupdated": "2023-02-01"}, + [{"hello": "there"}], + ], + [ + {"page": "2", "pages": "2", "lastupdated": "2023-02-01"}, + [{"howare": "you"}], + ], + ], + fetcher.Result( + [{"hello": "there"}, {"howare": "you"}], + last_updated=dt.datetime(2023, 2, 1), + ), + id="paged with date", + ), + pytest.param( + "http://foo.bar", + {"baz": "bat"}, + [ + [ + {"page": "1", "pages": "2", "lastupdated": "2023-02-01"}, + [{"hello": "there"}], + ], + [ + {"page": "2", "pages": "2"}, + [{"howare": "you"}], + ], + ], + fetcher.Result([{"hello": "there"}, {"howare": "you"}], last_updated=None), + id="paged without date", + ), + ), +) +def test_fetch(url, params, responses, expected, mock_fetcher): + mock_fetcher.session.get = mock.Mock( + side_effect=[MockHTTPResponse(value=response) for response in responses] + ) + + got = mock_fetcher.fetch(url=url, params=params) + expected_params = [ + { + "per_page": fetcher.PER_PAGE, + "format": "json", + **({"page": i + 1} if i else {}), + **params, + } + for i in range(len(responses)) + ] + got_params = [i.kwargs["params"] for i in mock_fetcher.session.get.mock_calls] + + assert got == expected + assert expected_params == got_params + for response, rparams in zip(responses, expected_params): + assert mock_fetcher.cache[url, tuple(sorted(rparams.items()))] == json.dumps( + response ) + + +@pytest.mark.parametrize( + ["response", "expected"], + [ + pytest.param( + [{"message": [{"id": "baderror", "key": "nogood", "value": "dontlikeit"}]}], + r"Got error baderror \(nogood\): dontlikeit", + id="no rows", + ), + pytest.param( + [ + { + "pages": 2, + "message": [ + {"id": "baderror", "key": "nogood", "value": "dontlikeit"} + ], + }, + [], + ], + r"Got error baderror \(nogood\): dontlikeit", + id="no page", + ), + pytest.param( + [ + { + "page": 1, + "message": [ + {"id": "baderror", "key": "nogood", "value": "dontlikeit"} + ], + }, + [], + ], + r"Got error baderror \(nogood\): dontlikeit", + id="no pages", + ), + pytest.param( + [ + { + "page": 1, + "message": [{"key": "nogood", "value": "dontlikeit"}], + }, + [], + ], + r"Got unexpected response", + id="improper error", + ), + ], +) +def test_parse_response_errors(response, expected): + with pytest.raises(RuntimeError, match=expected): + fetcher.ParsedResponse.from_response(response) diff --git a/wbdata/__init__.py b/wbdata/__init__.py index e93b56d..fa669d2 100644 --- a/wbdata/__init__.py +++ b/wbdata/__init__.py @@ -1,18 +1,25 @@ """ wbdata: A wrapper for the World Bank API """ -__version__ = "0.3.0.post" -from .api import ( # noqa: F401 - get_country, - get_data, - get_series, - get_dataframe, - get_indicator, - get_incomelevel, - get_lendingtype, - get_source, - get_topic, - search_countries, - search_indicators, -) +__version__ = "1.0.0" + +from .client import Client + + +def get_default_client() -> Client: + """ + Get the default client + """ + return Client() + + +get_data = get_default_client().get_data +get_series = get_default_client().get_series +get_dataframe = get_default_client().get_dataframe +get_countries = get_default_client().get_countries +get_indicators = get_default_client().get_indicators +get_incomelevels = get_default_client().get_incomelevels +get_lendingtypes = get_default_client().get_lendingtypes +get_sources = get_default_client().get_sources +get_topics = get_default_client().get_topics diff --git a/wbdata/api.py b/wbdata/api.py deleted file mode 100644 index ba59e16..0000000 --- a/wbdata/api.py +++ /dev/null @@ -1,505 +0,0 @@ -""" -wbdata.api: Where all the functions go -""" - -import collections -import datetime -import re -import warnings - -import tabulate - -try: - import pandas as pd -except ImportError: - pd = None - -from decorator import decorator -from . import fetcher - -BASE_URL = "https://api.worldbank.org/v2" -COUNTRIES_URL = f"{BASE_URL}/countries" -ILEVEL_URL = f"{BASE_URL}/incomeLevels" -INDICATOR_URL = f"{BASE_URL}/indicators" -LTYPE_URL = f"{BASE_URL}/lendingTypes" -SOURCES_URL = f"{BASE_URL}/sources" -TOPIC_URL = f"{BASE_URL}/topics" -INDIC_ERROR = "Cannot specify more than one of indicator, source, and topic" - - -class WBSearchResult(list): - """ - A list that prints out a user-friendly table when printed or returned on the - command line - - - Items are expected to be dict-like and have an "id" key and a "name" or - "value" key - """ - - def __repr__(self): - try: - return tabulate.tabulate( - [[o["id"], o["name"]] for o in self], - headers=["id", "name"], - tablefmt="simple", - ) - except KeyError: - return tabulate.tabulate( - [[o["id"], o["value"]] for o in self], - headers=["id", "value"], - tablefmt="simple", - ) - - -if pd: - - class WBSeries(pd.Series): - """ - A pandas Series with a last_updated attribute - """ - - _metadata = ["last_updated"] - - @property - def _constructor(self): - return WBSeries - - class WBDataFrame(pd.DataFrame): - """ - A pandas DataFrame with a last_updated attribute - """ - - _metadata = ["last_updated"] - - @property - def _constructor(self): - return WBDataFrame - - -@decorator -def uses_pandas(f, *args, **kwargs): - """Raise ValueError if pandas is not loaded""" - if not pd: - raise ValueError("Pandas must be installed to be used") - return f(*args, **kwargs) - - -def parse_value_or_iterable(arg): - """ - If arg is a single value, return it as a string; if an iterable, return a - ;-joined string of all values - """ - if str(arg) == arg: - return arg - if type(arg) == int: - return str(arg) - return ";".join(arg) - - -def convert_year_to_datetime(yearstr): - """return datetime.datetime object from %Y formatted string""" - return datetime.datetime.strptime(yearstr, "%Y") - - -def convert_month_to_datetime(monthstr): - """return datetime.datetime object from %YM%m formatted string""" - split = monthstr.split("M") - return datetime.datetime(int(split[0]), int(split[1]), 1) - - -def convert_quarter_to_datetime(quarterstr): - """ - return datetime.datetime object from %YQ%# formatted string, where # is - the desired quarter - """ - split = quarterstr.split("Q") - quarter = int(split[1]) - month = quarter * 3 - 2 - return datetime.datetime(int(split[0]), month, 1) - - -def convert_dates_to_datetime(data): - """ - Return a datetime.datetime object from a date string as provided by the - World Bank - """ - first = data[0]["date"] - if isinstance(first, datetime.datetime): - return data - if "M" in first: - converter = convert_month_to_datetime - elif "Q" in first: - converter = convert_quarter_to_datetime - else: - converter = convert_year_to_datetime - for datum in data: - datum_date = datum["date"] - if "MRV" in datum_date: - continue - if "-" in datum_date: - continue - datum["date"] = converter(datum_date) - return data - - -def cast_float(value): - """ - Return a floated value or none - """ - try: - return float(value) - except (ValueError, TypeError): - return None - - -def get_series( - indicator, - country="all", - data_date=None, - freq="Y", - source=None, - convert_date=False, - column_name="value", - keep_levels=False, - cache=True, -): - """ - Retrieve indicators for given countries and years - - :indicator: the desired indicator code - :country: a country code, sequence of country codes, or "all" (default) - :data_date: the desired date as a datetime object or a 2-tuple with start - and end dates - :freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' - (monthly), or 'Q' (quarterly). The indicator may or may not support the - specified frequency. - :source: the specific source to retrieve data from (defaults on API to 2, - World Development Indicators) - :convert_date: if True, convert date field to a datetime.datetime object. - :column_name: the desired name for the pandas column - :keep_levels: if True and pandas is True, don't reduce the number of index - levels returned if only getting one date or country - :cache: use the cache - :returns: WBSeries - """ - raw_data = get_data( - indicator=indicator, - country=country, - data_date=data_date, - freq=freq, - source=source, - convert_date=convert_date, - cache=cache, - ) - df = pd.DataFrame( - [[i["country"]["value"], i["date"], i["value"]] for i in raw_data], - columns=["country", "date", column_name], - ) - df[column_name] = df[column_name].map(cast_float) - if not keep_levels and len(df["country"].unique()) == 1: - df = df.set_index("date") - elif not keep_levels and len(df["date"].unique()) == 1: - df = df.set_index("country") - else: - df = df.set_index(["country", "date"]) - series = WBSeries(df[column_name]) - series.last_updated = raw_data.last_updated - return series - - -def data_date_to_str(data_date, freq): - """ - Convert data_date to the appropriate representation base on freq - - - :data_date: A datetime.datetime object to be formatted - :freq: One of 'Y' (year), 'M' (month) or 'Q' (quarter) - - """ - if freq == "Y": - return data_date.strftime("%Y") - if freq == "M": - return data_date.strftime("%YM%m") - if freq == "Q": - return f"{data_date.year}Q{(data_date.month - 1) // 3 + 1}" - - -def get_data( - indicator, - country="all", - data_date=None, - freq="Y", - source=None, - convert_date=False, - pandas=False, - column_name="value", - keep_levels=False, - cache=True, -): - """ - Retrieve indicators for given countries and years - - :indicator: the desired indicator code - :country: a country code, sequence of country codes, or "all" (default) - :data_date: the desired date as a datetime object or a 2-tuple with start - and end dates - :freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' - (monthly), or 'Q' (quarterly). The indicator may or may not support the - specified frequency. - :source: the specific source to retrieve data from (defaults on API to 2, - World Development Indicators) - :convert_date: if True, convert date field to a datetime.datetime object. - :cache: use the cache - :returns: list of dictionaries - """ - if pandas: - warnings.warn( - ( - "Argument 'pandas' is deprecated and will be removed in a " - "future version. Use get_series or get_dataframe instead." - ), - PendingDeprecationWarning, - ) - return get_series( - indicator=indicator, - country=country, - data_date=data_date, - source=source, - convert_date=convert_date, - column_name=column_name, - keep_levels=keep_levels, - cache=cache, - ) - query_url = COUNTRIES_URL - try: - c_part = parse_value_or_iterable(country) - except TypeError: - raise TypeError("'country' must be a string or iterable'") - query_url = "/".join((query_url, c_part, "indicators", indicator)) - args = {} - if data_date: - args["date"] = ( - ":".join(data_date_to_str(dd, freq) for dd in data_date) - if isinstance(data_date, collections.Sequence) - else data_date_to_str(data_date, freq) - ) - if source: - args["source"] = source - data = fetcher.fetch(query_url, args, cache=cache) - if convert_date: - data = convert_dates_to_datetime(data) - return data - - -def id_only_query(query_url, query_id, cache): - """ - Retrieve information when ids are the only arguments - - :query_url: the base url to use for the query - :query_id: an id or sequence of ids - :cache: use the cache - :returns: WBSearchResult containing dictionary objects describing results - """ - if query_id: - query_url = "/".join((query_url, parse_value_or_iterable(query_id))) - return WBSearchResult(fetcher.fetch(query_url)) - - -def get_source(source_id=None, cache=True): - """ - Retrieve information on a source - - :source_id: a source id or sequence thereof. None returns all sources - :cache: use the cache - :returns: WBSearchResult containing dictionary objects describing selected - sources - """ - return id_only_query(SOURCES_URL, source_id, cache=cache) - - -def get_incomelevel(level_id=None, cache=True): - """ - Retrieve information on an income level aggregate - - :level_id: a level id or sequence thereof. None returns all income level - aggregates - :cache: use the cache - :returns: WBSearchResult containing dictionary objects describing selected - income level aggregates - """ - return id_only_query(ILEVEL_URL, level_id, cache=cache) - - -def get_topic(topic_id=None, cache=True): - """ - Retrieve information on a topic - - :topic_id: a topic id or sequence thereof. None returns all topics - :cache: use the cache - :returns: WBSearchResult containing dictionary objects describing selected - topic aggregates - """ - return id_only_query(TOPIC_URL, topic_id, cache=cache) - - -def get_lendingtype(type_id=None, cache=True): - """ - Retrieve information on an income level aggregate - - :level_id: lending type id or sequence thereof. None returns all lending - type aggregates - :cache: use the cache - :returns: WBSearchResult containing dictionary objects describing selected - topic aggregates - """ - return id_only_query(LTYPE_URL, type_id, cache=cache) - - -def get_country(country_id=None, incomelevel=None, lendingtype=None, cache=True): - """ - Retrieve information on a country or regional aggregate. Can specify - either country_id, or the aggregates, but not both - - :country_id: a country id or sequence thereof. None returns all countries - and aggregates. - :incomelevel: desired incomelevel id or ids. - :lendingtype: desired lendingtype id or ids. - :cache: use the cache - :returns: WBSearchResult containing dictionary objects representing each - country - """ - if country_id: - if incomelevel or lendingtype: - raise ValueError("Can't specify country_id and aggregates") - return id_only_query(COUNTRIES_URL, country_id, cache=cache) - args = {} - if incomelevel: - args["incomeLevel"] = parse_value_or_iterable(incomelevel) - if lendingtype: - args["lendingType"] = parse_value_or_iterable(lendingtype) - return WBSearchResult(fetcher.fetch(COUNTRIES_URL, args, cache=cache)) - - -def get_indicator(indicator=None, source=None, topic=None, cache=True): - """ - Retrieve information about an indicator or indicators. Only one of - indicator, source, and topic can be specified. Specifying none of the - three will return all indicators. - - :indicator: an indicator code or sequence thereof - :source: a source id or sequence thereof - :topic: a topic id or sequence thereof - :cache: use the cache - :returns: WBSearchResult containing dictionary objects representing - indicators - """ - if indicator: - if source or topic: - raise ValueError(INDIC_ERROR) - query_url = "/".join((INDICATOR_URL, parse_value_or_iterable(indicator))) - elif source: - if topic: - raise ValueError(INDIC_ERROR) - query_url = "/".join( - (SOURCES_URL, parse_value_or_iterable(source), "indicators") - ) - elif topic: - query_url = "/".join((TOPIC_URL, parse_value_or_iterable(topic), "indicators")) - else: - query_url = INDICATOR_URL - return WBSearchResult(fetcher.fetch(query_url, cache=cache)) - - -def search_indicators(query, source=None, topic=None, cache=True): - """ - Search indicators for a certain regular expression. Only one of source or - topic can be specified. In interactive mode, will return None and print ids - and names unless suppress_printing is True. - - :query: the term to match against indicator names - :source: if present, id of desired source - :topic: if present, id of desired topic - :cache: use the cache - :returns: WBSearchResult containing dictionary objects representing search - indicators - """ - indicators = get_indicator(source=source, topic=topic, cache=cache) - pattern = re.compile(query, re.IGNORECASE) - return WBSearchResult(i for i in indicators if pattern.search(i["name"])) - - -def search_countries(query, incomelevel=None, lendingtype=None, cache=True): - """ - Search countries by name. Very simple search. - - :query: the string to match against country names - :incomelevel: if present, search only the matching incomelevel - :lendingtype: if present, search only the matching lendingtype - :cache: use the cache - :returns: WBSearchResult containing dictionary objects representing - countries - """ - countries = get_country( - incomelevel=incomelevel, lendingtype=lendingtype, cache=cache - ) - pattern = re.compile(query, re.IGNORECASE) - return WBSearchResult(i for i in countries if pattern.search(i["name"])) - - -@uses_pandas -def get_dataframe( - indicators, - country="all", - data_date=None, - freq="Y", - source=None, - convert_date=False, - keep_levels=False, - cache=True, -): - """ - Convenience function to download a set of indicators and merge them into a - pandas DataFrame. The index will be the same as if calls were made to - get_data separately. - - :indicators: An dictionary where the keys are desired indicators and the - values are the desired column names - :country: a country code, sequence of country codes, or "all" (default) - :data_date: the desired date as a datetime object or a 2-sequence with - start and end dates - :freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' - (monthly), or 'Q' (quarterly). The indicator may or may not support the - specified frequency. - :source: the specific source to retrieve data from (defaults on API to 2, - World Development Indicators) - :convert_date: if True, convert date field to a datetime.datetime object. - :keep_levels: if True don't reduce the number of index levels returned if - only getting one date or country - :cache: use the cache - :returns: a WBDataFrame - """ - serieses = [ - ( - get_series( - indicator=indicator, - country=country, - data_date=data_date, - freq=freq, - source=source, - convert_date=convert_date, - keep_levels=keep_levels, - cache=cache, - ).rename(name) - ) - for indicator, name in indicators.items() - ] - result = None - for series in serieses: - if result is None: - result = series.to_frame() - else: - result = result.join(series.to_frame(), how="outer") - result = WBDataFrame(result) - result.last_updated = {i.name: i.last_updated for i in serieses} - return result diff --git a/wbdata/cache.py b/wbdata/cache.py new file mode 100644 index 0000000..63e6202 --- /dev/null +++ b/wbdata/cache.py @@ -0,0 +1,80 @@ +""" +Caching functionality + +""" +import datetime as dt +import logging +import os +from pathlib import Path +from typing import Union + +import appdirs +import cachetools +import shelved_cache # type: ignore[import-untyped] + +from wbdata import __version__ + +log = logging.getLogger(__name__) + +CACHE_PATH = os.getenv( + "WBDATA_CACHE_PATH", + os.path.join( + appdirs.user_cache_dir(appname="wbdata", version=__version__), "cache" + ), +) + +try: + TTL_DAYS = int(os.getenv("WBDATA_CACHE_TTL_DAYS", "7")) +except ValueError: + logging.warning("Couldn't parse WBDATA_CACHE_TTL_DAYS value, defaulting to 7") + TTL_DAYS = 7 + +try: + MAX_SIZE = int(os.getenv("WBDATA_CACHE_MAX_SIZE", "100")) +except ValueError: + logging.warning("Couldn't parse WBDATA_CACHE_MAX_SIZE value, defaulting to 100") + MAX_SIZE = 100 + + +def get_cache( + path: Union[str, Path, None] = None, + ttl_days: Union[int, None] = None, + max_size: Union[int, None] = None, +) -> cachetools.Cache: + """ + Create a persistent cache. + + + Default caching functionality can be controlled with environment variables: + + * `WBDATA_CACHE_PATH`: path for the cache (default: system default + application cache) + * `WBDATA_CACHE_TTL_DAYS`: number of days to cache results (default: 7) + * `WBDATA_CACHE_MAX_SIZE`: maximum number of items to cache (default: 100) + + + The cache returned is a `shelved_cache.PersistentCache` that wraps a + `cachetools.TTLCache` object with the desired parameters. The cache + is cleaned up on load. + + Parameters: + path: path to the cache. If `None`, value of `WBDATA_CACHE_PATH` + ttl_days: number of days to cache results. If `None`, value of + `WBDATA_CACHE_TTL_DAYS` + max_size: maximum number of items to cache. If `None`, value of + `WBDATA_CACHE_MAX_SIZE`. + + """ + path = path or CACHE_PATH + Path(path).parent.mkdir(parents=True, exist_ok=True) + ttl_days = ttl_days or TTL_DAYS + max_size = max_size or MAX_SIZE + cache: cachetools.TTLCache = shelved_cache.PersistentCache( + cachetools.TTLCache, + filename=str(path), + maxsize=max_size, + ttl=dt.timedelta(days=ttl_days), + timer=dt.datetime.now, + ) + cache.expire() + return cache diff --git a/wbdata/client.py b/wbdata/client.py new file mode 100644 index 0000000..91d6e9b --- /dev/null +++ b/wbdata/client.py @@ -0,0 +1,561 @@ +""" +The client class defines the wbdata client class and associated support classes. +""" + +import contextlib +import dataclasses +import datetime as dt +import re +from pathlib import Path +from typing import Any, Dict, Generator, Iterable, List, Sequence, Tuple, Union + +import decorator +import requests +import tabulate + +try: + import pandas as pd # type: ignore[import-untyped] +except ImportError: + pd = None + + +from . import cache, dates, fetcher + +BASE_URL = "https://api.worldbank.org/v2" +COUNTRIES_URL = f"{BASE_URL}/countries" +ILEVEL_URL = f"{BASE_URL}/incomeLevels" +INDICATOR_URL = f"{BASE_URL}/indicators" +LTYPE_URL = f"{BASE_URL}/lendingTypes" +SOURCE_URL = f"{BASE_URL}/sources" +TOPIC_URL = f"{BASE_URL}/topics" + + +class SearchResult(List): + """ + A list that prints out a user-friendly table when printed or returned on the + command line + + + Items are expected to be dict-like and have an "id" key and a "name" or + "value" key + """ + + def __repr__(self) -> str: + try: + return tabulate.tabulate( + [[o["id"], o["name"]] for o in self], + headers=["id", "name"], + tablefmt="simple", + ) + except KeyError: + return tabulate.tabulate( + [[o["id"], o["value"]] for o in self], + headers=["id", "value"], + tablefmt="simple", + ) + + +if pd: + + class Series(pd.Series): + """ + A `pandas.Series` with a `last_updated` attribute. + + + The `last_updated` attribute is set when the `Series` is created but not + automatically updated. Its value is either `None` or a `datetime.datetime` + object. + """ + + def __init__( + self, + *args, + last_updated: Union[None, dt.datetime] = None, + **kwargs, + ): + super().__init__(*args, **kwargs) + self.last_updated = last_updated + + _metadata = ["last_updated"] + + @property + def _constructor(self): + return Series + + class DataFrame(pd.DataFrame): + def __init__( + self, *args, serieses: Union[Dict[str, Series], None] = None, **kwargs + ): + """ + A `pandas.DataFrame` with a `last_updated` attribute + + + The `last_updated` attribute is set when the Series is created but not + automatically updated. Its value is a dictionary where the keys are the + column names and the values are `None` or a `datetime.datetime` object. + """ + if serieses: + super().__init__(serieses) + self.last_updated: Union[Dict[str, Union[dt.datetime, None]], None] = { + name: s.last_updated for name, s in serieses.items() + } + else: + super().__init__(*args, **kwargs) + self.last_updated = None + + _metadata = ["last_updated"] + + @property + def _constructor(self): + return DataFrame +else: + Series = Any # type: ignore[misc, assignment] + DataFrame = Any # type: ignore[misc, assignment] + + +@decorator.decorator +def needs_pandas(f, *args, **kwargs): + if pd is None: + raise RuntimeError(f"{f.__name__} requires pandas") + return f(*args, **kwargs) + + +def _parse_value_or_iterable(arg: Any) -> str: + """ + If arg is a single value, return it as a string; if an iterable, return a + ;-joined string of all values + """ + if isinstance(arg, str): + return arg + if isinstance(arg, Iterable): + return ";".join(str(i) for i in arg) + return str(arg) + + +def _cast_float(value: str) -> Union[float, None]: + """Return a value coerced to float or None""" + with contextlib.suppress(ValueError, TypeError): + return float(value) + return None + + +def _filter_by_pattern( + rows: Iterable[Dict[str, Any]], pattern=Union[str, re.Pattern] +) -> Generator[Dict[str, Any], None, None]: + """Return a generator of rows matching the pattern""" + if isinstance(pattern, str): + pattern = re.compile(pattern, re.IGNORECASE) + return (row for row in rows if pattern.search(row["name"])) + + +@dataclasses.dataclass +class Client: + """ + The client object for the World Bank API. + + Most users will only need to create this if they need more than one cache, + want to specify a cache programmatically rather than through environment + variables, or want to specify a requests Session. + + Parameters: + cache_path: path to the cache file + cache_ttl_days: number of days to retain cached results + cache_max_size: number of items to retain in the cache + session: requests Session object to use to make requests + """ + + cache_path: Union[str, Path, None] = None + cache_ttl_days: Union[int, None] = None + cache_max_size: Union[int, None] = None + session: Union[requests.Session, None] = None + + def __post_init__(self): + self.fetcher = fetcher.Fetcher( + cache=cache.get_cache( + path=self.cache_path, + ttl_days=self.cache_ttl_days, + max_size=self.cache_max_size, + ) + ) + self.has_pandas = pd is None + + def get_data( + self, + indicator: str, + country: Union[str, Sequence[str]] = "all", + date: Union[ + str, + dt.datetime, + Tuple[Union[str, dt.datetime], Union[str, dt.datetime]], + None, + ] = None, + freq: str = "Y", + source: Union[int, str, Sequence[Union[int, str]], None] = None, + parse_dates: bool = False, + skip_cache: bool = False, + ) -> fetcher.Result: + """ + Retrieve indicators for given countries and years + + Parameters: + indicator: the desired indicator code + country: a country code, sequence of country codes, or "all" (default) + date: the desired date as a string, datetime object or a 2-tuple + with start and end dates + freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' + (monthly), or 'Q' (quarterly). The indicator may or may not + support the specified frequency. + source: the specific source to retrieve data from (defaults on API + to 2, World Development Indicators) + parse_dates: if True, convert date field to a datetime.datetime + object. + skip_cache: bypass the cache when downloading + + Returns: + A list of dictionaries of observations + """ + url = COUNTRIES_URL + try: + c_part = _parse_value_or_iterable(country) + except TypeError as e: + raise TypeError("'country' must be a string or iterable'") from e + url = "/".join((url, c_part, "indicators", indicator)) + params: Dict[str, Any] = {} + if date: + params["date"] = dates.format_dates(date, freq) + if source: + params["source"] = source + data = self.fetcher.fetch(url=url, params=params, skip_cache=skip_cache) + if parse_dates: + dates.parse_row_dates(data) + return data + + def _id_only_query(self, url: str, id_: Any, skip_cache: bool) -> SearchResult: + """ + Utility to retrieve information when ids are the only arguments + + Parameters: + url: the base url to use for the query + id_: an id or sequence of ids + skip_cache: bypass cache when downloading + + Returns: + list of dictionary objects describing results + """ + if id_: + url = "/".join((url, _parse_value_or_iterable(id_))) + return SearchResult(self.fetcher.fetch(url=url, skip_cache=skip_cache)) + + def get_sources( + self, + source_id: Union[int, str, Sequence[Union[int, str]], None] = None, + skip_cache: bool = False, + ) -> SearchResult: + """ + Retrieve information on one or more sources + + Parameters: + source_id: a source id or sequence thereof. None returns all sources + skip_cache: bypass cache when downloading + + Returns: + list of dictionary objects describing selected sources + """ + return self._id_only_query(url=SOURCE_URL, id_=source_id, skip_cache=skip_cache) + + def get_incomelevels( + self, + level_id: Union[int, str, Sequence[Union[int, str]], None] = None, + skip_cache: bool = False, + ) -> SearchResult: + """ + Retrieve information on one or more income level aggregates + + Parameters: + level_id: a level id or sequence thereof. None returns all income level + aggregates + skip_cache: bypass cache when downloading + + Returns: + list of dictionary objects describing selected + income level aggregates + """ + return self._id_only_query(ILEVEL_URL, level_id, skip_cache=skip_cache) + + def get_topics( + self, + topic_id: Union[int, str, Sequence[Union[int, str]], None] = None, + skip_cache: bool = False, + ) -> SearchResult: + """ + Retrieve information on one or more topics + + Parameters: + topic_id: a topic id or sequence thereof. None returns all topics + skip_cache: bypass cache when downloading + + Returns: + list of dictionary objects describing selected topic + aggregates + """ + return self._id_only_query(TOPIC_URL, topic_id, skip_cache=skip_cache) + + def get_lendingtypes( + self, + type_id: Union[int, str, Sequence[Union[int, str]], None] = None, + skip_cache: bool = False, + ) -> SearchResult: + """ + Retrieve information on one or more lending type aggregates + + Parameters: + type_id: lending type id or sequence thereof. None returns all lending + type aggregates + skip_cache: bypass cache when downloading + + Returns: + list of dictionary objects describing selected lending type aggregates + """ + return self._id_only_query(LTYPE_URL, type_id, skip_cache=skip_cache) + + def get_countries( + self, + country_id: Union[str, Sequence[str], None] = None, + query: Union[str, re.Pattern, None] = None, + incomelevel: Union[int, str, Sequence[Union[int, str]], None] = None, + lendingtype: Union[int, str, Sequence[Union[int, str]], None] = None, + skip_cache: bool = False, + ) -> SearchResult: + """ + Retrieve information on one or more country or regional aggregates. + + You can filter your results by specifying `query, `incomelevel`, or + `lendingtype`. Specifying `query` will only return countries with + names that match the query as a regular expression. If a string is + supplied, the match is case insensitive. + + Specifying `query`, `incomelevel`, or `lendingtype` along with + `country_id` will raise a `ValueError`. + + Parameters: + country_id: a country id or sequence thereof. None returns all + countries and aggregates. + query: a regular expression on which to filter results + incomelevel: desired incomelevel id or ids on which to filter results + lendingtype: desired lendingtype id or ids on which to filter results + skip_cache: bypass cache when downloading + + Returns: + list of dictionaries describing countries + + """ + if country_id: + if incomelevel or lendingtype or query: + raise ValueError("Can't specify country_id and aggregates") + return self._id_only_query(COUNTRIES_URL, country_id, skip_cache=skip_cache) + params = {} + if incomelevel: + params["incomeLevel"] = _parse_value_or_iterable(incomelevel) + if lendingtype: + params["lendingType"] = _parse_value_or_iterable(lendingtype) + results = self.fetcher.fetch( + url=COUNTRIES_URL, params=params, skip_cache=skip_cache + ) + if query: + results = _filter_by_pattern(results, query) + return SearchResult(results) + + def get_indicators( + self, + indicator: Union[str, Sequence[str], None] = None, + query: Union[str, re.Pattern, None] = None, + source: Union[str, int, Sequence[Union[str, int]], None] = None, + topic: Union[str, int, Sequence[Union[str, int]], None] = None, + skip_cache: bool = False, + ) -> SearchResult: + """ + Retrieve information about an indicator or indicators. + + When called with no arguments, returns all indicators. You can specify + one or more indicators to retrieve, or you can specify a source or a + topic for which to list all indicators. Specifying more than one of + `indicators`, `source`, and `topic` will raise a ValueError. + + Specifying `query` will only return indicators with names that match + the query as a regular expression. If a string is supplied, the match + is case insensitive. Specifying both `query` and `indicators` will raise + a ValueError. + + Parameters: + indicator: an indicator code or sequence thereof + query: a regular expression on which to filter results + source: a source id or sequence thereof + topic: a topic id or sequence thereof + skip_cache: bypass cache when downloading + + Returns: + list of dictionary objects representing indicators + """ + if query and indicator: + raise ValueError("Cannot specify indicator and query") + if sum(bool(i) for i in (indicator, source, topic)) > 1: + raise ValueError( + "Cannot specify more than one of indicator, source, and topic" + ) + if indicator: + url = "/".join((INDICATOR_URL, _parse_value_or_iterable(indicator))) + elif source: + url = "/".join((SOURCE_URL, _parse_value_or_iterable(source), "indicators")) + elif topic: + url = "/".join((TOPIC_URL, _parse_value_or_iterable(topic), "indicators")) + else: + url = INDICATOR_URL + results = self.fetcher.fetch(url=url, skip_cache=skip_cache) + if query: + results = _filter_by_pattern(results, query) + return SearchResult(results) + + @needs_pandas + def get_series( + self, + indicator: str, + country: Union[str, Sequence[str]] = "all", + date: Union[ + str, + dt.datetime, + Tuple[Union[str, dt.datetime], Union[str, dt.datetime]], + None, + ] = None, + freq: str = "Y", + source: Union[int, str, Sequence[Union[int, str]], None] = None, + parse_dates: bool = False, + name: str = "value", + keep_levels: bool = False, + skip_cache: bool = False, + ) -> Series: + """ + Retrieve data for a single indicator as a pandas Series. + + If pandas is not installed, a RuntimeError will be raised. + + Parameters: + indicator: the desired indicator code + country: a country code, sequence of country codes, or "all" (default) + date: the desired date as a string, datetime object or a 2-tuple + with start and end dates + freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' + (monthly), or 'Q' (quarterly). The indicator may or may not + support the specified frequency. + source: the specific source to retrieve data from (defaults on API + to 2, World Development Indicators) + parse_dates: if True, convert date field to a datetime.datetime + object. + skip_cache: bypass the cache when downloading + name: the desired name for the pandas Series + keep_levels: if True don't reduce the number of index + levels returned if only getting one date or country + skip_cache: bypass the cache when downloading + + Returns: + Series with the requested data. The index of the series depends on + the data returned and the specified options. If the data spans + multiple dates and countries or if `keep_levels` is `True`, the + index will be a 2-level MultiIndex with levels "country" and + "name". If `keep_levels` is `False` (the default) and the data + only has one country or date, the level with only one value + will be dropped. If `keep_levels` is `False` and both levels + only have one value, the country level is dropped. + """ + raw_data = self.get_data( + indicator=indicator, + country=country, + date=date, + freq=freq, + source=source, + parse_dates=parse_dates, + skip_cache=skip_cache, + ) + df = pd.DataFrame( + [[i["country"]["value"], i["date"], i["value"]] for i in raw_data], + columns=["country", "date", name], + ) + df[name] = df[name].map(_cast_float) + if not keep_levels and len(df["country"].unique()) == 1: + df = df.set_index("date") + elif not keep_levels and len(df["date"].unique()) == 1: + df = df.set_index("country") + else: + df = df.set_index(["country", "date"]) + return Series(df[name], last_updated=raw_data.last_updated) + + @needs_pandas + def get_dataframe( + self, + indicators: Dict[str, str], + country: Union[str, Sequence[str]] = "all", + date: Union[ + str, + dt.datetime, + Tuple[Union[str, dt.datetime], Union[str, dt.datetime]], + None, + ] = None, + freq: str = "Y", + source: Union[int, str, Sequence[Union[int, str]], None] = None, + parse_dates: bool = False, + keep_levels: bool = False, + skip_cache: bool = False, + ) -> DataFrame: + """ + Download a set of indicators and merge them into a pandas DataFrame. + + If pandas is not installed, a RuntimeError will be raised. + + Parameters: + indicators: An dictionary where the keys are desired indicators and the + values are the desired column names country: a country code, + sequence of country codes, or "all" (default) + date: the desired date as a string, datetime object or a 2-tuple + with start and end dates + freq: the desired periodicity of the data, one of 'Y' (yearly), 'M' + (monthly), or 'Q' (quarterly). The indicator may or may not + support the specified frequency. + source: the specific source to retrieve data from (defaults on API + to 2, World Development Indicators) + parse_dates: if True, convert date field to a datetime.datetime + object. + skip_cache: bypass the cache when downloading + keep_levels: if True don't reduce the number of index + levels returned if only getting one date or country + skip_cache: bypass the cache when downloading + + Returns: + DataFrame with one column per indicator. The index of the DataFrame + depends on the data returned and the specified options. If the + data spans multiple dates and countries or if `keep_levels` is + `True`, the index will be a 2-level MultiIndex with levels + "country" and "name". If `keep_levels` is `False` (the default) + and the data only has one country or date, the level with only + one value will be dropped. If `keep_levels` is `False` and both + levels only have one value, the country level is dropped. + + """ + df = DataFrame( + serieses={ + name: self.get_series( + indicator=indicator, + country=country, + date=date, + freq=freq, + source=source, + parse_dates=parse_dates, + keep_levels=True, + skip_cache=skip_cache, + ) + for indicator, name in indicators.items() + } + ) + if not keep_levels and len(set(df.index.get_level_values(0))) == 1: + df.index = df.index.droplevel(0) + elif not keep_levels and len(set(df.index.get_level_values(1))) == 1: + df.index = df.index.droplevel(1) + return df diff --git a/wbdata/dates.py b/wbdata/dates.py new file mode 100644 index 0000000..660a27e --- /dev/null +++ b/wbdata/dates.py @@ -0,0 +1,122 @@ +""" +Miscellaneous data utilities +""" +import datetime as dt +import re +from typing import Any, Dict, Sequence, Tuple, Union + +import dateparser + +PATTERN_YEAR = re.compile(r"\d{4}") +PATTERN_MONTH = re.compile(r"\d{4}M\d{1,2}") +PATTERN_QUARTER = re.compile(r"\d{4}Q\d{1,2}") + +Date = Union[str, dt.datetime] +Dates = Union[Date, Tuple[Date, Date]] + + +def _parse_year(datestr: str) -> dt.datetime: + """return datetime.datetime object from %Y formatted string""" + return dt.datetime.strptime(datestr, "%Y") + + +def _parse_month(datestr: str) -> dt.datetime: + """return datetime.datetime object from %YM%m formatted string""" + split = datestr.split("M") + return dt.datetime(int(split[0]), int(split[1]), 1) + + +def _parse_quarter(datestr: str) -> dt.datetime: + """ + return datetime.datetime object from %YQ%# formatted string, where # is + the desired quarter + """ + split = datestr.split("Q") + quarter = int(split[1]) + month = quarter * 3 - 2 + return dt.datetime(int(split[0]), month, 1) + + +def parse_row_dates(data: Sequence[Dict[str, Any]]) -> None: + """ + Replace date strings in raw response with datetime objects, in-place. + + Does not replace "MRV" or "-". If we don't recognize the format, do nothing. + + Parameters: + data: sequence of dictionaries with `date` keys to parse + """ + first = data[0]["date"] + if not isinstance(first, str): # Ignore unexpected cases + return + if PATTERN_MONTH.match(first): + converter = _parse_month + elif PATTERN_QUARTER.match(first): + converter = _parse_quarter + else: + converter = _parse_year + for datum in data: + datum_date = datum["date"] + if not isinstance(datum_date, str) or "MRV" in datum_date or "-" in datum_date: + continue + datum["date"] = converter(datum_date) + + +def _format_date(date: dt.datetime, freq: str) -> str: + """ + Convert date to the appropriate representation base on freq + + + :date: A datetime.datetime object to be formatted + :freq: One of 'Y' (year), 'M' (month) or 'Q' (quarter) + + """ + try: + return { + "Y": lambda x: x.strftime("%Y"), + "M": lambda x: x.strftime("%YM%m"), + "Q": lambda x: f"{x.year}Q{(x.month - 1) // 3 + 1}", + }[freq](date) + except KeyError as e: + raise ValueError(f"Unknown Frequency type: {freq}") from e + + +def _parse_date(date: Date) -> dt.datetime: + if isinstance(date, dt.datetime): + return date + if PATTERN_YEAR.fullmatch(date): + return _parse_year(date) + if PATTERN_MONTH.fullmatch(date): + return _parse_month(date) + if PATTERN_QUARTER.fullmatch(date): + return _parse_quarter(date) + last_chance = dateparser.parse(date) + if last_chance: + return last_chance + raise ValueError(f"Unable to parse date string {date}") + + +def _parse_and_format_date(date: Date, freq: str) -> str: + return _format_date(_parse_date(date), freq) + + +def format_dates(dates: Dates, freq: str) -> str: + """ + Given one or two date arguments, turn them into WB-accepted date parameters + + Parameters: + dates: a date or a tuple of two dates, where a date is either a string + or a datetime.datetime object. The date can be either a World-Bank + format string or anything that dateparser can handle. + freq: One of "Y", "M", or "Q" for year, month, or quarter respectively. + + Returns: + A string representing a date or date range according to the specified + frequency in the form the World Bank API expects. + """ + if isinstance(dates, tuple): + return ( + f"{_parse_and_format_date(dates[0], freq)}" + f":{_parse_and_format_date(dates[1], freq)}" + ) + return _parse_and_format_date(dates, freq) diff --git a/wbdata/fetcher.py b/wbdata/fetcher.py index 932ced7..f2840a2 100644 --- a/wbdata/fetcher.py +++ b/wbdata/fetcher.py @@ -2,149 +2,170 @@ wbdata.fetcher: retrieve and cache queries """ -import datetime +import contextlib +import dataclasses +import datetime as dt import json import logging -import pickle import pprint +from typing import Any, Dict, List, MutableMapping, NamedTuple, Tuple, Union -import appdirs +import backoff import requests -import wbdata +PER_PAGE = 1000 +TRIES = 3 -from pathlib import Path -EXP = 7 -PER_PAGE = 1000 -TODAY = datetime.date.today() -TRIES = 5 +def _strip_id(row: Dict[str, Any]) -> None: + with contextlib.suppress(KeyError): + row["id"] = row["id"].strip() # type: ignore[union-attr] -class WBResults(list): - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - self.last_updated = None +Response = Tuple[Dict[str, Any], List[Dict[str, Any]]] -class Cache(object): - """Docstring for Cache """ +class ParsedResponse(NamedTuple): + rows: List[Dict[str, Any]] + page: int + pages: int + last_updated: Union[str, None] - def __init__(self): - self.path = Path( - appdirs.user_cache_dir(appname="wbdata", version=wbdata.__version__) - ) - self.path.parent.mkdir(parents=True, exist_ok=True) + @classmethod + def from_response(cls, response: Response) -> "ParsedResponse": try: - with self.path.open("rb") as cachefile: - self.cache = { - i: (date, json) - for i, (date, json) in pickle.load(cachefile).items() - if (TODAY - datetime.date.fromordinal(date)).days < EXP - } - except (IOError, EOFError): - self.cache = {} - - def __getitem__(self, key): - return self.cache[key][1] + return ParsedResponse( + rows=response[1], + page=int(response[0]["page"]), + pages=int(response[0]["pages"]), + last_updated=response[0].get("lastupdated"), + ) + except (IndexError, KeyError) as e: + try: + message = response[0]["message"][0] + raise RuntimeError( + f"Got error {message['id']} ({message['key']}): " + f"{message['value']}" + ) from e + except (IndexError, KeyError) as e: + raise RuntimeError( + f"Got unexpected response:\n{pprint.pformat(response)}" + ) from e - def __setitem__(self, key, value): - self.cache[key] = TODAY.toordinal(), value - self.sync() - def __contains__(self, item): - return item in self.cache +CacheKey = Tuple[str, Tuple[Tuple[str, Any], ...]] - def sync(self): - """Sync cache to disk""" - with self.path.open("wb") as cachefile: - pickle.dump(self.cache, cachefile) +class Result(List[Dict[str, Any]]): + """ + List with a `last_updated` attribute. The `last_updated` attribute is either + a datetime.datetime object or None. + """ -CACHE = Cache() + def __init__(self, *args, last_updated: Union[dt.datetime, None] = None, **kwargs): + super().__init__(*args, **kwargs) + self.last_updated = last_updated -def get_json_from_url(url, args): +@dataclasses.dataclass +class Fetcher: """ - Fetch a url directly from the World Bank, up to TRIES tries + An object for making cached HTTP requests. - :url: the url to retrieve - :args: a dictionary of GET arguments - :returns: a string with the url contents + Parameters: + cache: a dictlike container for caching responses + session: a requests session to use to make the requests, if `None`, + create a new session """ - for _ in range(TRIES): - try: - return requests.get(url, args).text - except requests.ConnectionError: - continue - logging.error(f"Error connecting to {url}") - raise RuntimeError("Couldn't connect to API") - -def get_response(url, args, cache=True): - """ - Get single page response from World Bank API or from cache - : query_url: the base url to be queried - : args: a dictionary of GET arguments - : cache: use the cache - : returns: a dictionary with the response from the API - """ - logging.debug(f"fetching {url}") - key = (url, tuple(sorted(args.items()))) - if cache and key in CACHE: - response = CACHE[key] - else: - response = get_json_from_url(url, args) - if cache: - CACHE[key] = response - return json.loads(response) - - -def fetch(url, args=None, cache=True): - """Fetch data from the World Bank API or from cache. - - Given the base url, keep fetching results until there are no more pages. - - : query_url: the base url to be queried - : args: a dictionary of GET arguments - : cache: use the cache - : returns: a list of dictionaries containing the response to the query - """ - if args is None: - args = {} - else: - args = dict(args) - args["format"] = "json" - args["per_page"] = PER_PAGE - results = [] - pages, this_page = 0, 1 - while pages != this_page: - response = get_response(url, args, cache=cache) - try: - results.extend(response[1]) - this_page = response[0]["page"] - pages = response[0]["pages"] - except (IndexError, KeyError): - try: - message = response[0]["message"][0] - raise RuntimeError( - f"Got error {message['id']} ({message['key']}): " - f"{message['value']}" - ) - except (IndexError, KeyError): - raise RuntimeError( - f"Got unexpected response:\n{pprint.pformat(response)}" - ) - logging.debug(f"Processed page {this_page} of {pages}") - args["page"] = int(this_page) + 1 - for i in results: - if "id" in i: - i["id"] = i["id"].strip() - results = WBResults(results) - try: - results.last_updated = datetime.datetime.strptime( - response[0]["lastupdated"], "%Y-%m-%d" + cache: MutableMapping[CacheKey, str] + session: requests.Session = dataclasses.field(default_factory=requests.Session) + + @backoff.on_exception( + wait_gen=backoff.expo, + exception=requests.exceptions.ConnectTimeout, + max_tries=TRIES, + ) + def _get_response_body( + self, + url: str, + params: Dict[str, Any], + ) -> str: + """ + Fetch a url directly from the World Bank + + Parameters: + url: the url to retrieve + params: a dictionary of GET parameters + + Returns: a string with the response content + """ + # Copy is for mocking. It's kind of depressing but not too expensive + body = self.session.get(url=url, params={**params}).text + return body + + def _get_response( + self, + url: str, + params: Dict[str, Any], + skip_cache=False, + ) -> ParsedResponse: + """ + Get single page response from World Bank API or from cache + + Parameters: + query_url: the base url to be queried + params: a dictionary of GET arguments + skip_cache: bypass the cache + + Returns: parsed version of the API response + """ + key = (url, tuple(sorted(params.items()))) + if not skip_cache and key in self.cache: + body = self.cache[key] + else: + body = self._get_response_body(url, params) + self.cache[key] = body + return ParsedResponse.from_response(tuple(json.loads(body))) + + def fetch( + self, + url: str, + params: Union[Dict[str, Any], None] = None, + skip_cache: bool = False, + ) -> Result: + """Fetch data from the World Bank API or from cache. + + Given the base url, keep fetching results until there are no more pages. + + Parameters: + url: the base url to be queried + params: a dictionary of GET arguments + skip_cache: bool: use the cache + + Returns: + a list of dictionaries containing the response to the query + """ + params = {**(params or {})} + params["format"] = "json" + params["per_page"] = PER_PAGE + page, pages = -1, -2 + rows: List[Dict[str, Any]] = [] + while pages != page: + response = self._get_response( + url=url, + params=params, + skip_cache=skip_cache, + ) + rows.extend(response.rows) + page, pages = response.page, response.pages + logging.debug(f"Processed page {page} of {pages}") + params["page"] = page + 1 + for row in rows: + _strip_id(row) + last_updated = ( + None + if not response.last_updated + else dt.datetime.strptime(response.last_updated, "%Y-%m-%d") ) - except KeyError: - pass - return results + return Result(rows, last_updated=last_updated) diff --git a/wbdata/py.typed b/wbdata/py.typed new file mode 100644 index 0000000..e69de29