Skip to content

Commit

Permalink
return total_loss moved from training_step to validation_step
Browse files Browse the repository at this point in the history
  • Loading branch information
NiranjanChaudhari0929 committed Dec 10, 2024
1 parent b402f9b commit 1bde7b2
Showing 1 changed file with 7 additions and 11 deletions.
18 changes: 7 additions & 11 deletions src/lmcontrol/nn/clf.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,24 +163,20 @@ def training_step(self, batch, batch_idx):
images, labels = batch
outputs = self.forward(images)
loss_components = self.criterion(outputs, labels)

self._score("train", outputs, loss_components, labels)
total_loss = sum(loss_components)
self.log('total_train_loss', total_loss, on_step=False, on_epoch=True)

return total_loss




def validation_step(self, batch, batch_idx):
images, labels = batch
outputs = self.forward(images)
loss_components = self.criterion(outputs, labels)

self._score("val", outputs, loss_components, labels)

total_loss = sum(loss_components)
self.log('total_val_loss', total_loss, on_step=False, on_epoch=True)

return total_loss

def configure_optimizers(self):
optimizer = torch.optim.Adam(self.parameters(), lr=self.lr)
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, milestones=[10, 20, 30], gamma=self.gamma)
Expand Down Expand Up @@ -258,7 +254,7 @@ def _get_loaders_and_model(args, logger=None):
logger.info(f"Loading validation data from: {len(split_files)} files")
val_dataset = LMDataset(split_files, label_classes=train_dataset.label_classes, transform=transform_val, logger=logger, return_labels=True, label_type=args.labels, n_samples=n, return_embeddings=args.return_embeddings, split='validate', val_size=args.val_frac, seed=args.seed)
for i in range(len(val_dataset.labels)):
current_labels = train_dataset.labels[i]
current_labels = val_dataset.labels[i]
logger.info(val_dataset.label_type[i] + " - " + str(torch.unique(current_labels)))

elif args.validation:
Expand All @@ -280,7 +276,7 @@ def _get_loaders_and_model(args, logger=None):
logger.info(f"Loading validation data: {len(val_files)} files")
val_dataset = LMDataset(val_files, label_classes=train_dataset.label_classes, transform=transform_val, logger=logger, return_labels=True, label_type=args.labels, n_samples=n, return_embeddings=args.return_embeddings)
for i in range(len(val_dataset.labels)):
current_labels = train_dataset.labels[i]
current_labels = val_dataset.labels[i]
logger.info(val_dataset.label_type[i] + " - " + str(torch.unique(current_labels)))

else:
Expand All @@ -305,7 +301,7 @@ def _get_trainer(args, trial=None):
callbacks = []

targs = dict(max_epochs=args.epochs, devices=1, accelerator=accelerator, check_val_every_n_epoch=4, callbacks=callbacks)

if args.checkpoint:
checkpoint_callback = ModelCheckpoint(
dirpath=args.checkpoint,
Expand Down

0 comments on commit 1bde7b2

Please sign in to comment.