Skip to content

Commit

Permalink
feat: stratified split, split train/val pipeline
Browse files Browse the repository at this point in the history
Referenced Issue: #2
  • Loading branch information
jerry-ryu committed Jan 9, 2023
1 parent 8698bc5 commit b4eeb25
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 34 deletions.
59 changes: 36 additions & 23 deletions model/font_classifier/dataset_font.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,24 +12,39 @@

class FontDataset(Dataset):

image_paths = []
image_labels = []
image_paths_train = []
image_labels_train = []

def __init__(self, data_dir, val_ratio=0.2):
image_paths_val = []
image_labels_val = []

def __init__(self, data_dir, val_ratio=0.2, is_train = True):
self.data_dir = data_dir
self.val_ratio = val_ratio

self.transform = None
self.setup()
if is_train:
self.val_ratio = val_ratio
self.is_train = is_train
self.setup()


def setup(self):
profiles = os.listdir(self.data_dir)

for idx, profile in enumerate(profiles):
paths = os.listdir(os.path.join(self.data_dir, profile))
image_path = []
image_label = []
for path in paths:
self.image_paths.append(os.path.join(self.data_dir,profile,path))
self.image_labels.append(idx)
image_path.append(os.path.join(self.data_dir,profile,path))
image_label.append(idx)
tmp_all = set(range(len(image_path)))
tmp_train = set(random.sample(list(range(len(image_path))), int(len(image_path) * self.val_ratio)))
tmp_val = tmp_all - tmp_train

self.image_paths_train.extend([image_path[x] for x in tmp_train])
self.image_labels_train.extend([image_label[x] for x in tmp_train])
self.image_paths_val.extend([image_path[x] for x in tmp_val])
self.image_labels_val.extend([image_path[x] for x in tmp_val])

def set_transform(self, transform):
self.transform = transform
Expand All @@ -40,27 +55,25 @@ def __getitem__(self, index):
image = self.read_image(index)
image_transform = self.transform(image)

label = self.image_labels[index]
if self.is_train:
label = self.image_labels_train[index]
else:
label = self.image_labels_val[index]

return image_transform, label

def read_image(self, index):
image_path = self.image_paths[index]
if self.is_train:
image_path = self.image_paths_train[index]
else:
image_path = self.image_paths_val[index]
return Image.open(image_path).convert('RGB')

def __len__(self):
return len(self.image_paths)

def split_dataset(self) -> Tuple[Subset, Subset]:
"""
데이터셋을 train 과 val 로 나눕니다,
pytorch 내부의 torch.utils.data.random_split 함수를 사용하여
torch.utils.data.Subset 클래스 둘로 나눕니다.
구현이 어렵지 않으니 구글링 혹은 IDE (e.g. pycharm) 의 navigation 기능을 통해 코드를 한 번 읽어보는 것을 추천드립니다^^
"""
n_val = int(len(self) * self.val_ratio)
n_train = len(self) - n_val
train_set, val_set = random_split(self, [n_train, n_val])
return train_set, val_set
if self.is_train:
return len(self.image_paths_train)
else:
return len(self.image_paths_val)



Expand Down
37 changes: 26 additions & 11 deletions model/font_classifier/train_font.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,28 +75,42 @@ def train(data_dir, model_dir, args):


dataset_module = getattr(import_module("dataset_font"), args.dataset)
dataset = dataset_module(
dataset_train = dataset_module(
data_dir=data_dir,
val_ratio = args.val_ratio
val_ratio = args.val_ratio,
is_train = True
)

num_classes = len(os.listdir(args.data_dir)) # font의 개수

# -- augmentation
transform_module = getattr(import_module("dataset_font"), args.augmentation) # default: BaseAugmentation
transform_module = getattr(import_module("dataset_font"), args.train_augmentation) # default: BaseAugmentation
transform = transform_module(
resize=args.resize,
mean=(0.548, 0.504, 0.479),
std=(0.237, 0.247, 0.246)
)

# -- data_loader & sampler
dataset.set_transform(transform)
train_set, val_set = dataset.split_dataset()
dataset_train.set_transform(transform)

dataset_val = dataset_module(
data_dir=data_dir,
val_ratio = args.val_ratio,
is_train = False
)
dataset_val.set_transform(transform)


transform = transform_module(
resize=args.resize,
mean=(0.548, 0.504, 0.479),
std=(0.237, 0.247, 0.246)
)


train_loader = DataLoader(
train_set,
dataset_train,
batch_size=args.batch_size,
num_workers=multiprocessing.cpu_count() // 2,
shuffle=True,
Expand All @@ -105,7 +119,7 @@ def train(data_dir, model_dir, args):
)

val_loader = DataLoader(
val_set,
dataset_val,
batch_size=args.valid_batch_size,
num_workers=multiprocessing.cpu_count() // 2,
shuffle=False,
Expand Down Expand Up @@ -214,8 +228,8 @@ def train(data_dir, model_dir, args):
# -- evaluation
f1 = MulticlassF1Score(num_classes=num_classes)
f1_score = f1(preds_expand.type(torch.LongTensor), labels_expand.type(torch.LongTensor)).item()
val_loss = np.sum(val_loss_items) / len(val_set)
val_acc = np.sum(val_acc_items) / len(val_set)
val_loss = np.sum(val_loss_items) / len(dataset_val)
val_acc = np.sum(val_acc_items) / len(dataset_val)

print(f"[Val] acc : {val_acc:4.2%}, loss: {val_loss:4.2}, f1: {f1_score:4.4} ")

Expand All @@ -241,14 +255,15 @@ def train(data_dir, model_dir, args):
parser.add_argument('--seed', type=int, default=42, help='random seed (default: 42)')
parser.add_argument('--epochs', type=int, default=200, help='number of epochs to train (default: 200)')
parser.add_argument('--dataset', type=str, default='FontDataset', help='dataset augmentation type (default: Ma skBaseDataset)')
parser.add_argument('--augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)') ##
parser.add_argument('--train_augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)')
parser.add_argument('--val_augmentation', type=str, default='BaseAugmentation', help='data augmentation type (default: BaseAugmentation)')
parser.add_argument("--resize", nargs="+", type=int, default=[256, 256], help='resize size for image when training')
parser.add_argument('--batch_size', type=int, default=64, help='input batch size for training (default: 64)')
parser.add_argument('--valid_batch_size', type=int, default=10, help='input batch size for validing (default: 1000)')
parser.add_argument('--model', type=str, default='ResNet50', help='model type (default: ResNet50)')
parser.add_argument('--optimizer', type=str, default='Adam', help='optimizer type (default: Adam)')
parser.add_argument('--lr', type=float, default=1e-3, help='learning rate (default: 1e-3)')
parser.add_argument('--val_ratio', type=float, default=0.2, help='ratio for validaton (default: 0.2)')
parser.add_argument('--val_ratio', type=float, default=0.01, help='ratio for validaton (default: 0.2)')
parser.add_argument('--criterion', type=str, default='cross_entropy', help='criterion type (default: cross_entropy)')
parser.add_argument('--log_interval', type=int, default=20, help='how many batches to wait before logging training status')
parser.add_argument('--name', default='exp', help='model save at {SM_MODEL_DIR}/{name}')
Expand Down

0 comments on commit b4eeb25

Please sign in to comment.