Skip to content

Commit fb2f2d7

Browse files
author
Jakub Pieszczek
committed
save state_dict.pt as a separate file
1 parent d10780d commit fb2f2d7

File tree

1 file changed

+5
-0
lines changed

1 file changed

+5
-0
lines changed

timm/utils/checkpoint_saver.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,7 +167,12 @@ def save_checkpoint(self, epoch, metric=None):
167167
"model_kwargs": self.args.model_kwargs,
168168
}
169169
torch.save(model_dict, temp_location)
170+
torch.save(
171+
get_state_dict(self.model, self.unwrap_fn),
172+
os.path.join(temp_dir, "state_dict.pt"),
173+
)
170174
mlflow.log_artifact(temp_location)
175+
mlflow.log_artifact(os.path.join(temp_dir, "state_dict.pt"))
171176

172177
return (None, None) if self.best_metric is None else (self.best_metric, self.best_epoch)
173178

0 commit comments

Comments
 (0)