Skip to content

Commit

Permalink
Make the codebase compatible with pytest, pylint, pytype.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 368041519
  • Loading branch information
hbq1 authored and DistraxDev committed Apr 12, 2021
1 parent d2f1c05 commit 547c0db
Show file tree
Hide file tree
Showing 29 changed files with 237 additions and 169 deletions.
1 change: 0 additions & 1 deletion distrax/_src/bijectors/bijector.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def inverse_log_det_jacobian(self, y: Array) -> Array:
@abc.abstractmethod
def forward_and_log_det(self, x: Array) -> Tuple[Array, Array]:
"""Computes y = f(x) and log|det J(f)(x)|."""
pass

def inverse_and_log_det(self, y: Array) -> Tuple[Array, Array]:
"""Computes x = f^{-1}(y) and log|det J(f^{-1})(y)|."""
Expand Down
1 change: 1 addition & 0 deletions distrax/_src/bijectors/bijector_from_tfp.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@ def _ensure_batch_shape(self,
event_ndims_out: int,
forward_fn: Callable[[Array], Array],
x: Array) -> Array:
"""Broadcasts logdet to the batch shape as required."""
if self._tfp_bijector.is_constant_jacobian:
# If the Jacobian is constant, TFP may return a log det that doesn't have
# full batch shape, but is broadcastable to it. Distrax assumes that the
Expand Down
183 changes: 98 additions & 85 deletions distrax/_src/bijectors/bijector_from_tfp_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,41 +24,52 @@
import numpy as np
from tensorflow_probability.substrates import jax as tfp


tfb = tfp.bijectors


def _batched_chain():
return tfb.Chain([
tfb.Shift(jnp.zeros((4, 2, 3))),
tfb.ScaleMatvecDiag([[1., 2., 3.], [4., 5., 6.]])
])


class BijectorFromTFPTest(parameterized.TestCase):

def setUp(self):
super().setUp()
bjs = {}
bjs['BatchedChain'] = tfb.Chain([
tfb.Shift(jnp.zeros((4, 2, 3))),
tfb.ScaleMatvecDiag([[1., 2., 3.], [4., 5., 6.]])
])
bjs['Square'] = tfb.Square()
bjs['ScaleScalar'] = tfb.Scale(2.)
bjs['ScaleMatrix'] = tfb.Scale(2. * jnp.ones((3, 2)))
bjs['Reshape'] = tfb.Reshape((2, 3), (6,))

# To parallelize pytest runs.
# See https://github.com/pytest-dev/pytest-xdist/issues/432.
for name, bij in bjs.items():
bij.__repr__ = lambda _, name_=name: name_

self._test_bijectors = bjs

