diff --git a/use-cases/eurac/trainer.py b/use-cases/eurac/trainer.py index abe85471..88ac42f5 100644 --- a/use-cases/eurac/trainer.py +++ b/use-cases/eurac/trainer.py @@ -29,6 +29,7 @@ from itwinai.torch.type import Metric from itwinai.torch.profiling.profiler import profile_torch_trainer + class RNNDistributedTrainer(TorchTrainer): """Trainer class for RNN model using pytorch. @@ -90,13 +91,12 @@ def __init__( ) self.save_parameters(**self.locals2params(locals())) - @suppress_workers_print @profile_torch_trainer def execute( - self, - train_dataset: Dataset, - validation_dataset: Optional[Dataset] = None, + self, + train_dataset: Dataset, + validation_dataset: Optional[Dataset] = None, test_dataset: Optional[Dataset] = None ) -> Tuple[Dataset, Dataset, Dataset, Any]: return super().execute(train_dataset, validation_dataset, test_dataset) @@ -138,8 +138,8 @@ def create_model_loss_optimizer(self) -> None: **distribute_kwargs, ) - def set_epoch(self, epoch: int): - if self.profiler is not None: + def set_epoch(self, epoch: int): + if self.profiler is not None: self.profiler.step() if self.strategy.is_distributed: @@ -180,7 +180,6 @@ def train(self): for epoch in tqdm(range(self.epochs)): epoch_start_time = timer() self.set_epoch(epoch) - self.model.train() # set time indices for training @@ -386,7 +385,6 @@ def create_model_loss_optimizer(self) -> None: patience=self.config.lr_reduction_patience ) - target_weights = { t: 1 / len(self.config.target_names) for t in self.config.target_names } @@ -507,7 +505,7 @@ def create_dataloaders(self, train_dataset, validation_dataset, test_dataset): processing=( "multi-gpu" if self.config.distributed else "single-gpu" ), - ) + ) val_sampler_builder = SamplerBuilder( validation_dataset, diff --git a/use-cases/virgo/trainer.py b/use-cases/virgo/trainer.py index 8226e070..5fd2a3f9 100644 --- a/use-cases/virgo/trainer.py +++ b/use-cases/virgo/trainer.py @@ -43,7 +43,7 @@ def __init__( ) -> None: super().__init__( epochs=num_epochs, - config={}, + config=config, strategy=strategy, logger=logger, random_seed=random_seed,