diff --git a/sleap_nn/training/model_trainer.py b/sleap_nn/training/model_trainer.py index 8e557d5b..c1f5b629 100644 --- a/sleap_nn/training/model_trainer.py +++ b/sleap_nn/training/model_trainer.py @@ -681,6 +681,11 @@ def train( ) try: + if self.config.trainer_config.use_wandb: + wandb_logger.experiment.config.update({"run_name": wandb_config.name}) + wandb_logger.experiment.config.update(dict(self.config)) + wandb_logger.experiment.config.update({"model_params": total_params}) + self.trainer.fit( self.model, self.train_data_loader, @@ -688,10 +693,6 @@ def train( ckpt_path=self.config.trainer_config.resume_ckpt_path, ) - if self.config.trainer_config.use_wandb: - wandb_logger.experiment.config.update(dict(self.config)) - wandb_logger.experiment.config.update({"model_params": total_params}) - except KeyboardInterrupt: print("Stopping training...")