From 33c47b65bacc46ea167d89958762b9e9bb338579 Mon Sep 17 00:00:00 2001 From: Janos Gabler Date: Sat, 18 Jun 2022 09:52:20 +0200 Subject: [PATCH] Add jax support (#357) --- .github/workflows/main.yml | 24 ++-- CHANGES.rst | 19 +++ environment.yml | 3 +- src/estimagic/parameters/tree_conversion.py | 3 +- src/estimagic/parameters/tree_registry.py | 4 +- src/estimagic/process_user_function.py | 6 +- src/estimagic/utilities.py | 11 ++ tests/optimization/test_jax_derivatives.py | 122 ++++++++++++++++++ .../test_trust_region_sampling.py | 1 + .../test_with_nonlinear_constraints.py | 2 + tests/test_utilities.py | 34 +++++ tox.ini | 24 +--- 12 files changed, 221 insertions(+), 32 deletions(-) create mode 100644 tests/optimization/test_jax_derivatives.py diff --git a/.github/workflows/main.yml b/.github/workflows/main.yml index b4916da5e..476d9501e 100644 --- a/.github/workflows/main.yml +++ b/.github/workflows/main.yml @@ -15,7 +15,7 @@ on: jobs: - run-tests: + run-tests-linux: name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }} runs-on: ${{ matrix.os }} @@ -23,7 +23,7 @@ jobs: strategy: fail-fast: false matrix: - os: ['ubuntu-latest', 'macos-latest', 'windows-latest'] + os: ['ubuntu-latest'] python-version: ['3.8', '3.9', '3.10'] steps: @@ -39,7 +39,7 @@ jobs: - name: Run pytest. shell: bash -l {0} - run: tox -e pytest -- -m "not slow and not jax" --cov-report=xml --cov=./ + run: tox -e pytest-linux -- --cov-report=xml --cov=./ - name: Upload coverage report. if: runner.os == 'Linux' && matrix.python-version == '3.8' @@ -47,27 +47,31 @@ jobs: with: token: ${{ secrets.CODECOV_TOKEN }} + run-tests: - run-jax-tests: + name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }} + runs-on: ${{ matrix.os }} - name: Run jax tests on Linux - runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + os: ['macos-latest', 'windows-latest'] + python-version: ['3.8', '3.9', '3.10'] steps: - uses: actions/checkout@v2 - uses: conda-incubator/setup-miniconda@v2 with: auto-update-conda: true - python-version: 3.8 + python-version: ${{ matrix.python-version }} - name: Install core dependencies. shell: bash -l {0} run: conda install -c conda-forge tox-conda - - name: Run pytest on jax tests. + - name: Run pytest. shell: bash -l {0} - run: tox -e jax -- -m "jax" - + run: tox -e pytest -- -m "not slow and not jax" docs: diff --git a/CHANGES.rst b/CHANGES.rst index e8c4a7681..fc6567926 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -7,6 +7,25 @@ releases are available on `Anaconda.org `_. +0.3.3 +----- + +- :gh:`357` Adds jax support (:ghuser:`janosg`) +- :gh:`359` Improves error handling with violated constaints (:ghuser:`timmens`) +- :gh:`358` Improves cartis roberts set of test functions and improves the + default latex rendering of MultiIndex tables (:ghuser:`mpetrosian`) + + +0.3.2 +----- + +- :gh:`355` Improves test coverage of contraints processing (:ghuser:`janosg`) +- :gh:`354` Improves test coverage for bounds processing (:ghuser:`timmens`) +- :gh:`353` Improves history plots (:ghuser:`timmens`) +- :gh:`352` Improves scaling and benchmarking (:ghuser:`janosg`) +- :gh:`351` Improves estimation summaries (:ghuser:`timmens`) +- :gh:`350` Allow empty queries or selectors in constraints (:ghuser:`janosg`) + 0.3.1 ----- diff --git a/environment.yml b/environment.yml index eb90abb41..b45f7af54 100644 --- a/environment.yml +++ b/environment.yml @@ -38,13 +38,12 @@ dependencies: - sqlalchemy>=1.3 - tox-conda - statsmodels - - dill - pytask>=0.0.11 - nlopt - sphinx-panels - pygmo - nb_black - - pybaum + - pybaum>=0.1.2 - pip: - black diff --git a/src/estimagic/parameters/tree_conversion.py b/src/estimagic/parameters/tree_conversion.py index c2e220e51..47e181bd8 100644 --- a/src/estimagic/parameters/tree_conversion.py +++ b/src/estimagic/parameters/tree_conversion.py @@ -5,6 +5,7 @@ from estimagic.parameters.block_trees import block_tree_to_matrix from estimagic.parameters.parameter_bounds import get_bounds from estimagic.parameters.tree_registry import get_registry +from estimagic.utilities import isscalar from pybaum import leaf_names from pybaum import tree_flatten from pybaum import tree_just_flatten @@ -130,7 +131,7 @@ def params_unflatten(x): def _get_func_flatten(registry, func_eval, primary_key): - if np.isscalar(func_eval): + if isscalar(func_eval): if primary_key == "value": func_flatten = lambda func_eval: float(func_eval) else: diff --git a/src/estimagic/parameters/tree_registry.py b/src/estimagic/parameters/tree_registry.py index 291754cfb..688200f44 100644 --- a/src/estimagic/parameters/tree_registry.py +++ b/src/estimagic/parameters/tree_registry.py @@ -30,7 +30,9 @@ def get_registry(extended=False, data_col="value"): dict: The pytree registry. """ - types = ["numpy.ndarray", "pandas.Series"] if extended else None + types = ( + ["numpy.ndarray", "pandas.Series", "jax.numpy.ndarray"] if extended else None + ) registry = get_pybaum_registry(types=types) if extended: registry[pd.DataFrame] = { diff --git a/src/estimagic/process_user_function.py b/src/estimagic/process_user_function.py index 87c375021..64cbea8ea 100644 --- a/src/estimagic/process_user_function.py +++ b/src/estimagic/process_user_function.py @@ -47,7 +47,11 @@ def process_func_of_params(func, kwargs, name="your function", skip_checks=False required_args = unpartialled_args.intersection(no_default_args) too_many_required_arguments = len(required_args) > 1 - if too_many_required_arguments: + # Try to discover if we have a jax calculated jacobian that has a weird + # signature that would not pass this test: + skip_because_of_jax = required_args == {"args", "kwargs"} + + if too_many_required_arguments and not skip_because_of_jax: raise InvalidKwargsError( f"Too few keyword arguments for {name}. After applying all keyword " "arguments at most one required argument (the params) should remain. " diff --git a/src/estimagic/utilities.py b/src/estimagic/utilities.py index 95699e56e..595d40131 100644 --- a/src/estimagic/utilities.py +++ b/src/estimagic/utilities.py @@ -293,3 +293,14 @@ def to_pickle(obj, path): def read_pickle(path): return pd.read_pickle(path) + + +def isscalar(element): + """Jax aware replacement for np.isscalar.""" + if np.isscalar(element): + return True + # call anything a scalar that says it has 0 dimensions + elif getattr(element, "ndim", -1) == 0: + return True + else: + return False diff --git a/tests/optimization/test_jax_derivatives.py b/tests/optimization/test_jax_derivatives.py new file mode 100644 index 000000000..4d98e694b --- /dev/null +++ b/tests/optimization/test_jax_derivatives.py @@ -0,0 +1,122 @@ +import numpy as np +import pytest +from estimagic.config import IS_JAX_INSTALLED +from estimagic.optimization.optimize import minimize +from numpy.testing import assert_array_almost_equal as aaae + +if IS_JAX_INSTALLED: + import jax.numpy as jnp + import jax + + +@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.") +def test_scipy_conference_example(): + def criterion(x): + first = (x["a"] - jnp.pi) ** 2 + second = jnp.linalg.norm(x["b"] - jnp.arange(3)) + third = jnp.linalg.norm(x["c"] - jnp.eye(2)) + return first + second + third + + start_params = { + "a": 1.0, + "b": jnp.ones(3).astype(float), + "c": jnp.ones((2, 2)).astype(float), + } + + gradient = jax.grad(criterion) + + res = minimize( + criterion=criterion, + derivative=gradient, + params=start_params, + algorithm="scipy_lbfgsb", + ) + + assert isinstance(res.params["b"], jnp.ndarray) + aaae(res.params["b"], jnp.arange(3)) + aaae(res.params["c"], jnp.eye(2)) + assert np.allclose(res.params["a"], np.pi, atol=1e-4) + + +@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.") +def test_params_is_jax_scalar(): + def criterion(x): + return x**2 + + res = minimize( + criterion=criterion, + params=jnp.array(1.0), + algorithm="scipy_lbfgsb", + derivative=jax.grad(criterion), + ) + + assert isinstance(res.params, jnp.ndarray) + assert np.allclose(res.params, 0.0) + + +@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.") +def params_is_1d_array(): + def criterion(x): + return x @ x + + res = minimize( + criterion=criterion, + params=jnp.arange(3), + algorithm="scipy_lbfgsb", + derivative=jax.grad(criterion), + ) + + assert isinstance(res.params, jnp.ndarray) + assert aaae(res.params, jnp.arange(3)) + + +@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.") +@pytest.mark.parametrize("algorithm", ["scipy_lbfgsb", "scipy_ls_lm"]) +def test_dict_output_works(algorithm): + def criterion(x): + return {"root_contributions": x, "value": x @ x} + + def scalar_wrapper(x): + return criterion(x)["value"] + + def ls_wrapper(x): + return criterion(x)["root_contributions"] + + deriv_dict = { + "value": jax.grad(scalar_wrapper), + "root_contributions": jax.jacobian(ls_wrapper), + } + + res = minimize( + criterion=criterion, + params=jnp.array([1.0, 2.0, 3.0]), + algorithm=algorithm, + derivative=deriv_dict, + ) + + assert isinstance(res.params, jnp.ndarray) + aaae(res.params, np.zeros(3)) + + +@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.") +def test_least_squares_optimizer_pytree(): + def criterion(x): + return {"root_contributions": x} + + def ls_wrapper(x): + return criterion(x)["root_contributions"] + + params = {"a": 1.0, "b": 2.0, "c": jnp.array([1.0, 2.0])} + jac = jax.jacobian(ls_wrapper) + + res = minimize( + criterion=criterion, + params=params, + algorithm="scipy_ls_lm", + derivative=jac, + ) + + assert isinstance(res.params, dict) + assert np.allclose(res.params["a"], 0) + assert np.allclose(res.params["b"], 0) + aaae(res.params["c"], np.zeros(2)) diff --git a/tests/optimization/test_trust_region_sampling.py b/tests/optimization/test_trust_region_sampling.py index 17533cd69..d08dd3a6c 100644 --- a/tests/optimization/test_trust_region_sampling.py +++ b/tests/optimization/test_trust_region_sampling.py @@ -132,6 +132,7 @@ def test_get_next_trust_region_points_latin_hypercube_single_use( aaae(sample.mean(axis=0), center, decimal=decimal) +@pytest.mark.slow @pytest.mark.parametrize( "optimality_criterion", ["a-optimal", "e-optimal", "d-optimal", "g-optimal", "maximin"], diff --git a/tests/optimization/test_with_nonlinear_constraints.py b/tests/optimization/test_with_nonlinear_constraints.py index 605f17729..5601d876d 100644 --- a/tests/optimization/test_with_nonlinear_constraints.py +++ b/tests/optimization/test_with_nonlinear_constraints.py @@ -5,6 +5,7 @@ import pytest from estimagic import maximize from estimagic import minimize +from estimagic.config import IS_CYIPOPT_INSTALLED from estimagic.optimization import AVAILABLE_ALGORITHMS from numpy.testing import assert_array_almost_equal as aaae @@ -233,6 +234,7 @@ def constraint(selected): TEST_CASES = list(itertools.product(["ipopt"], [True, False])) +@pytest.mark.skipif(not IS_CYIPOPT_INSTALLED, reason="Needs ipopt") @pytest.mark.parametrize("algorithm, skip_checks", TEST_CASES) def test_general_example(general_example, algorithm, skip_checks): diff --git a/tests/test_utilities.py b/tests/test_utilities.py index d98d85f70..ba70b4917 100644 --- a/tests/test_utilities.py +++ b/tests/test_utilities.py @@ -1,6 +1,7 @@ import numpy as np import pandas as pd import pytest +from estimagic.config import IS_JAX_INSTALLED from estimagic.utilities import calculate_trustregion_initial_radius from estimagic.utilities import chol_params_to_lower_triangular_matrix from estimagic.utilities import cov_matrix_to_params @@ -9,6 +10,7 @@ from estimagic.utilities import cov_to_sds_and_corr from estimagic.utilities import dimension_to_number_of_triangular_elements from estimagic.utilities import hash_array +from estimagic.utilities import isscalar from estimagic.utilities import number_of_triangular_elements_to_dimension from estimagic.utilities import read_pickle from estimagic.utilities import robust_cholesky @@ -19,6 +21,9 @@ from estimagic.utilities import to_pickle from numpy.testing import assert_array_almost_equal as aaae +if IS_JAX_INSTALLED: + import jax.numpy as jnp + def test_chol_params_to_lower_triangular_matrix(): calculated = chol_params_to_lower_triangular_matrix(pd.Series([1, 2, 3])) @@ -176,3 +181,32 @@ def test_pickling(tmp_path): to_pickle(a, path) b = read_pickle(path) assert a == b + + +SCALARS = [1, 2.0, np.pi, np.array(1), np.array(2.0), np.array(np.pi), np.nan] + + +@pytest.mark.parametrize("element", SCALARS) +def test_isscalar_true(element): + assert isscalar(element) is True + + +NON_SCALARS = [np.arange(3), {"a": 1}, [1, 2, 3]] + + +@pytest.mark.parametrize("element", NON_SCALARS) +def test_isscalar_false(element): + assert isscalar(element) is False + + +@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.") +def tets_isscalar_jax_true(): + x = jnp.arange(3) + element = x @ x + assert isscalar(element) is True + + +@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.") +def test_isscalar_jax_false(): + element = jnp.arange(3) + assert isscalar(element) is False diff --git a/tox.ini b/tox.ini index 8fc0cc11e..7bed42a00 100644 --- a/tox.ini +++ b/tox.ini @@ -1,16 +1,15 @@ [tox] -envlist = pytest, jax, sphinx +envlist = pytest-linux, pytest, sphinx skipsdist = True skip_missing_interpreters = True [testenv] basepython = python -[testenv:pytest] +[testenv:pytest-linux] setenv = CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1 conda_channels = - opensourceeconomics conda-forge defaults deps = @@ -35,25 +34,21 @@ conda_deps = statsmodels seaborn plotly - dill cyipopt nlopt pygmo - pybaum + pybaum >= 0.1.2 + jax + petsc4py commands = pytest {posargs} -[testenv:jax] +[testenv:pytest] setenv = CONDA_DLL_SEARCH_MODIFICATION_ENABLE = 1 conda_channels = - opensourceeconomics conda-forge defaults -deps = - Py-BOBYQA - DFO-LS - fides == 0.7.4 conda_deps = bokeh >= 1.3 click @@ -72,12 +67,7 @@ conda_deps = statsmodels seaborn plotly - dill - cyipopt - nlopt - pygmo - pybaum - jax + pybaum >= 0.1.2 commands = pytest {posargs}