Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove usages of orbax_utils.save_args_from_target, as this function does nothing (it used to control a checkpointing behavior that has since been optimized away). #4482

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 2 additions & 4 deletions flax/training/checkpoints.py
Original file line number Diff line number Diff line change
Expand Up @@ -690,9 +690,8 @@ def save_checkpoint(
' https://flax.readthedocs.io/en/latest/guides/training_techniques/use_checkpointing.html#if-you-don-t-save-pytrees'
)

save_args = orbax_utils.save_args_from_target(target)
orbax_checkpointer.save(
ckpt_path, target, save_args=save_args, force=overwrite
ckpt_path, target, force=overwrite
)
# Do a process check here in case people call this for multihost.
if process_index() == 0:
Expand Down Expand Up @@ -843,9 +842,8 @@ def save_checkpoint_multiprocess(
_remove_invalid_ckpts(
ckpt_path, base_path, keep, overwrite, keep_every_n_steps, True
)
save_args = orbax_utils.save_args_from_target(target)
orbax_checkpointer.save(
ckpt_path, target, save_args=save_args, force=overwrite
ckpt_path, target, force=overwrite
)
end_time = time.time()
monitoring.record_event_duration_secs(
Expand Down
Loading