Skip to content

Commit

Permalink
Merge pull request #7 from boostcampaitech4lv23cv2/Develop/Font_Class…
Browse files Browse the repository at this point in the history
…ifier

[FEAT] wandb 연결 & 시각화, Develop/font classifier -> master
  • Loading branch information
jane79 authored Jan 10, 2023
2 parents 936a105 + 2daa307 commit beb7f96
Show file tree
Hide file tree
Showing 3 changed files with 129 additions and 46 deletions.
63 changes: 38 additions & 25 deletions model/font_classifier/dataset_font.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,39 @@

class FontDataset(Dataset):

image_paths = []
image_labels = []
image_paths_train = []
image_labels_train = []

def __init__(self, data_dir, val_ratio=0.2):
image_paths_val = []
image_labels_val = []

def __init__(self, data_dir, val_ratio=0.2, is_train = True):
self.data_dir = data_dir
self.val_ratio = val_ratio

self.transform = None
self.setup()
self.is_train = is_train
if is_train:
self.val_ratio = val_ratio
self.setup()


def setup(self):
profiles = os.listdir(self.data_dir)

for idx, profile in enumerate(profiles):
paths = os.listdir(os.path.join(self.data_dir, profile))
image_path = []
image_label = []
for path in paths:
self.image_paths.append(os.path.join(self.data_dir,profile,path))
self.image_labels.append(idx)
image_path.append(os.path.join(self.data_dir,profile,path))
image_label.append(idx)
tmp_all = set(range(len(image_path)))
tmp_val = set(random.sample(list(range(len(image_path))), int(len(image_path) * self.val_ratio)))
tmp_train = tmp_all - tmp_val

self.image_paths_train.extend([image_path[x] for x in tmp_train])
self.image_labels_train.extend([image_label[x] for x in tmp_train])
self.image_paths_val.extend([image_path[x] for x in tmp_val])
self.image_labels_val.extend([image_label[x] for x in tmp_val])

def set_transform(self, transform):
self.transform = transform
Expand All @@ -40,27 +55,25 @@ def __getitem__(self, index):
image = self.read_image(index)
image_transform = self.transform(image)

label = self.image_labels[index]
if self.is_train:
label = self.image_labels_train[index]
else:
label = self.image_labels_val[index]

return image_transform, label

def read_image(self, index):
image_path = self.image_paths[index]
if self.is_train:
image_path = self.image_paths_train[index]
else:
image_path = self.image_paths_val[index]
return Image.open(image_path).convert('RGB')

def __len__(self):
return len(self.image_paths)

def split_dataset(self) -> Tuple[Subset, Subset]:
"""
데이터셋을 train 과 val 로 나눕니다,
pytorch 내부의 torch.utils.data.random_split 함수를 사용하여
torch.utils.data.Subset 클래스 둘로 나눕니다.
구현이 어렵지 않으니 구글링 혹은 IDE (e.g. pycharm) 의 navigation 기능을 통해 코드를 한 번 읽어보는 것을 추천드립니다^^
"""
n_val = int(len(self) * self.val_ratio)
n_train = len(self) - n_val
train_set, val_set = random_split(self, [n_train, n_val])
return train_set, val_set
if self.is_train:
return len(self.image_paths_train)
else:
return len(self.image_paths_val)



