diff --git a/keras/src/backend/jax/distribution_lib.py b/keras/src/backend/jax/distribution_lib.py index 124027938dc..5dc5c057d29 100644 --- a/keras/src/backend/jax/distribution_lib.py +++ b/keras/src/backend/jax/distribution_lib.py @@ -118,7 +118,6 @@ def distribute_data_input(per_process_batch, layout, batch_dim_name): layout = _to_jax_layout(layout) num_model_replicas_total = layout.mesh.shape[batch_dim_name] - mesh_shape = list(layout.mesh.shape.values()) mesh_model_dim_size = 1 for name, dim_size in layout.mesh.shape.items(): diff --git a/keras/src/backend/jax/distribution_lib_test.py b/keras/src/backend/jax/distribution_lib_test.py index 2605f058487..81ceddfd305 100644 --- a/keras/src/backend/jax/distribution_lib_test.py +++ b/keras/src/backend/jax/distribution_lib_test.py @@ -337,7 +337,9 @@ def test_distribute_data_input(self): mesh, jax.sharding.PartitionSpec("batch", None) ) - result = backend_dlib.distribute_data_input(per_process_batch, layout, "batch") + result = backend_dlib.distribute_data_input( + per_process_batch, layout, "batch" + ) # Check the shape of the global batch array self.assertEqual(