jax.Array must be fully replicated to be saved in aggregate file #16343
-
I'm trying to save a checkpoint in Flax and getting this error message. Saving code :
What could be causing this? I tried disabling jax.Array with Edit : state is an instance of flax.training.train_state |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Looks like this belongs in the flax repo? Can you file an issue here: https://github.com/google/flax? |
Beta Was this translation helpful? Give feedback.
Looks like this belongs in the flax repo? Can you file an issue here: https://github.com/google/flax?