Skip to content

Commit

Permalink
Add jax support (#357)
Browse files Browse the repository at this point in the history
  • Loading branch information
janosg authored Jun 18, 2022
1 parent 9bdd7ff commit 33c47b6
Show file tree
Hide file tree
Showing 12 changed files with 221 additions and 32 deletions.
24 changes: 14 additions & 10 deletions .github/workflows/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,15 @@ on:

jobs:

run-tests:
run-tests-linux:

name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}

strategy:
fail-fast: false
matrix:
os: ['ubuntu-latest', 'macos-latest', 'windows-latest']
os: ['ubuntu-latest']
python-version: ['3.8', '3.9', '3.10']

steps:
Expand All @@ -39,35 +39,39 @@ jobs:

- name: Run pytest.
shell: bash -l {0}
run: tox -e pytest -- -m "not slow and not jax" --cov-report=xml --cov=./
run: tox -e pytest-linux -- --cov-report=xml --cov=./

- name: Upload coverage report.
if: runner.os == 'Linux' && matrix.python-version == '3.8'
uses: codecov/codecov-action@v1
with:
token: ${{ secrets.CODECOV_TOKEN }}

run-tests:

run-jax-tests:
name: Run tests for ${{ matrix.os }} on ${{ matrix.python-version }}
runs-on: ${{ matrix.os }}

name: Run jax tests on Linux
runs-on: ubuntu-latest
strategy:
fail-fast: false
matrix:
os: ['macos-latest', 'windows-latest']
python-version: ['3.8', '3.9', '3.10']

steps:
- uses: actions/checkout@v2
- uses: conda-incubator/setup-miniconda@v2
with:
auto-update-conda: true
python-version: 3.8
python-version: ${{ matrix.python-version }}

- name: Install core dependencies.
shell: bash -l {0}
run: conda install -c conda-forge tox-conda

- name: Run pytest on jax tests.
- name: Run pytest.
shell: bash -l {0}
run: tox -e jax -- -m "jax"

run: tox -e pytest -- -m "not slow and not jax"

docs:

Expand Down
19 changes: 19 additions & 0 deletions CHANGES.rst
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,25 @@ releases are available on `Anaconda.org
<https://anaconda.org/OpenSourceEconomics/estimagic>`_.


0.3.3
-----

- :gh:`357` Adds jax support (:ghuser:`janosg`)
- :gh:`359` Improves error handling with violated constaints (:ghuser:`timmens`)
- :gh:`358` Improves cartis roberts set of test functions and improves the
default latex rendering of MultiIndex tables (:ghuser:`mpetrosian`)


0.3.2
-----

- :gh:`355` Improves test coverage of contraints processing (:ghuser:`janosg`)
- :gh:`354` Improves test coverage for bounds processing (:ghuser:`timmens`)
- :gh:`353` Improves history plots (:ghuser:`timmens`)
- :gh:`352` Improves scaling and benchmarking (:ghuser:`janosg`)
- :gh:`351` Improves estimation summaries (:ghuser:`timmens`)
- :gh:`350` Allow empty queries or selectors in constraints (:ghuser:`janosg`)

0.3.1
-----

Expand Down
3 changes: 1 addition & 2 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,12 @@ dependencies:
- sqlalchemy>=1.3
- tox-conda
- statsmodels
- dill
- pytask>=0.0.11
- nlopt
- sphinx-panels
- pygmo
- nb_black
- pybaum
- pybaum>=0.1.2

- pip:
- black
Expand Down
3 changes: 2 additions & 1 deletion src/estimagic/parameters/tree_conversion.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from estimagic.parameters.block_trees import block_tree_to_matrix
from estimagic.parameters.parameter_bounds import get_bounds
from estimagic.parameters.tree_registry import get_registry
from estimagic.utilities import isscalar
from pybaum import leaf_names
from pybaum import tree_flatten
from pybaum import tree_just_flatten
Expand Down Expand Up @@ -130,7 +131,7 @@ def params_unflatten(x):

def _get_func_flatten(registry, func_eval, primary_key):

