Skip to content

Commit

Permalink
Merge pull request #69 from jfcrenshaw/dev
Browse files Browse the repository at this point in the history
Updated depecrated jax methods.
  • Loading branch information
jfcrenshaw authored Mar 18, 2022
2 parents 66268fe + ef8edb1 commit 7e75e8d
Show file tree
Hide file tree
Showing 7 changed files with 27 additions and 57 deletions.
49 changes: 13 additions & 36 deletions pzflow/bijectors/bijectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
from typing import Callable, Sequence, Tuple, Union

import jax.numpy as np
from jax import ops, random
from jax import random

# define a type alias for Jax Pytrees
Pytree = Union[tuple, list]
Expand Down Expand Up @@ -254,9 +254,7 @@ def mag0(outputs):
else:

def mag0(outputs):
return ops.index_update(
outputs,
ops.index[:, mag0_idx],
return outputs.at[:, mag0_idx].set(
outputs[:, mag0_idx] + outputs[:, new_ref],
indices_are_sorted=True,
unique_indices=True,
Expand Down Expand Up @@ -291,9 +289,7 @@ def inverse_fun(params, inputs, **kwargs):
# calculate mag[0]
outputs = mag0(outputs)
# mag[i] = mag[0] - (mag[0] - mag[i])
outputs = ops.index_update(
outputs,
ops.index[:, mag0_idx + 1 :],
outputs = outputs.at[:, mag0_idx + 1 :].set(
outputs[:, mag0_idx, None] - outputs[:, mag0_idx + 1 :],
indices_are_sorted=True,
unique_indices=True,
Expand Down Expand Up @@ -351,22 +347,18 @@ def InvSoftplus(
def init_fun(rng, input_dim, **kwargs):
@ForwardFunction
def forward_fun(params, inputs, **kwargs):
outputs = ops.index_update(
inputs,
ops.index[:, idx],
outputs = inputs.at[:, idx].set(
np.log(-1 + np.exp(k * inputs[:, idx])) / k,
)
log_det = np.log(1 + np.exp(-k * outputs[ops.index[:, idx]])).sum(axis=1)
log_det = np.log(1 + np.exp(-k * outputs[:, idx])).sum(axis=1)
return outputs, log_det

@InverseFunction
def inverse_fun(params, inputs, **kwargs):
outputs = ops.index_update(
inputs,
ops.index[:, idx],
outputs = inputs.at[:, idx].set(
np.log(1 + np.exp(k * inputs[:, idx])) / k,
)
log_det = -np.log(1 + np.exp(-k * inputs[ops.index[:, idx]])).sum(axis=1)
log_det = -np.log(1 + np.exp(-k * inputs[:, idx])).sum(axis=1)
return outputs, log_det

return (), forward_fun, inverse_fun
Expand Down Expand Up @@ -640,7 +632,7 @@ def inverse_fun(params, inputs, **kwargs):


@Bijector
def UniformDequantizer(column_idx: int = None) -> Tuple[InitFunction, Bijector_Info]:
def UniformDequantizer(column_idx: int) -> Tuple[InitFunction, Bijector_Info]:
"""Bijector that dequantizes discrete variables with uniform noise.
Dequantizers are necessary for modeling discrete values with a flow.
Expand All @@ -662,36 +654,21 @@ def UniformDequantizer(column_idx: int = None) -> Tuple[InitFunction, Bijector_I
"""

bijector_info = ("UniformDequantizer", (column_idx,))

if column_idx is None:
idx = ops.index[:, :]
else:
idx = ops.index[:, column_idx]
column_idx = np.array(column_idx)

@InitFunction
def init_fun(rng, input_dim, **kwargs):
@ForwardFunction
def forward_fun(params, inputs, **kwargs):
u = random.uniform(random.PRNGKey(0), shape=inputs[idx].shape)
outputs = ops.index_update(
inputs.astype(float),
idx,
inputs[idx].astype(float) + u,
indices_are_sorted=True,
unique_indices=True,
)
u = random.uniform(random.PRNGKey(0), shape=inputs[:, column_idx].shape)
outputs = inputs.astype(float)
outputs.at[:, column_idx].set(outputs[:, column_idx] + u)
log_det = np.zeros(inputs.shape[0])
return outputs, log_det

@InverseFunction
def inverse_fun(params, inputs, **kwargs):
outputs = ops.index_update(
inputs,
idx,
np.floor(inputs[idx]),
indices_are_sorted=True,
unique_indices=True,
)
outputs = inputs.at[:, column_idx].set(np.floor(inputs[:, column_idx]))
log_det = np.zeros(inputs.shape[0])
return outputs, log_det

Expand Down
16 changes: 5 additions & 11 deletions pzflow/flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as onp
import pandas as pd
from jax import grad, jit, ops, random
from jax.experimental.optimizers import Optimizer, adam
from jax.example_libraries.optimizers import Optimizer, adam

from pzflow import distributions
from pzflow.bijectors import Bijector_Info, InitFunction, Pytree
Expand Down Expand Up @@ -246,7 +246,7 @@ def _get_err_samples(
type: str = "data",
skip: str = None,
) -> np.ndarray:
"""Draw error samples for each row of inputs. """
"""Draw error samples for each row of inputs."""

X = inputs.copy()

Expand Down Expand Up @@ -476,9 +476,7 @@ def check_flags(data):
)

# save these pdfs in the big array
pdfs = ops.index_update(
pdfs,
ops.index[unflagged_idx, :],
pdfs = pdfs.at[unflagged_idx, :].set(
unflagged_pdfs,
indices_are_sorted=True,
unique_indices=True,
Expand Down Expand Up @@ -543,9 +541,7 @@ def check_flags(data):
marg_pdfs = marg_pdfs.sum(axis=1)

# save the new pdfs in the big array
pdfs = ops.index_update(
pdfs,
ops.index[flagged_idx, :],
pdfs = pdfs.at[flagged_idx, :].set(
marg_pdfs,
indices_are_sorted=True,
unique_indices=True,
Expand Down Expand Up @@ -614,9 +610,7 @@ def check_flags(data):
prob = prob.reshape(-1, err_samples, len(grid))
prob = prob.mean(axis=1)
# add the pdfs to the bigger list
pdfs = ops.index_update(
pdfs,
ops.index[batch_idx : batch_idx + batch_size, :],
pdfs = pdfs.at[batch_idx : batch_idx + batch_size, :].set(
prob,
indices_are_sorted=True,
unique_indices=True,
Expand Down
2 changes: 1 addition & 1 deletion pzflow/flowEnsemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import numpy as onp
import pandas as pd
from jax import random
from jax.experimental.optimizers import Optimizer
from jax.example_libraries.optimizers import Optimizer

from pzflow import Flow
from pzflow.bijectors import Bijector_Info, InitFunction
Expand Down
2 changes: 1 addition & 1 deletion pzflow/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import jax.numpy as np
from jax import random
from jax.experimental.stax import Dense, LeakyRelu, serial
from jax.example_libraries.stax import Dense, LeakyRelu, serial

from pzflow import bijectors

Expand Down
2 changes: 1 addition & 1 deletion pzflow/version.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,4 +2,4 @@
# 1. We can control the version number in one location
# 2. It is accessible from the package (see __init__.py)
# 3. We can access it from setup.py without loading pzflow.
__version__ = "2.0.6"
__version__ = "2.0.7"
7 changes: 3 additions & 4 deletions tests/test_bijectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,9 +99,8 @@ def test_bad_inputs(bijector, args):
bijector(*args)


@pytest.mark.parametrize("column_idx", [(None), ([1, 3, 5])])
def test_uniform_dequantizer_returns_correct_shape(column_idx):
init_fun, bijector_info = UniformDequantizer(column_idx)
def test_uniform_dequantizer_returns_correct_shape():
init_fun, bijector_info = UniformDequantizer([1, 3, 4])
params, forward_fun, inverse_fun = init_fun(random.PRNGKey(0), x.shape[-1])

conditions = np.zeros((3, 1))
Expand All @@ -111,4 +110,4 @@ def test_uniform_dequantizer_returns_correct_shape(column_idx):

inv_outputs, inv_log_det = inverse_fun(params, x, conditions=conditions)
assert inv_outputs.shape == x.shape
assert inv_log_det.shape == x.shape[:1]
assert inv_log_det.shape == x.shape[:1]
6 changes: 3 additions & 3 deletions tests/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import jax.numpy as np
from jax import random, ops
from jax import random
from pzflow.bijectors import *
from pzflow.utils import *
import pytest
Expand Down Expand Up @@ -37,7 +37,7 @@ def test_sub_diag_indices_correct():
x = np.array([[[0, 0], [0, 0]], [[1, 1], [1, 1]], [[2, 2], [2, 2]]])
y = np.array([[[1, 0], [0, 1]], [[2, 1], [1, 2]], [[3, 2], [2, 3]]])
idx = sub_diag_indices(x)
x = ops.index_update(x, idx, x[idx] + 1)
x = x.at[idx].set(x[idx] + 1)

assert np.allclose(x, y)

Expand All @@ -48,4 +48,4 @@ def test_sub_diag_indices_correct():
)
def test_sub_diag_indices_bad_input(x):
with pytest.raises(ValueError):
idx = sub_diag_indices(x)
idx = sub_diag_indices(x)

0 comments on commit 7e75e8d

Please sign in to comment.