Skip to content

Commit

Permalink
Debug network_test_jax.py #13
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 10, 2022
1 parent 90a5df2 commit 369656d
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions raygun/jax/tests/network_test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ def create_network():
'gt': gt,
'mask': mask,
}
rng = jax.random.PRNGKey(42)
rng: KeyArray= jax.random.PRNGKey(42)

# init model
if n_devices > 1:
Expand All @@ -186,7 +186,7 @@ def create_network():
single_device_inputs = {
'raw': raw,
'gt': gt,
'mask': mask,
'mask': mask
}
rng = jnp.broadcast_to(rng, (n_devices,) + rng.shape)
model_params = jax.pmap(my_model.initialize)(rng, single_device_inputs)
Expand Down

0 comments on commit 369656d

Please sign in to comment.