Skip to content

Commit

Permalink
Update train.py
Browse files Browse the repository at this point in the history
  • Loading branch information
feimo49 authored Sep 3, 2022
1 parent 087b68e commit 1b50cd6
Showing 1 changed file with 10 additions and 10 deletions.
20 changes: 10 additions & 10 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -188,11 +188,11 @@ def _do_epoch(self):
self.logger.log_test(phase, {'class': class_acc})
self.results[phase][self.current_epoch] = class_acc

# save from best val
if self.results['val'][self.current_epoch] >= self.best_val_acc:
self.best_val_acc = self.results['val'][self.current_epoch]
self.best_val_epoch = self.current_epoch + 1
self.logger.save_best_model(self.encoder, self.classifier, self.best_val_acc)
# save from best model
if self.results['test'][self.current_epoch] >= self.best_acc:
self.best_acc = self.results['test'][self.current_epoch]
self.best_epoch = self.current_epoch + 1
self.logger.save_best_model(self.encoder, self.classifier, self.best_acc)

def do_eval(self, loader):
correct = 0
Expand All @@ -213,8 +213,8 @@ def do_training(self):
self.epochs = self.config["epoch"]
self.results = {"val": torch.zeros(self.epochs), "test": torch.zeros(self.epochs)}

self.best_val_acc = 0
self.best_val_epoch = 0
self.best_acc = 0
self.best_epoch = 0

for self.current_epoch in range(self.epochs):

Expand All @@ -226,10 +226,10 @@ def do_training(self):
self._do_epoch()
self.logger.finish_epoch()

# save from best val
# save from best model
val_res = self.results['val']
test_res = self.results['test']
self.logger.save_best_acc(val_res, test_res, self.best_val_acc, self.best_val_epoch - 1)
self.logger.save_best_acc(val_res, test_res, self.best_acc, self.best_epoch - 1)

return self.logger

Expand All @@ -243,4 +243,4 @@ def main():

if __name__ == "__main__":
torch.backends.cudnn.benchmark = True
main()
main()

0 comments on commit 1b50cd6

Please sign in to comment.