diff --git a/scOT/train.py b/scOT/train.py index 1d6967f..e2ba209 100644 --- a/scOT/train.py +++ b/scOT/train.py @@ -407,7 +407,8 @@ def get_statistics(errors): ) trainer.train(resume_from_checkpoint=params.resume_training) - + trainer.save_model(train_config.output_dir) + if (RANK == 0 or RANK == -1) and params.push_to_hf_hub is not None: model.push_to_hub(params.push_to_hf_hub)