Skip to content

Commit

Permalink
add face tasks
Browse files Browse the repository at this point in the history
  • Loading branch information
andytu28 committed Dec 17, 2019
1 parent 6bfc5f1 commit 419f7d2
Show file tree
Hide file tree
Showing 31 changed files with 3,498 additions and 2,546 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
*.pyc
587 changes: 0 additions & 587 deletions CPG_cifar100_main_finetune_sparse.py

This file was deleted.

656 changes: 0 additions & 656 deletions CPG_cifar100_with_one_mask.py

This file was deleted.

388 changes: 222 additions & 166 deletions CPG_cifar100_main.py → CPG_face_main.py

Large diffs are not rendered by default.

69 changes: 69 additions & 0 deletions FACE_UTILS/LFWDataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,69 @@
# -*- coding: utf-8 -*-
import torchvision.datasets as datasets
import os
import numpy as np

class LFWDataset(datasets.ImageFolder):
'''
'''
def __init__(self, dir, pairs_path, file_ext='jpg', transform=None):

super(LFWDataset, self).__init__(dir,transform)
self.pairs_path = pairs_path
# LFW dir contains 2 folders: faces and lists
self.validation_images = self.get_lfw_paths(dir,file_ext=file_ext)

def read_lfw_pairs(self,pairs_filename):
pairs = []
with open(pairs_filename, 'r') as f:
for line in f.readlines()[1:]:
pair = line.strip().split()
pairs.append(pair)
return np.array(pairs)
# !!!!!!!!!!!!!!!!!!!!!!!!!!!NOTICE YOUR FILE_EXTENSION!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!
def get_lfw_paths(self,lfw_dir,file_ext="jpg"):
pairs = self.read_lfw_pairs(self.pairs_path)
nrof_skipped_pairs = 0
path_list = []
issame_list = []
for i in range(len(pairs)):
pair = pairs[i]
if len(pair) == 3:
path0 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])+'.'+file_ext)
path1 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[2])+'.'+file_ext)
issame = True
elif len(pair) == 4:
path0 = os.path.join(lfw_dir, pair[0], pair[0] + '_' + '%04d' % int(pair[1])+'.'+file_ext)
path1 = os.path.join(lfw_dir, pair[2], pair[2] + '_' + '%04d' % int(pair[3])+'.'+file_ext)
issame = False
if os.path.exists(path0) and os.path.exists(path1): # Only add the pair if both paths exist
path_list.append((path0,path1,issame))
issame_list.append(issame)
else:
nrof_skipped_pairs += 1
if nrof_skipped_pairs>0:
print('Skipped %d image pairs' % nrof_skipped_pairs)
return path_list

def __getitem__(self, index):
'''
Args:
index: Index of the triplet or the matches - not of a single image
Returns:
'''
def transform(img_path):
"""Convert image into numpy array and apply transformation
Doing this so that it is consistent with all other datasets
to return a PIL Image.
"""
img = self.loader(img_path)
return self.transform(img)
(path_1,path_2,issame) = self.validation_images[index]
img1, img2 = transform(path_1), transform(path_2)
return img1, img2, issame

def __len__(self):
return len(self.validation_images)
9 changes: 9 additions & 0 deletions FACE_UTILS/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""Contains a bunch of utility functions."""
import numpy as np
import pdb


def set_dataset_paths(args):
"""Set default train and test path if not provided as input."""
args.train_path = 'data/%s/train' % (args.dataset)
args.val_path = 'data/%s/val' % (args.dataset)
55 changes: 55 additions & 0 deletions FACE_UTILS/dataset.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
import collections
import glob
import os

import numpy as np
from PIL import Image

import torch
import torch.backends.cudnn as cudnn
import torch.nn as nn
import torch.optim as optim
import torchvision.datasets as datasets
import torchvision.models as models
import torchvision.transforms as transforms
import pdb

VGGFACE_MEAN = [0.5, 0.5, 0.5]
VGGFACE_STD = [0.5, 0.5, 0.5]


def train_loader(path, train_batch_size, num_workers=4, pin_memory=False, normalize=None):
if normalize is None:
normalize = transforms.Normalize(
mean=VGGFACE_MEAN, std=VGGFACE_STD)

train_transform = transforms.Compose([
transforms.Resize(112),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
normalize,
])

train_dataset = datasets.ImageFolder(path, train_transform)

return torch.utils.data.DataLoader(train_dataset,
batch_size=train_batch_size, shuffle=True, sampler=None,
num_workers=num_workers, pin_memory=pin_memory)


def val_loader(path, val_batch_size, num_workers=4, pin_memory=False, normalize=None):
if normalize is None:
normalize = transforms.Normalize(
mean=VGGFACE_MEAN, std=VGGFACE_STD)

val_transform = transforms.Compose([
transforms.Resize(112),
transforms.ToTensor(),
normalize,
])

val_dataset = datasets.ImageFolder(path, val_transform)

return torch.utils.data.DataLoader(val_dataset,
batch_size=val_batch_size, shuffle=False, sampler=None,
num_workers=num_workers, pin_memory=pin_memory)
Loading

0 comments on commit 419f7d2

Please sign in to comment.