@chex.all_variants
@parameterized.parameters(
(tfb.Square, (), (), (), ()),
(tfb.Square, (2, 3), (), (2, 3), ()),
(lambda: tfb.Scale(2.), (), (), (), ()),
(lambda: tfb.Scale(2.), (2, 3), (), (2, 3), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (), (), (3, 2), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (2,), (), (3, 2), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (1, 1), (), (3, 2), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (4, 1, 1), (), (4, 3, 2), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (4, 3, 2), (), (4, 3, 2), ()),
(lambda: tfb.Reshape((2, 3), (6,)), (), (6,), (), (2, 3)),
(lambda: tfb.Reshape((2, 3), (6,)), (10,), (6,), (10,), (2, 3)),
(_batched_chain, (), (3,), (4, 2), (3,)),
(_batched_chain, (2,), (3,), (4, 2), (3,)),
(_batched_chain, (4, 1), (3,), (4, 2), (3,)),
(_batched_chain, (5, 1, 2), (3,), (5, 4, 2), (3,)),
('Square', (), (), (), ()),
('Square', (2, 3), (), (2, 3), ()),
('ScaleScalar', (), (), (), ()),
('ScaleScalar', (2, 3), (), (2, 3), ()),
('ScaleMatrix', (), (), (3, 2), ()),
('ScaleMatrix', (2,), (), (3, 2), ()),
('ScaleMatrix', (1, 1), (), (3, 2), ()),
('ScaleMatrix', (4, 1, 1), (), (4, 3, 2), ()),
('ScaleMatrix', (4, 3, 2), (), (4, 3, 2), ()),
('Reshape', (), (6,), (), (2, 3)),
('Reshape', (10,), (6,), (10,), (2, 3)),
('BatchedChain', (), (3,), (4, 2), (3,)),
('BatchedChain', (2,), (3,), (4, 2), (3,)),
('BatchedChain', (4, 1), (3,), (4, 2), (3,)),
('BatchedChain', (5, 1, 2), (3,), (5, 4, 2), (3,)),
)
def test_forward_methods_are_correct(self, tfp_bij,
batch_shape_in, event_shape_in,
batch_shape_out, event_shape_out):
tfp_bij = tfp_bij()
def test_forward_methods_are_correct(self, tfp_bij_name, batch_shape_in,
event_shape_in, batch_shape_out,
event_shape_out):
tfp_bij = self._test_bijectors[tfp_bij_name]
bij = bijector_from_tfp.BijectorFromTFP(tfp_bij)
key = jax.random.PRNGKey(42)
x = jax.random.uniform(key, batch_shape_in + event_shape_in)
Expand All @@ -76,26 +87,26 @@ def test_forward_methods_are_correct(self, tfp_bij,

@chex.all_variants
@parameterized.parameters(
(tfb.Square, (), (), (), ()),
(tfb.Square, (2, 3), (), (2, 3), ()),
(lambda: tfb.Scale(2.), (), (), (), ()),
(lambda: tfb.Scale(2.), (2, 3), (), (2, 3), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (3, 2), (), (), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (3, 2), (), (2,), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (3, 2), (), (1, 1), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (4, 3, 2), (), (4, 1, 1), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (4, 3, 2), (), (4, 3, 2), ()),
(lambda: tfb.Reshape((2, 3), (6,)), (), (6,), (), (2, 3)),
(lambda: tfb.Reshape((2, 3), (6,)), (10,), (6,), (10,), (2, 3)),
(_batched_chain, (4, 2), (3,), (), (3,)),
(_batched_chain, (4, 2), (3,), (2,), (3,)),
(_batched_chain, (4, 2), (3,), (4, 1), (3,)),
(_batched_chain, (5, 4, 2), (3,), (5, 1, 2), (3,)),
('Square', (), (), (), ()),
('Square', (2, 3), (), (2, 3), ()),
('ScaleScalar', (), (), (), ()),
('ScaleScalar', (2, 3), (), (2, 3), ()),
('ScaleMatrix', (3, 2), (), (), ()),
('ScaleMatrix', (3, 2), (), (2,), ()),
('ScaleMatrix', (3, 2), (), (1, 1), ()),
('ScaleMatrix', (4, 3, 2), (), (4, 1, 1), ()),
('ScaleMatrix', (4, 3, 2), (), (4, 3, 2), ()),
('Reshape', (), (6,), (), (2, 3)),
('Reshape', (10,), (6,), (10,), (2, 3)),
('BatchedChain', (4, 2), (3,), (), (3,)),
('BatchedChain', (4, 2), (3,), (2,), (3,)),
('BatchedChain', (4, 2), (3,), (4, 1), (3,)),
('BatchedChain', (5, 4, 2), (3,), (5, 1, 2), (3,)),
)
def test_inverse_methods_are_correct(self, tfp_bij,
batch_shape_in, event_shape_in,
batch_shape_out, event_shape_out):
tfp_bij = tfp_bij()
def test_inverse_methods_are_correct(self, tfp_bij_name, batch_shape_in,
event_shape_in, batch_shape_out,
event_shape_out):
tfp_bij = self._test_bijectors[tfp_bij_name]
bij = bijector_from_tfp.BijectorFromTFP(tfp_bij)
key = jax.random.PRNGKey(42)
y = jax.random.uniform(key, batch_shape_out + event_shape_out)
Expand All @@ -113,27 +124,28 @@ def test_inverse_methods_are_correct(self, tfp_bij,

@chex.all_variants
@parameterized.parameters(
(tfb.Square, (), (), (), ()),
(tfb.Square, (2, 3), (), (2, 3), ()),
(lambda: tfb.Scale(2.), (), (), (), ()),
(lambda: tfb.Scale(2.), (2, 3), (), (2, 3), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (), (), (), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (2,), (), (2,), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (1, 1), (), (1, 1), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (4, 1, 1), (), (4, 1, 1), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (4, 3, 2), (), (4, 3, 2), ()),
(lambda: tfb.Reshape((2, 3), (6,)), (), (6,), (), (2, 3)),
(lambda: tfb.Reshape((2, 3), (6,)), (10,), (6,), (10,), (2, 3)),
(_batched_chain, (), (3,), (), (3,)),
(_batched_chain, (2,), (3,), (2,), (3,)),
(_batched_chain, (4, 1), (3,), (4, 1), (3,)),
(_batched_chain, (5, 1, 2), (3,), (5, 1, 2), (3,)),
('Square', (), (), (), ()),
('Square', (2, 3), (), (2, 3), ()),
('ScaleScalar', (), (), (), ()),
('ScaleScalar', (2, 3), (), (2, 3), ()),
('ScaleMatrix', (), (), (), ()),
('ScaleMatrix', (2,), (), (2,), ()),
('ScaleMatrix', (1, 1), (), (1, 1), ()),
('ScaleMatrix', (4, 1, 1), (), (4, 1, 1), ()),
('ScaleMatrix', (4, 3, 2), (), (4, 3, 2), ()),
('Reshape', (), (6,), (), (2, 3)),
('Reshape', (10,), (6,), (10,), (2, 3)),
('BatchedChain', (), (3,), (), (3,)),
('BatchedChain', (2,), (3,), (2,), (3,)),
('BatchedChain', (4, 1), (3,), (4, 1), (3,)),
('BatchedChain', (5, 1, 2), (3,), (5, 1, 2), (3,)),
)
def test_composite_methods_are_consistent(self, tfp_bij,
batch_shape_in, event_shape_in,
batch_shape_out, event_shape_out):
def test_composite_methods_are_consistent(self, tfp_bij_name, batch_shape_in,
event_shape_in, batch_shape_out,
event_shape_out):
key1, key2 = jax.random.split(jax.random.PRNGKey(42))
bij = bijector_from_tfp.BijectorFromTFP(tfp_bij())
tfp_bij = self._test_bijectors[tfp_bij_name]
bij = bijector_from_tfp.BijectorFromTFP(tfp_bij)

# Forward methods.
x = jax.random.uniform(key1, batch_shape_in + event_shape_in)
Expand All @@ -157,26 +169,26 @@ def test_composite_methods_are_consistent(self, tfp_bij,

@chex.all_variants
@parameterized.parameters(
(tfb.Square, (), (), (), ()),
(tfb.Square, (2, 3), (), (2, 3), ()),
(lambda: tfb.Scale(2.), (), (), (), ()),
(lambda: tfb.Scale(2.), (2, 3), (), (2, 3), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (), (), (), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (2,), (), (2,), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (1, 1), (), (1, 1), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (4, 1, 1), (), (4, 1, 1), ()),
(lambda: tfb.Scale(2. * jnp.ones((3, 2))), (4, 3, 2), (), (4, 3, 2), ()),
(lambda: tfb.Reshape((2, 3), (6,)), (), (6,), (), (2, 3)),
(lambda: tfb.Reshape((2, 3), (6,)), (10,), (6,), (10,), (2, 3)),
(_batched_chain, (), (3,), (), (3,)),
(_batched_chain, (2,), (3,), (2,), (3,)),
(_batched_chain, (4, 1), (3,), (4, 1), (3,)),
(_batched_chain, (5, 1, 2), (3,), (5, 1, 2), (3,)),
('Square', (), (), (), ()),
('Square', (2, 3), (), (2, 3), ()),
('ScaleScalar', (), (), (), ()),
('ScaleScalar', (2, 3), (), (2, 3), ()),
('ScaleMatrix', (), (), (), ()),
('ScaleMatrix', (2,), (), (2,), ()),
('ScaleMatrix', (1, 1), (), (1, 1), ()),
('ScaleMatrix', (4, 1, 1), (), (4, 1, 1), ()),
('ScaleMatrix', (4, 3, 2), (), (4, 3, 2), ()),
('Reshape', (), (6,), (), (2, 3)),
('Reshape', (10,), (6,), (10,), (2, 3)),
('BatchedChain', (), (3,), (), (3,)),
('BatchedChain', (2,), (3,), (2,), (3,)),
('BatchedChain', (4, 1), (3,), (4, 1), (3,)),
('BatchedChain', (5, 1, 2), (3,), (5, 1, 2), (3,)),
)
def test_works_with_tfp_caching(self, tfp_bij,
batch_shape_in, event_shape_in,
batch_shape_out, event_shape_out):
tfp_bij = tfp_bij()
def test_works_with_tfp_caching(self, tfp_bij_name, batch_shape_in,
event_shape_in, batch_shape_out,
event_shape_out):
tfp_bij = self._test_bijectors[tfp_bij_name]
bij = bijector_from_tfp.BijectorFromTFP(tfp_bij)
key1, key2 = jax.random.split(jax.random.PRNGKey(42))

Expand All @@ -203,7 +215,7 @@ def test_works_with_tfp_caching(self, tfp_bij,
np.testing.assert_allclose(logdet1, logdet2, atol=1e-8)

def test_access_properties_tfp_bijector(self):
tfp_bij = _batched_chain()
tfp_bij = self._test_bijectors['BatchedChain']
bij = bijector_from_tfp.BijectorFromTFP(tfp_bij)
# Access the attribute `bijectors`
np.testing.assert_allclose(
Expand All @@ -212,6 +224,7 @@ def test_access_properties_tfp_bijector(self):
bij.bijectors[1].scale.diag, tfp_bij.bijectors[1].scale.diag, atol=1e-8)

def test_jittable(self):

@jax.jit
def f(x, b):
return b.forward(x)
Expand Down
2 changes: 1 addition & 1 deletion distrax/_src/bijectors/chain_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,7 +57,7 @@ def _with_base_dists(*all_named_parameters):
class ChainTest(parameterized.TestCase):

def setUp(self):
super(ChainTest, self).setUp()
super().setUp()
self.seed = jax.random.PRNGKey(1234)

def test_properties(self):
Expand Down
2 changes: 1 addition & 1 deletion distrax/_src/bijectors/lambda_bijector_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def _with_base_dists(*all_named_parameters):
class LambdaTest(parameterized.TestCase):

def setUp(self):
super(LambdaTest, self).setUp()
super().setUp()
self.seed = jax.random.PRNGKey(1234)

@parameterized.named_parameters(_with_base_dists(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,15 @@
import jax.numpy as jnp


def setUpModule():
jax_config.update('jax_enable_x64', True)


class RationalQuadraticSplineFloat64Test(chex.TestCase):
"""Tests for rational quadratic spline that use float64."""

def _assert_dtypes(self, bij, x, dtype):
"""Asserts dtypes."""
# Sanity check to make sure float64 is enabled.
x_64 = jnp.zeros([])
self.assertEqual(jnp.float64, x_64.dtype)
Expand Down Expand Up @@ -59,5 +65,4 @@ def test_dtypes(self, dtypes, boundary_slopes):


if __name__ == '__main__':
jax_config.update('jax_enable_x64', True)
absltest.main()
2 changes: 2 additions & 0 deletions distrax/_src/bijectors/rational_quadratic_spline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ def _make_bijector(params_shape,
range_min=0.,
range_max=1.,
boundary_slopes='unconstrained'):
"""Returns RationalQuadraticSpline bijector."""
params_shape += (3 * num_bins + 1,)
if zero_params:
params = jnp.zeros(params_shape)
Expand All @@ -44,6 +45,7 @@ def _make_bijector(params_shape,


class RationalQuadraticSplineTest(parameterized.TestCase):
"""Tests for rational quadratic spline."""

def test_properties(self):
bijector = _make_bijector(params_shape=(4, 5), num_bins=8)
Expand Down
1 change: 1 addition & 0 deletions distrax/_src/bijectors/split_coupling.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,7 @@ def _recombine(self, x1: Array, x2: Array) -> Array:
return jnp.concatenate([x1, x2], self._split_axis)

def _inner_bijector(self, params: BijectorParams) -> base.Bijector:
"""Returns an inner bijector for the passed params."""
bijector = conversion.as_bijector(self._bijector(params))
if bijector.event_ndims_in != bijector.event_ndims_out:
raise ValueError(
Expand Down
12 changes: 6 additions & 6 deletions distrax/_src/distributions/bernoulli.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,18 +109,18 @@ def _sample_n(self, key: PRNGKey, n: int) -> Array:
key=key, shape=new_shape, dtype=probs.dtype, minval=0., maxval=1.)
return jnp.less(uniform, probs).astype(self._dtype)

def log_prob(self, event: Array) -> Array:
def log_prob(self, value: Array) -> Array:
"""See `Distribution.log_prob`."""
log_probs0, log_probs1 = self._log_probs_parameter()
return (math.multiply_no_nan(log_probs0, 1 - event) +
math.multiply_no_nan(log_probs1, event))
return (math.multiply_no_nan(log_probs0, 1 - value) +
math.multiply_no_nan(log_probs1, value))

def prob(self, event: Array) -> Array:
def prob(self, value: Array) -> Array:
"""See `Distribution.prob`."""
probs1 = self.probs
probs0 = 1 - probs1
return (math.multiply_no_nan(probs0, 1 - event) +
math.multiply_no_nan(probs1, event))
return (math.multiply_no_nan(probs0, 1 - value) +
math.multiply_no_nan(probs1, value))

def entropy(self) -> Array:
"""See `Distribution.entropy`."""
Expand Down
26 changes: 13 additions & 13 deletions distrax/_src/distributions/categorical.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,15 +95,15 @@ def _sample_n(self, key: PRNGKey, n: int) -> Array:
key=key, logits=self.logits, axis=-1, shape=new_shape)
return draws.astype(self._dtype)

def log_prob(self, event: Array) -> Array:
def log_prob(self, value: Array) -> Array:
"""See `Distribution.log_prob`."""
event_one_hot = jax.nn.one_hot(event, self.num_categories)
return jnp.sum(math.multiply_no_nan(self.logits, event_one_hot), axis=-1)
value_one_hot = jax.nn.one_hot(value, self.num_categories)
return jnp.sum(math.multiply_no_nan(self.logits, value_one_hot), axis=-1)

def prob(self, event: Array) -> Array:
def prob(self, value: Array) -> Array:
"""See `Distribution.prob`."""
event_one_hot = jax.nn.one_hot(event, self.num_categories)
return jnp.sum(math.multiply_no_nan(self.probs, event_one_hot), axis=-1)
value_one_hot = jax.nn.one_hot(value, self.num_categories)
return jnp.sum(math.multiply_no_nan(self.probs, value_one_hot), axis=-1)

def entropy(self) -> Array:
"""See `Distribution.entropy`."""
Expand Down Expand Up @@ -131,15 +131,15 @@ def mode(self) -> Array:
parameter = self._probs if self._logits is None else self._logits
return jnp.argmax(parameter, axis=-1).astype(self._dtype)

def cdf(self, event: Array) -> Array:
def cdf(self, value: Array) -> Array:
"""See `Distribution.cdf`."""
# For event < 0 the output should be zero because support = {0, ..., K-1}.
should_be_zero = event < 0
# Will use event as an index below, so clip it to {0, ..., K-1}.
event = jnp.clip(event, 0, self.num_categories - 1)
event_one_hot = jax.nn.one_hot(event, self.num_categories)
# For value < 0 the output should be zero because support = {0, ..., K-1}.
should_be_zero = value < 0
# Will use value as an index below, so clip it to {0, ..., K-1}.
value = jnp.clip(value, 0, self.num_categories - 1)
value_one_hot = jax.nn.one_hot(value, self.num_categories)
cdf = jnp.sum(math.multiply_no_nan(
jnp.cumsum(self.probs, axis=-1), event_one_hot), axis=-1)
jnp.cumsum(self.probs, axis=-1), value_one_hot), axis=-1)
return jnp.where(should_be_zero, jnp.array(0.), cdf)

def logits_parameter(self) -> Array:
Expand Down
Loading

0 comments on commit 547c0db

Please sign in to comment.