diff --git a/raygun/jax/tests/network_test_jax.py b/raygun/jax/tests/network_test_jax.py index 37b3f139..21cbd981 100644 --- a/raygun/jax/tests/network_test_jax.py +++ b/raygun/jax/tests/network_test_jax.py @@ -175,7 +175,7 @@ def create_network(): 'gt': gt, 'mask': mask, } - rng = jax.random.PRNGKey(42) + rng: KeyArray= jax.random.PRNGKey(42) # init model if n_devices > 1: @@ -186,7 +186,7 @@ def create_network(): single_device_inputs = { 'raw': raw, 'gt': gt, - 'mask': mask, + 'mask': mask } rng = jnp.broadcast_to(rng, (n_devices,) + rng.shape) model_params = jax.pmap(my_model.initialize)(rng, single_device_inputs)