Skip to content

Commit

Permalink
fix diffs
Browse files Browse the repository at this point in the history
  • Loading branch information
matbun committed Oct 16, 2024
1 parent b48beaf commit c877bfc
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 10 deletions.
16 changes: 7 additions & 9 deletions use-cases/eurac/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@
from itwinai.torch.type import Metric
from itwinai.torch.profiling.profiler import profile_torch_trainer


class RNNDistributedTrainer(TorchTrainer):
"""Trainer class for RNN model using pytorch.
Expand Down Expand Up @@ -90,13 +91,12 @@ def __init__(
)
self.save_parameters(**self.locals2params(locals()))


@suppress_workers_print
@profile_torch_trainer
def execute(
self,
train_dataset: Dataset,
validation_dataset: Optional[Dataset] = None,
self,
train_dataset: Dataset,
validation_dataset: Optional[Dataset] = None,
test_dataset: Optional[Dataset] = None
) -> Tuple[Dataset, Dataset, Dataset, Any]:
return super().execute(train_dataset, validation_dataset, test_dataset)
Expand Down Expand Up @@ -138,8 +138,8 @@ def create_model_loss_optimizer(self) -> None:
**distribute_kwargs,
)

def set_epoch(self, epoch: int):
if self.profiler is not None:
def set_epoch(self, epoch: int):
if self.profiler is not None:
self.profiler.step()

if self.strategy.is_distributed:
Expand Down Expand Up @@ -180,7 +180,6 @@ def train(self):
for epoch in tqdm(range(self.epochs)):
epoch_start_time = timer()
self.set_epoch(epoch)

self.model.train()

# set time indices for training
Expand Down Expand Up @@ -386,7 +385,6 @@ def create_model_loss_optimizer(self) -> None:
patience=self.config.lr_reduction_patience
)


target_weights = {
t: 1 / len(self.config.target_names) for t in self.config.target_names
}
Expand Down Expand Up @@ -507,7 +505,7 @@ def create_dataloaders(self, train_dataset, validation_dataset, test_dataset):
processing=(
"multi-gpu" if self.config.distributed else "single-gpu"
),
)
)

val_sampler_builder = SamplerBuilder(
validation_dataset,
Expand Down
2 changes: 1 addition & 1 deletion use-cases/virgo/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ def __init__(
) -> None:
super().__init__(
epochs=num_epochs,
config={},
config=config,
strategy=strategy,
logger=logger,
random_seed=random_seed,
Expand Down

0 comments on commit c877bfc

Please sign in to comment.