From ef8edb15de3853c33fc91321d3f14c45ded79e6b Mon Sep 17 00:00:00 2001 From: John Franklin Crenshaw Date: Thu, 17 Mar 2022 21:42:34 -0700 Subject: [PATCH] Updated depecrated jax methods. --- pzflow/bijectors/bijectors.py | 49 ++++++++++------------------------- pzflow/flow.py | 16 ++++-------- pzflow/flowEnsemble.py | 2 +- pzflow/utils.py | 2 +- pzflow/version.py | 2 +- tests/test_bijectors.py | 7 +++-- tests/test_utils.py | 6 ++--- 7 files changed, 27 insertions(+), 57 deletions(-) diff --git a/pzflow/bijectors/bijectors.py b/pzflow/bijectors/bijectors.py index aabcd93..5aff658 100644 --- a/pzflow/bijectors/bijectors.py +++ b/pzflow/bijectors/bijectors.py @@ -2,7 +2,7 @@ from typing import Callable, Sequence, Tuple, Union import jax.numpy as np -from jax import ops, random +from jax import random # define a type alias for Jax Pytrees Pytree = Union[tuple, list] @@ -254,9 +254,7 @@ def mag0(outputs): else: def mag0(outputs): - return ops.index_update( - outputs, - ops.index[:, mag0_idx], + return outputs.at[:, mag0_idx].set( outputs[:, mag0_idx] + outputs[:, new_ref], indices_are_sorted=True, unique_indices=True, @@ -291,9 +289,7 @@ def inverse_fun(params, inputs, **kwargs): # calculate mag[0] outputs = mag0(outputs) # mag[i] = mag[0] - (mag[0] - mag[i]) - outputs = ops.index_update( - outputs, - ops.index[:, mag0_idx + 1 :], + outputs = outputs.at[:, mag0_idx + 1 :].set( outputs[:, mag0_idx, None] - outputs[:, mag0_idx + 1 :], indices_are_sorted=True, unique_indices=True, @@ -351,22 +347,18 @@ def InvSoftplus( def init_fun(rng, input_dim, **kwargs): @ForwardFunction def forward_fun(params, inputs, **kwargs): - outputs = ops.index_update( - inputs, - ops.index[:, idx], + outputs = inputs.at[:, idx].set( np.log(-1 + np.exp(k * inputs[:, idx])) / k, ) - log_det = np.log(1 + np.exp(-k * outputs[ops.index[:, idx]])).sum(axis=1) + log_det = np.log(1 + np.exp(-k * outputs[:, idx])).sum(axis=1) return outputs, log_det @InverseFunction def inverse_fun(params, inputs, **kwargs): - outputs = ops.index_update( - inputs, - ops.index[:, idx], + outputs = inputs.at[:, idx].set( np.log(1 + np.exp(k * inputs[:, idx])) / k, ) - log_det = -np.log(1 + np.exp(-k * inputs[ops.index[:, idx]])).sum(axis=1) + log_det = -np.log(1 + np.exp(-k * inputs[:, idx])).sum(axis=1) return outputs, log_det return (), forward_fun, inverse_fun @@ -640,7 +632,7 @@ def inverse_fun(params, inputs, **kwargs): @Bijector -def UniformDequantizer(column_idx: int = None) -> Tuple[InitFunction, Bijector_Info]: +def UniformDequantizer(column_idx: int) -> Tuple[InitFunction, Bijector_Info]: """Bijector that dequantizes discrete variables with uniform noise. Dequantizers are necessary for modeling discrete values with a flow. @@ -662,36 +654,21 @@ def UniformDequantizer(column_idx: int = None) -> Tuple[InitFunction, Bijector_I """ bijector_info = ("UniformDequantizer", (column_idx,)) - - if column_idx is None: - idx = ops.index[:, :] - else: - idx = ops.index[:, column_idx] + column_idx = np.array(column_idx) @InitFunction def init_fun(rng, input_dim, **kwargs): @ForwardFunction def forward_fun(params, inputs, **kwargs): - u = random.uniform(random.PRNGKey(0), shape=inputs[idx].shape) - outputs = ops.index_update( - inputs.astype(float), - idx, - inputs[idx].astype(float) + u, - indices_are_sorted=True, - unique_indices=True, - ) + u = random.uniform(random.PRNGKey(0), shape=inputs[:, column_idx].shape) + outputs = inputs.astype(float) + outputs.at[:, column_idx].set(outputs[:, column_idx] + u) log_det = np.zeros(inputs.shape[0]) return outputs, log_det @InverseFunction def inverse_fun(params, inputs, **kwargs): - outputs = ops.index_update( - inputs, - idx, - np.floor(inputs[idx]), - indices_are_sorted=True, - unique_indices=True, - ) + outputs = inputs.at[:, column_idx].set(np.floor(inputs[:, column_idx])) log_det = np.zeros(inputs.shape[0]) return outputs, log_det diff --git a/pzflow/flow.py b/pzflow/flow.py index 80c2ec5..c6eeb10 100644 --- a/pzflow/flow.py +++ b/pzflow/flow.py @@ -6,7 +6,7 @@ import numpy as onp import pandas as pd from jax import grad, jit, ops, random -from jax.experimental.optimizers import Optimizer, adam +from jax.example_libraries.optimizers import Optimizer, adam from pzflow import distributions from pzflow.bijectors import Bijector_Info, InitFunction, Pytree @@ -246,7 +246,7 @@ def _get_err_samples( type: str = "data", skip: str = None, ) -> np.ndarray: - """Draw error samples for each row of inputs. """ + """Draw error samples for each row of inputs.""" X = inputs.copy() @@ -476,9 +476,7 @@ def check_flags(data): ) # save these pdfs in the big array - pdfs = ops.index_update( - pdfs, - ops.index[unflagged_idx, :], + pdfs = pdfs.at[unflagged_idx, :].set( unflagged_pdfs, indices_are_sorted=True, unique_indices=True, @@ -543,9 +541,7 @@ def check_flags(data): marg_pdfs = marg_pdfs.sum(axis=1) # save the new pdfs in the big array - pdfs = ops.index_update( - pdfs, - ops.index[flagged_idx, :], + pdfs = pdfs.at[flagged_idx, :].set( marg_pdfs, indices_are_sorted=True, unique_indices=True, @@ -614,9 +610,7 @@ def check_flags(data): prob = prob.reshape(-1, err_samples, len(grid)) prob = prob.mean(axis=1) # add the pdfs to the bigger list - pdfs = ops.index_update( - pdfs, - ops.index[batch_idx : batch_idx + batch_size, :], + pdfs = pdfs.at[batch_idx : batch_idx + batch_size, :].set( prob, indices_are_sorted=True, unique_indices=True, diff --git a/pzflow/flowEnsemble.py b/pzflow/flowEnsemble.py index 3d760e2..78bba11 100644 --- a/pzflow/flowEnsemble.py +++ b/pzflow/flowEnsemble.py @@ -5,7 +5,7 @@ import numpy as onp import pandas as pd from jax import random -from jax.experimental.optimizers import Optimizer +from jax.example_libraries.optimizers import Optimizer from pzflow import Flow from pzflow.bijectors import Bijector_Info, InitFunction diff --git a/pzflow/utils.py b/pzflow/utils.py index ca9ee54..af5c412 100644 --- a/pzflow/utils.py +++ b/pzflow/utils.py @@ -2,7 +2,7 @@ import jax.numpy as np from jax import random -from jax.experimental.stax import Dense, LeakyRelu, serial +from jax.example_libraries.stax import Dense, LeakyRelu, serial from pzflow import bijectors diff --git a/pzflow/version.py b/pzflow/version.py index e4c0eff..36a3365 100644 --- a/pzflow/version.py +++ b/pzflow/version.py @@ -2,4 +2,4 @@ # 1. We can control the version number in one location # 2. It is accessible from the package (see __init__.py) # 3. We can access it from setup.py without loading pzflow. -__version__ = "2.0.6" +__version__ = "2.0.7" diff --git a/tests/test_bijectors.py b/tests/test_bijectors.py index 2b1738e..f0d93cc 100644 --- a/tests/test_bijectors.py +++ b/tests/test_bijectors.py @@ -99,9 +99,8 @@ def test_bad_inputs(bijector, args): bijector(*args) -@pytest.mark.parametrize("column_idx", [(None), ([1, 3, 5])]) -def test_uniform_dequantizer_returns_correct_shape(column_idx): - init_fun, bijector_info = UniformDequantizer(column_idx) +def test_uniform_dequantizer_returns_correct_shape(): + init_fun, bijector_info = UniformDequantizer([1, 3, 4]) params, forward_fun, inverse_fun = init_fun(random.PRNGKey(0), x.shape[-1]) conditions = np.zeros((3, 1)) @@ -111,4 +110,4 @@ def test_uniform_dequantizer_returns_correct_shape(column_idx): inv_outputs, inv_log_det = inverse_fun(params, x, conditions=conditions) assert inv_outputs.shape == x.shape - assert inv_log_det.shape == x.shape[:1] \ No newline at end of file + assert inv_log_det.shape == x.shape[:1] diff --git a/tests/test_utils.py b/tests/test_utils.py index 5f5cf52..2ecb0af 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,5 +1,5 @@ import jax.numpy as np -from jax import random, ops +from jax import random from pzflow.bijectors import * from pzflow.utils import * import pytest @@ -37,7 +37,7 @@ def test_sub_diag_indices_correct(): x = np.array([[[0, 0], [0, 0]], [[1, 1], [1, 1]], [[2, 2], [2, 2]]]) y = np.array([[[1, 0], [0, 1]], [[2, 1], [1, 2]], [[3, 2], [2, 3]]]) idx = sub_diag_indices(x) - x = ops.index_update(x, idx, x[idx] + 1) + x = x.at[idx].set(x[idx] + 1) assert np.allclose(x, y) @@ -48,4 +48,4 @@ def test_sub_diag_indices_correct(): ) def test_sub_diag_indices_bad_input(x): with pytest.raises(ValueError): - idx = sub_diag_indices(x) \ No newline at end of file + idx = sub_diag_indices(x)