Skip to content

Commit

Permalink
feat: connect wandb
Browse files Browse the repository at this point in the history
Referenced Issue: #2
  • Loading branch information
jerry-ryu committed Jan 9, 2023
1 parent 816d897 commit deb8edb
Showing 1 changed file with 29 additions and 30 deletions.
59 changes: 29 additions & 30 deletions model/font_classifier/train_font.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
import matplotlib.pyplot as plt
import numpy as np
import torch
import wandb

from scheduler import scheduler_module
from torch.utils.data import DataLoader
Expand Down Expand Up @@ -133,25 +134,6 @@ def train(data_dir, model_dir, args):
# --scheduler
if args.scheduler != "None":
scheduler = scheduler_module.get_scheduler(scheduler_module,args.scheduler, optimizer)

# -- logging
logger = SummaryWriter(log_dir=save_dir)
with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
json.dump(vars(args), f, ensure_ascii=False, indent=4)

layout = {
"Train_Val": {
"accuracy": ["Multiline", ["Train/accuracy_epcoh", "Val/accuracy"]],
"f1_score": ["Multiline", ["Val/f1_score"]],
"loss": ["Multiline", ['Train/loss_epoch', 'Val/loss']]
},
}

logger.add_custom_scalars(layout)

best_val_acc = 0
best_val_loss = np.inf
best_f1_score = 0

for epoch in tqdm(range(args.epochs)):

Expand Down Expand Up @@ -186,15 +168,21 @@ def train(data_dir, model_dir, args):
f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
)
# logs
# logger.add_scalar("Train/loss", train_loss, epoch * len(train_loader) + idx)
# logger.add_scalar("Train/accuracy", train_acc, epoch * len(train_loader) + idx)
# logger.add_scalar("lr", current_lr, epoch)
wandb.log({
"Train/loss": train_loss,
"Train/accuracy": train_acc,
"step" : epoch * len(train_loader) + idx + 1
})


loss_value_sum = loss_value / args.log_interval
train_acc_sum = matches / args.batch_size / len(train_loader)
logger.add_scalar("Train/loss_epoch", loss_value_sum, epoch)
logger.add_scalar("Train/accuracy_epcoh", train_acc_sum, epoch)
wandb.log({
"Train/loss_epoch": loss_value_sum,
"Train/accuracy_epcoh": train_acc_sum,
"lr": current_lr,
"epoch": epoch
})


# val loop
Expand Down Expand Up @@ -231,9 +219,11 @@ def train(data_dir, model_dir, args):

print(f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2}, f1: {f1_score:4.4} ")

logger.add_scalar("Val/loss", val_loss, epoch)
logger.add_scalar("Val/accuracy", val_acc, epoch)
logger.add_scalar("Val/f1_score", f1_score, epoch)
wandb.log({
"Val/loss": val_loss,
"Val/accuracy": val_acc,
"Val/f1_score": f1_score
})

if (epoch + 1) % args.save_interval == 0:
# model save
Expand Down Expand Up @@ -266,13 +256,22 @@ def train(data_dir, model_dir, args):
parser.add_argument('--save_interval', type=int, default=5, help='how many epochs to wait before save pth')

# Container environment
parser.add_argument('--data_dir', type=str, default='/opt/level3_productserving-level3-cv-11/data/words/ko/images')
parser.add_argument('--data_dir', type=str, default='/opt/level3_productserving-level3-cv-11/data/words/ko/typical_images')
parser.add_argument('--model_dir', type=str, default = './experiment/classifier')


# wandb
parser.add_argument('--tags', default= None, nargs='+',type=str, help = "프로젝트 태그 할당")
args = parser.parse_args()
print(args)

data_dir = args.data_dir
model_dir = args.model_dir


wandb.init(entity = "miho",
project = "Final-Project",
sync_tensorboard=True,
name = args.name,
tags = args.tags)
wandb.config.update(args)

train(data_dir, model_dir, args)

0 comments on commit deb8edb

Please sign in to comment.