Expand All @@ -70,7 +83,7 @@ def __init__(self, img_paths, resize, mean=(0.548, 0.504, 0.479), std=(0.237, 0.
self.transform = transforms.Compose([
transforms.Resize(resize, Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
#transforms.Normalize(mean=mean, std=std),
])

def __getitem__(self, index):
Expand All @@ -91,7 +104,7 @@ def __init__(self, resize, mean, std, **args):
self.transform = transforms.Compose([
transforms.Resize(resize, Image.BILINEAR),
transforms.ToTensor(),
transforms.Normalize(mean=mean, std=std),
#transforms.Normalize(mean=mean, std=std),
])

def __call__(self, image):
Expand Down
10 changes: 7 additions & 3 deletions model/font_classifier/inference_font.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@

from dataset_font import FontDataset

import torch.nn.functional as F

# 경고 off
import warnings
warnings.filterwarnings(action='ignore')
Expand Down Expand Up @@ -45,20 +47,22 @@ def inference(data_dir, args):
for idx, images in enumerate(loader):
images = images.to(device)
predict = model(images)
#pred = predict.argmax(dim=-1)
pred_topk = torch.topk(predict, k= 2, dim = -1)
scores = F.softmax(predict.data, dim=1)
pred_topk = torch.topk(scores, k= 2, dim = -1)
pred = pred_topk.indices
value = pred_topk.values
preds.extend(pred.cpu().numpy())
values.extend(value.cpu().numpy())



labels_list = os.listdir(args.train_data_dir)
print(labels_list)

for i, pred in enumerate(preds):
print(f"#{i}_{image_names[i]}")
for j, idx in enumerate(pred):
print(f"{labels_list[idx]}, {values[i][j]}")
print(f"{labels_list[idx]}, {values[i][j]:.2%}")


if __name__ == '__main__':
Expand Down
102 changes: 84 additions & 18 deletions model/font_classifier/train_font.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@

from scheduler import scheduler_module
from torch.utils.data import DataLoader
from torch.utils.tensorboard import SummaryWriter

from PIL import Image
from loss import create_criterion
Expand All @@ -26,6 +25,8 @@
import matplotlib.pyplot as plt
from tqdm import tqdm

import torch.nn.functional as F

# 경고 off
import warnings
warnings.filterwarnings(action='ignore')
Expand Down Expand Up @@ -63,40 +64,86 @@ def increment_path(path, exist_ok=False):
n = max(i) + 1 if i else 2
return f"{path}{n}"

def createDirectory(directory):
try:
if not os.path.exists(directory):
os.makedirs(directory)
except OSError:
print("Error: Failed to create the directory.")

# convenience funtion to log predictions for a batch of test images
def log_test_predictions(images, labels, outputs, predicted, test_table, labels_list):
# obtain confidence scores for all classes
scores = F.softmax(outputs.data, dim=1)
log_scores = scores.cpu().numpy()
log_images = images.cpu().numpy()
log_labels = labels.cpu().numpy()
log_preds = predicted.cpu().numpy()
# adding ids based on the order of the images
_id = 0
for i, l, p, s in zip(log_images, log_labels, log_preds, log_scores):
# add required info to data table:
# id, image pixels, model's guess, true label, scores for all classes
img_id = str(_id)
i = np.transpose(i, (1, 2, 0))
# IMAGENET_MEAN, IMAGENET_STD = np.array([0.485, 0.456, 0.406]), np.array([0.229, 0.224, 0.225])
# i = np.clip(255.0 * (i * IMAGENET_STD + IMAGENET_MEAN), 0, 255)
i = i.astype(np.uint8).copy()
test_table.add_data(img_id, wandb.Image(i), p, labels_list[p], l, labels_list[l], *s)
_id += 1


# -- train
def train(data_dir, model_dir, args):
seed_everything(args.seed)

save_dir = increment_path(os.path.join(model_dir, args.name))
createDirectory(save_dir)

with open(os.path.join(save_dir, 'config.json'), 'w', encoding='utf-8') as f:
json.dump(vars(args), f, ensure_ascii=False, indent=4)

use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")


dataset_module = getattr(import_module("dataset_font"), args.dataset)
dataset = dataset_module(
dataset_train = dataset_module(
data_dir=data_dir,
val_ratio = args.val_ratio
val_ratio = args.val_ratio,
is_train = True
)

num_classes = len(os.listdir(args.data_dir)) # font의 개수

# -- augmentation
transform_module = getattr(import_module("dataset_font"), args.augmentation) # default: BaseAugmentation
transform_module = getattr(import_module("dataset_font"), args.train_augmentation) # default: BaseAugmentation
transform = transform_module(
resize=args.resize,
mean=(0.548, 0.504, 0.479),
std=(0.237, 0.247, 0.246)
)
dataset_train.set_transform(transform)


dataset_val = dataset_module(
data_dir=data_dir,
val_ratio = args.val_ratio,
is_train = False
)

transform = transform_module(
resize=args.resize,
mean=(0.548, 0.504, 0.479),
std=(0.237, 0.247, 0.246)
)
dataset_val.set_transform(transform)


# -- data_loader & sampler
dataset.set_transform(transform)
train_set, val_set = dataset.split_dataset()


train_loader = DataLoader(
train_set,
dataset_train,
batch_size=args.batch_size,
num_workers=multiprocessing.cpu_count() // 2,
shuffle=True,
Expand All @@ -105,7 +152,7 @@ def train(data_dir, model_dir, args):
)

val_loader = DataLoader(
val_set,
dataset_val,
batch_size=args.valid_batch_size,
num_workers=multiprocessing.cpu_count() // 2,
shuffle=False,
Expand Down Expand Up @@ -135,6 +182,9 @@ def train(data_dir, model_dir, args):
if args.scheduler != "None":
scheduler = scheduler_module.get_scheduler(scheduler_module,args.scheduler, optimizer)

# labels_list
labels_list = os.listdir(args.data_dir)

for epoch in tqdm(range(args.epochs)):

# -- train loop
Expand All @@ -143,6 +193,7 @@ def train(data_dir, model_dir, args):
matches = 0
loss_value_sum = 0
train_acc_sum = 0
current_lr = get_lr(optimizer)
for idx, train_batch in enumerate(train_loader):
inputs, labels = train_batch
inputs = inputs.to(device)
Expand All @@ -162,7 +213,6 @@ def train(data_dir, model_dir, args):
if (idx + 1) % args.log_interval == 0:
train_loss = loss_value / (idx +1)
train_acc = matches / args.batch_size / (idx +1)
current_lr = get_lr(optimizer)
print(
f"Epoch[{epoch}/{args.epochs}]({idx + 1}/{len(train_loader)}) || "
f"training loss {train_loss:4.4} || training accuracy {train_acc:4.2%} || lr {current_lr}"
Expand All @@ -174,12 +224,12 @@ def train(data_dir, model_dir, args):
"step" : epoch * len(train_loader) + idx + 1
})


loss_value_sum = loss_value / args.log_interval
train_acc_sum = matches / args.batch_size / len(train_loader)
wandb.log({
"Train/loss_epoch": loss_value_sum,
"Train/accuracy_epcoh": train_acc_sum,
"Train/accuracy_epoch": train_acc_sum,
"lr": current_lr,
"epoch": epoch
})
Expand All @@ -188,13 +238,22 @@ def train(data_dir, model_dir, args):
# val loop
with torch.no_grad():
print("Calculating validation results...")

# wandb artifacts
val_data_at = wandb.Artifact("val_samples_" + str(wandb.run.id), type="predictions")
columns = ["id", "image", "guess_num" ,"guess", "truth_num", "truth"]
for classtype in range(num_classes):
columns.append("score_" + str(classtype))

val_table = wandb.Table(columns=columns)

model.eval()
val_loss_items = []
val_acc_items = []
preds_expand = torch.tensor([])
labels_expand = torch.tensor([])

for val_batch in val_loader:
for val_batch in tqdm(val_loader):
inputs, labels = val_batch
inputs = inputs.to(device)
labels = labels.to(device)
Expand All @@ -210,12 +269,14 @@ def train(data_dir, model_dir, args):

preds_expand = torch.cat((preds_expand, preds.detach().cpu()),-1)
labels_expand = torch.cat((labels_expand, labels.detach().cpu()),-1)
if (epoch + 1) % args.save_interval == 0:
log_test_predictions(inputs, labels, outs, preds, val_table, labels_list)

# -- evaluation
f1 = MulticlassF1Score(num_classes=num_classes)
f1_score = f1(preds_expand.type(torch.LongTensor), labels_expand.type(torch.LongTensor)).item()
val_loss = np.sum(val_loss_items) / len(val_set)
val_acc = np.sum(val_acc_items) / len(val_set)
val_loss = np.sum(val_loss_items) / len(dataset_val)
val_acc = np.sum(val_acc_items) / len(dataset_val)

print(f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2}, f1: {f1_score:4.4} ")

Expand All @@ -229,6 +290,10 @@ def train(data_dir, model_dir, args):
# model save
torch.save(model, f"{save_dir}/{epoch}.pth")

# wandb artifact
val_data_at.add(val_table, "predictions")
wandb.run.log_artifact(val_data_at)

# --scheduler
if args.scheduler != "None":
scheduler.step()
Expand All @@ -241,14 +306,15 @@ def train(data_dir, model_dir, args):
parser.add_argument('--seed', type=int, default=42, help='random seed (default: 42)')
parser.add_argument('--epochs', type=int, default=200, help='number of epochs to train (default: 200)')
parser.add_argument('--dataset', type=str, default='FontDataset', help='dataset augmentation type (default: Ma skBaseDataset)')
parser.add_argument('--augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)') ##
parser.add_argument('--train_augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)')
parser.add_argument('--val_augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)')
parser.add_argument("--resize", nargs="+", type=int, default=[256, 256], help='resize size for image when training')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training (default: 64)')
parser.add_argument('--valid_batch_size', type=int, default=10, help='input batch size for validing (default: 1000)')
parser.add_argument('--valid_batch_size', type=int, default=100, help='input batch size for validing (default: 1000)')
parser.add_argument('--model', type=str, default='ResNet50', help='model type (default: ResNet50)')
parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer type (default: Adam)')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)')
parser.add_argument('--val_ratio', type=float, default=0.2, help='ratio for validaton (default: 0.2)')
parser.add_argument('--val_ratio', type=float, default=0.001, help='ratio for validaton (default: 0.2)')
parser.add_argument('--criterion', type=str, default='cross_entropy', help='criterion type (default: cross_entropy)')
parser.add_argument('--log_interval', type=int, default=20, help='how many batches to wait before logging training status')
parser.add_argument('--name', default='exp', help='model save at {SM_MODEL_DIR}/{name}')
Expand Down

0 comments on commit beb7f96

Please sign in to comment.