if np.isscalar(func_eval):
if isscalar(func_eval):
if primary_key == "value":
func_flatten = lambda func_eval: float(func_eval)
else:
Expand Down
4 changes: 3 additions & 1 deletion src/estimagic/parameters/tree_registry.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,9 @@ def get_registry(extended=False, data_col="value"):
dict: The pytree registry.
"""
types = ["numpy.ndarray", "pandas.Series"] if extended else None
types = (
["numpy.ndarray", "pandas.Series", "jax.numpy.ndarray"] if extended else None
)
registry = get_pybaum_registry(types=types)
if extended:
registry[pd.DataFrame] = {
Expand Down
6 changes: 5 additions & 1 deletion src/estimagic/process_user_function.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,11 @@ def process_func_of_params(func, kwargs, name="your function", skip_checks=False
required_args = unpartialled_args.intersection(no_default_args)
too_many_required_arguments = len(required_args) > 1

if too_many_required_arguments:
# Try to discover if we have a jax calculated jacobian that has a weird
# signature that would not pass this test:
skip_because_of_jax = required_args == {"args", "kwargs"}

if too_many_required_arguments and not skip_because_of_jax:
raise InvalidKwargsError(
f"Too few keyword arguments for {name}. After applying all keyword "
"arguments at most one required argument (the params) should remain. "
Expand Down
11 changes: 11 additions & 0 deletions src/estimagic/utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,3 +293,14 @@ def to_pickle(obj, path):

def read_pickle(path):
return pd.read_pickle(path)


def isscalar(element):
"""Jax aware replacement for np.isscalar."""
if np.isscalar(element):
return True
# call anything a scalar that says it has 0 dimensions
elif getattr(element, "ndim", -1) == 0:
return True
else:
return False
122 changes: 122 additions & 0 deletions tests/optimization/test_jax_derivatives.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import numpy as np
import pytest
from estimagic.config import IS_JAX_INSTALLED
from estimagic.optimization.optimize import minimize
from numpy.testing import assert_array_almost_equal as aaae

if IS_JAX_INSTALLED:
import jax.numpy as jnp
import jax


@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.")
def test_scipy_conference_example():
def criterion(x):
first = (x["a"] - jnp.pi) ** 2
second = jnp.linalg.norm(x["b"] - jnp.arange(3))
third = jnp.linalg.norm(x["c"] - jnp.eye(2))
return first + second + third

start_params = {
"a": 1.0,
"b": jnp.ones(3).astype(float),
"c": jnp.ones((2, 2)).astype(float),
}

gradient = jax.grad(criterion)

res = minimize(
criterion=criterion,
derivative=gradient,
params=start_params,
algorithm="scipy_lbfgsb",
)

assert isinstance(res.params["b"], jnp.ndarray)
aaae(res.params["b"], jnp.arange(3))
aaae(res.params["c"], jnp.eye(2))
assert np.allclose(res.params["a"], np.pi, atol=1e-4)


@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.")
def test_params_is_jax_scalar():
def criterion(x):
return x**2

res = minimize(
criterion=criterion,
params=jnp.array(1.0),
algorithm="scipy_lbfgsb",
derivative=jax.grad(criterion),
)

assert isinstance(res.params, jnp.ndarray)
assert np.allclose(res.params, 0.0)


@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.")
def params_is_1d_array():
def criterion(x):
return x @ x

res = minimize(
criterion=criterion,
params=jnp.arange(3),
algorithm="scipy_lbfgsb",
derivative=jax.grad(criterion),
)

assert isinstance(res.params, jnp.ndarray)
assert aaae(res.params, jnp.arange(3))


@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.")
@pytest.mark.parametrize("algorithm", ["scipy_lbfgsb", "scipy_ls_lm"])
def test_dict_output_works(algorithm):
def criterion(x):
return {"root_contributions": x, "value": x @ x}

def scalar_wrapper(x):
return criterion(x)["value"]

def ls_wrapper(x):
return criterion(x)["root_contributions"]

deriv_dict = {
"value": jax.grad(scalar_wrapper),
"root_contributions": jax.jacobian(ls_wrapper),
}

res = minimize(
criterion=criterion,
params=jnp.array([1.0, 2.0, 3.0]),
algorithm=algorithm,
derivative=deriv_dict,
)

assert isinstance(res.params, jnp.ndarray)
aaae(res.params, np.zeros(3))


@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.")
def test_least_squares_optimizer_pytree():
def criterion(x):
return {"root_contributions": x}

def ls_wrapper(x):
return criterion(x)["root_contributions"]

params = {"a": 1.0, "b": 2.0, "c": jnp.array([1.0, 2.0])}
jac = jax.jacobian(ls_wrapper)

res = minimize(
criterion=criterion,
params=params,
algorithm="scipy_ls_lm",
derivative=jac,
)

assert isinstance(res.params, dict)
assert np.allclose(res.params["a"], 0)
assert np.allclose(res.params["b"], 0)
aaae(res.params["c"], np.zeros(2))
1 change: 1 addition & 0 deletions tests/optimization/test_trust_region_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def test_get_next_trust_region_points_latin_hypercube_single_use(
aaae(sample.mean(axis=0), center, decimal=decimal)


@pytest.mark.slow
@pytest.mark.parametrize(
"optimality_criterion",
["a-optimal", "e-optimal", "d-optimal", "g-optimal", "maximin"],
Expand Down
2 changes: 2 additions & 0 deletions tests/optimization/test_with_nonlinear_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
import pytest
from estimagic import maximize
from estimagic import minimize
from estimagic.config import IS_CYIPOPT_INSTALLED
from estimagic.optimization import AVAILABLE_ALGORITHMS
from numpy.testing import assert_array_almost_equal as aaae

Expand Down Expand Up @@ -233,6 +234,7 @@ def constraint(selected):
TEST_CASES = list(itertools.product(["ipopt"], [True, False]))


@pytest.mark.skipif(not IS_CYIPOPT_INSTALLED, reason="Needs ipopt")
@pytest.mark.parametrize("algorithm, skip_checks", TEST_CASES)
def test_general_example(general_example, algorithm, skip_checks):

Expand Down
34 changes: 34 additions & 0 deletions tests/test_utilities.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import numpy as np
import pandas as pd
import pytest
from estimagic.config import IS_JAX_INSTALLED
from estimagic.utilities import calculate_trustregion_initial_radius
from estimagic.utilities import chol_params_to_lower_triangular_matrix
from estimagic.utilities import cov_matrix_to_params
Expand All @@ -9,6 +10,7 @@
from estimagic.utilities import cov_to_sds_and_corr
from estimagic.utilities import dimension_to_number_of_triangular_elements
from estimagic.utilities import hash_array
from estimagic.utilities import isscalar
from estimagic.utilities import number_of_triangular_elements_to_dimension
from estimagic.utilities import read_pickle
from estimagic.utilities import robust_cholesky
Expand All @@ -19,6 +21,9 @@
from estimagic.utilities import to_pickle
from numpy.testing import assert_array_almost_equal as aaae

if IS_JAX_INSTALLED:
import jax.numpy as jnp


def test_chol_params_to_lower_triangular_matrix():
calculated = chol_params_to_lower_triangular_matrix(pd.Series([1, 2, 3]))
Expand Down Expand Up @@ -176,3 +181,32 @@ def test_pickling(tmp_path):
to_pickle(a, path)
b = read_pickle(path)
assert a == b


SCALARS = [1, 2.0, np.pi, np.array(1), np.array(2.0), np.array(np.pi), np.nan]


@pytest.mark.parametrize("element", SCALARS)
def test_isscalar_true(element):
assert isscalar(element) is True


NON_SCALARS = [np.arange(3), {"a": 1}, [1, 2, 3]]


@pytest.mark.parametrize("element", NON_SCALARS)
def test_isscalar_false(element):
assert isscalar(element) is False


@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.")
def tets_isscalar_jax_true():
x = jnp.arange(3)
element = x @ x
assert isscalar(element) is True


@pytest.mark.skipif(not IS_JAX_INSTALLED, reason="Needs jax.")
def test_isscalar_jax_false():
element = jnp.arange(3)
assert isscalar(element) is False
Loading

0 comments on commit 33c47b6

Please sign in to comment.