diff --git a/llamax/integration_test.py b/llamax/integration_test.py index 82d4f52..a80ab14 100644 --- a/llamax/integration_test.py +++ b/llamax/integration_test.py @@ -125,7 +125,7 @@ def setUpClass(cls): # Convert checkpoint to double precision checkpoint = {k: v.double() for k, v in checkpoint.items()} - + jax.tree.map( lambda x, y: np.testing.assert_array_equal(x.shape, y.shape), dict(cls.torch_model.state_dict()),