diff --git a/model/font_classifier/dataset_font.py b/model/font_classifier/dataset_font.py index 65e1b4453b..fd0c424b77 100644 --- a/model/font_classifier/dataset_font.py +++ b/model/font_classifier/dataset_font.py @@ -12,24 +12,39 @@ class FontDataset(Dataset): - image_paths = [] - image_labels = [] + image_paths_train = [] + image_labels_train = [] - def __init__(self, data_dir, val_ratio=0.2): + image_paths_val = [] + image_labels_val = [] + + def __init__(self, data_dir, val_ratio=0.2, is_train = True): self.data_dir = data_dir - self.val_ratio = val_ratio - self.transform = None - self.setup() + self.is_train = is_train + if is_train: + self.val_ratio = val_ratio + self.setup() + def setup(self): profiles = os.listdir(self.data_dir) for idx, profile in enumerate(profiles): paths = os.listdir(os.path.join(self.data_dir, profile)) + image_path = [] + image_label = [] for path in paths: - self.image_paths.append(os.path.join(self.data_dir,profile,path)) - self.image_labels.append(idx) + image_path.append(os.path.join(self.data_dir,profile,path)) + image_label.append(idx) + tmp_all = set(range(len(image_path))) + tmp_val = set(random.sample(list(range(len(image_path))), int(len(image_path) * self.val_ratio))) + tmp_train = tmp_all - tmp_val + + self.image_paths_train.extend([image_path[x] for x in tmp_train]) + self.image_labels_train.extend([image_label[x] for x in tmp_train]) + self.image_paths_val.extend([image_path[x] for x in tmp_val]) + self.image_labels_val.extend([image_label[x] for x in tmp_val]) def set_transform(self, transform): self.transform = transform @@ -40,27 +55,25 @@ def __getitem__(self, index): image = self.read_image(index) image_transform = self.transform(image) - label = self.image_labels[index] + if self.is_train: + label = self.image_labels_train[index] + else: + label = self.image_labels_val[index] + return image_transform, label def read_image(self, index): - image_path = self.image_paths[index] + if self.is_train: + image_path = self.image_paths_train[index] + else: + image_path = self.image_paths_val[index] return Image.open(image_path).convert('RGB') def __len__(self): - return len(self.image_paths) - - def split_dataset(self) -> Tuple[Subset, Subset]: - """ - 데이터셋을 train 과 val 로 나눕니다, - pytorch 내부의 torch.utils.data.random_split 함수를 사용하여 - torch.utils.data.Subset 클래스 둘로 나눕니다. - 구현이 어렵지 않으니 구글링 혹은 IDE (e.g. pycharm) 의 navigation 기능을 통해 코드를 한 번 읽어보는 것을 추천드립니다^^ - """ - n_val = int(len(self) * self.val_ratio) - n_train = len(self) - n_val - train_set, val_set = random_split(self, [n_train, n_val]) - return train_set, val_set + if self.is_train: + return len(self.image_paths_train) + else: + return len(self.image_paths_val) @@ -70,7 +83,7 @@ def __init__(self, img_paths, resize, mean=(0.548, 0.504, 0.479), std=(0.237, 0. self.transform = transforms.Compose([ transforms.Resize(resize, Image.BILINEAR), transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), + #transforms.Normalize(mean=mean, std=std), ]) def __getitem__(self, index): @@ -91,7 +104,7 @@ def __init__(self, resize, mean, std, **args): self.transform = transforms.Compose([ transforms.Resize(resize, Image.BILINEAR), transforms.ToTensor(), - transforms.Normalize(mean=mean, std=std), + #transforms.Normalize(mean=mean, std=std), ]) def __call__(self, image): diff --git a/model/font_classifier/inference_font.py b/model/font_classifier/inference_font.py index cc2d39bf95..e71632df81 100644 --- a/model/font_classifier/inference_font.py +++ b/model/font_classifier/inference_font.py @@ -9,6 +9,8 @@ from dataset_font import FontDataset +import torch.nn.functional as F + # 경고 off import warnings warnings.filterwarnings(action='ignore') @@ -45,20 +47,22 @@ def inference(data_dir, args): for idx, images in enumerate(loader): images = images.to(device) predict = model(images) - #pred = predict.argmax(dim=-1) - pred_topk = torch.topk(predict, k= 2, dim = -1) + scores = F.softmax(predict.data, dim=1) + pred_topk = torch.topk(scores, k= 2, dim = -1) pred = pred_topk.indices value = pred_topk.values preds.extend(pred.cpu().numpy()) values.extend(value.cpu().numpy()) + + labels_list = os.listdir(args.train_data_dir) print(labels_list) for i, pred in enumerate(preds): print(f"#{i}_{image_names[i]}") for j, idx in enumerate(pred): - print(f"{labels_list[idx]}, {values[i][j]}") + print(f"{labels_list[idx]}, {values[i][j]:.2%}") if __name__ == '__main__': diff --git a/model/font_classifier/train_font.py b/model/font_classifier/train_font.py index 2bce17ab6a..c3a07763c5 100644 --- a/model/font_classifier/train_font.py +++ b/model/font_classifier/train_font.py @@ -15,7 +15,6 @@ from scheduler import scheduler_module from torch.utils.data import DataLoader -from torch.utils.tensorboard import SummaryWriter from PIL import Image from loss import create_criterion @@ -26,6 +25,8 @@ import matplotlib.pyplot as plt from tqdm import tqdm +import torch.nn.functional as F + # 경고 off import warnings warnings.filterwarnings(action='ignore') @@ -63,40 +64,86 @@ def increment_path(path, exist_ok=False): n = max(i) + 1 if i else 2 return f"{path}{n}" +def createDirectory(directory): + try: + if not os.path.exists(directory): + os.makedirs(directory) + except OSError: + print("Error: Failed to create the directory.") + +# convenience funtion to log predictions for a batch of test images +def log_test_predictions(images, labels, outputs, predicted, test_table, labels_list): + # obtain confidence scores for all classes + scores = F.softmax(outputs.data, dim=1) + log_scores = scores.cpu().numpy() + log_images = images.cpu().numpy() + log_labels = labels.cpu().numpy() + log_preds = predicted.cpu().numpy() + # adding ids based on the order of the images + _id = 0 + for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores): + # add required info to data table: + # id, image pixels, model's guess, true label, scores for all classes + img_id = str(_id) + i = np.transpose(i, (1, 2, 0)) + # IMAGENET_MEAN, IMAGENET_STD = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225]) + # i = np.clip(255.0 * (i * IMAGENET_STD + IMAGENET_MEAN), 0, 255) + i = i.astype(np.uint8).copy() + test_table.add_data(img_id, wandb.Image(i), p, labels_list[p], l, labels_list[l], *s) + _id += 1 + # -- train def train(data_dir, model_dir, args): seed_everything(args.seed) save_dir = increment_path(os.path.join(model_dir, args.name)) + createDirectory(save_dir) + with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f: + json.dump(vars(args), f, ensure_ascii=False, indent=4) + use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") dataset_module = getattr(import_module("dataset_font"), args.dataset) - dataset = dataset_module( + dataset_train = dataset_module( data_dir=data_dir, - val_ratio = args.val_ratio + val_ratio = args.val_ratio, + is_train = True ) num_classes = len(os.listdir(args.data_dir)) # font의 개수 # -- augmentation - transform_module = getattr(import_module("dataset_font"), args.augmentation) # default: BaseAugmentation + transform_module = getattr(import_module("dataset_font"), args.train_augmentation) # default: BaseAugmentation transform = transform_module( resize=args.resize, mean=(0.548, 0.504, 0.479), std=(0.237, 0.247, 0.246) ) + dataset_train.set_transform(transform) + + + dataset_val = dataset_module( + data_dir=data_dir, + val_ratio = args.val_ratio, + is_train = False + ) + + transform = transform_module( + resize=args.resize, + mean=(0.548, 0.504, 0.479), + std=(0.237, 0.247, 0.246) + ) + dataset_val.set_transform(transform) + - # -- data_loader & sampler - dataset.set_transform(transform) - train_set, val_set = dataset.split_dataset() train_loader = DataLoader( - train_set, + dataset_train, batch_size=args.batch_size, num_workers=multiprocessing.cpu_count() // 2, shuffle=True, @@ -105,7 +152,7 @@ def train(data_dir, model_dir, args): ) val_loader = DataLoader( - val_set, + dataset_val, batch_size=args.valid_batch_size, num_workers=multiprocessing.cpu_count() // 2, shuffle=False, @@ -135,6 +182,9 @@ def train(data_dir, model_dir, args): if args.scheduler != "None": scheduler = scheduler_module.get_scheduler(scheduler_module,args.scheduler, optimizer) + # labels_list + labels_list = os.listdir(args.data_dir) + for epoch in tqdm(range(args.epochs)): # -- train loop @@ -143,6 +193,7 @@ def train(data_dir, model_dir, args): matches = 0 loss_value_sum = 0 train_acc_sum = 0 + current_lr = get_lr(optimizer) for idx, train_batch in enumerate(train_loader): inputs, labels = train_batch inputs = inputs.to(device) @@ -162,7 +213,6 @@ def train(data_dir, model_dir, args): if (idx + 1) % args.log_interval == 0: train_loss = loss_value / (idx +1) train_acc = matches / args.batch_size / (idx +1) - current_lr = get_lr(optimizer) print( f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || " f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}" @@ -174,12 +224,12 @@ def train(data_dir, model_dir, args): "step" : epoch * len(train_loader) + idx + 1 }) - + loss_value_sum = loss_value / args.log_interval train_acc_sum = matches / args.batch_size / len(train_loader) wandb.log({ "Train/loss_epoch": loss_value_sum, - "Train/accuracy_epcoh": train_acc_sum, + "Train/accuracy_epoch": train_acc_sum, "lr": current_lr, "epoch": epoch }) @@ -188,13 +238,22 @@ def train(data_dir, model_dir, args): # val loop with torch.no_grad(): print("Calculating validation results...") + + # wandb artifacts + val_data_at = wandb.Artifact("val_samples_" + str(wandb.run.id), type="predictions") + columns = ["id", "image", "guess_num" ,"guess", "truth_num", "truth"] + for classtype in range(num_classes): + columns.append("score_" + str(classtype)) + + val_table = wandb.Table(columns=columns) + model.eval() val_loss_items = [] val_acc_items = [] preds_expand = torch.tensor([]) labels_expand = torch.tensor([]) - for val_batch in val_loader: + for val_batch in tqdm(val_loader): inputs, labels = val_batch inputs = inputs.to(device) labels = labels.to(device) @@ -210,12 +269,14 @@ def train(data_dir, model_dir, args): preds_expand = torch.cat((preds_expand, preds.detach().cpu()),-1) labels_expand = torch.cat((labels_expand, labels.detach().cpu()),-1) + if (epoch + 1) % args.save_interval == 0: + log_test_predictions(inputs, labels, outs, preds, val_table, labels_list) # -- evaluation f1 = MulticlassF1Score(num_classes=num_classes) f1_score = f1(preds_expand.type(torch.LongTensor), labels_expand.type(torch.LongTensor)).item() - val_loss = np.sum(val_loss_items) / len(val_set) - val_acc = np.sum(val_acc_items) / len(val_set) + val_loss = np.sum(val_loss_items) / len(dataset_val) + val_acc = np.sum(val_acc_items) / len(dataset_val) print(f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2}, f1: {f1_score:4.4} ") @@ -229,6 +290,10 @@ def train(data_dir, model_dir, args): # model save torch.save(model, f"{save_dir}/{epoch}.pth") + # wandb artifact + val_data_at.add(val_table, "predictions") + wandb.run.log_artifact(val_data_at) + # --scheduler if args.scheduler != "None": scheduler.step() @@ -241,14 +306,15 @@ def train(data_dir, model_dir, args): parser.add_argument('--seed', type=int, default=42, help='random seed (default: 42)') parser.add_argument('--epochs', type=int, default=200, help='number of epochs to train (default: 200)') parser.add_argument('--dataset', type=str, default='FontDataset', help='dataset augmentation type (default: Ma skBaseDataset)') - parser.add_argument('--augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)') ## + parser.add_argument('--train_augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)') + parser.add_argument('--val_augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)') parser.add_argument("--resize", nargs="+", type=int, default=[256, 256], help='resize size for image when training') parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training (default: 64)') - parser.add_argument('--valid_batch_size', type=int, default=10, help='input batch size for validing (default: 1000)') + parser.add_argument('--valid_batch_size', type=int, default=100, help='input batch size for validing (default: 1000)') parser.add_argument('--model', type=str, default='ResNet50', help='model type (default: ResNet50)') parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer type (default: Adam)') parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)') - parser.add_argument('--val_ratio', type=float, default=0.2, help='ratio for validaton (default: 0.2)') + parser.add_argument('--val_ratio', type=float, default=0.001, help='ratio for validaton (default: 0.2)') parser.add_argument('--criterion', type=str, default='cross_entropy', help='criterion type (default: cross_entropy)') parser.add_argument('--log_interval', type=int, default=20, help='how many batches to wait before logging training status') parser.add_argument('--name', default='exp', help='model save at {SM_MODEL_DIR}/{name}')