diff --git a/src/lmcontrol/nn/clf.py b/src/lmcontrol/nn/clf.py index 7810bd7..99e36dd 100644 --- a/src/lmcontrol/nn/clf.py +++ b/src/lmcontrol/nn/clf.py @@ -163,24 +163,20 @@ def training_step(self, batch, batch_idx): images, labels = batch outputs = self.forward(images) loss_components = self.criterion(outputs, labels) - self._score("train", outputs, loss_components, labels) total_loss = sum(loss_components) self.log('total_train_loss', total_loss, on_step=False, on_epoch=True) - - return total_loss - - + + def validation_step(self, batch, batch_idx): images, labels = batch outputs = self.forward(images) loss_components = self.criterion(outputs, labels) - self._score("val", outputs, loss_components, labels) - total_loss = sum(loss_components) self.log('total_val_loss', total_loss, on_step=False, on_epoch=True) - + return total_loss + def configure_optimizers(self): optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=self.gamma) @@ -258,7 +254,7 @@ def _get_loaders_and_model(args, logger=None): logger.info(f"Loading validation data from: {len(split_files)} files") val_dataset = LMDataset(split_files, label_classes=train_dataset.label_classes, transform=transform_val, logger=logger, return_labels=True, label_type=args.labels, n_samples=n, return_embeddings=args.return_embeddings, split='validate', val_size=args.val_frac, seed=args.seed) for i in range(len(val_dataset.labels)): - current_labels = train_dataset.labels[i] + current_labels = val_dataset.labels[i] logger.info(val_dataset.label_type[i] + " - " + str(torch.unique(current_labels))) elif args.validation: @@ -280,7 +276,7 @@ def _get_loaders_and_model(args, logger=None): logger.info(f"Loading validation data: {len(val_files)} files") val_dataset = LMDataset(val_files, label_classes=train_dataset.label_classes, transform=transform_val, logger=logger, return_labels=True, label_type=args.labels, n_samples=n, return_embeddings=args.return_embeddings) for i in range(len(val_dataset.labels)): - current_labels = train_dataset.labels[i] + current_labels = val_dataset.labels[i] logger.info(val_dataset.label_type[i] + " - " + str(torch.unique(current_labels))) else: @@ -305,7 +301,7 @@ def _get_trainer(args, trial=None): callbacks = [] targs = dict(max_epochs=args.epochs, devices=1, accelerator=accelerator, check_val_every_n_epoch=4, callbacks=callbacks) - + if args.checkpoint: checkpoint_callback = ModelCheckpoint( dirpath=args.checkpoint,