Skip to content

Commit

Permalink
initial commit
Browse files Browse the repository at this point in the history
  • Loading branch information
carpedm20 committed Apr 6, 2017
0 parents commit f515b16
Show file tree
Hide file tree
Showing 14 changed files with 921 additions and 0 deletions.
128 changes: 128 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
# Data
data/hand
data/gaze
data/*
samples
outputs

# ipython checkpoints
.ipynb_checkpoints

# Log
logs

# ETC
paper.pdf
.DS_Store

# Created by https://www.gitignore.io/api/python,vim

### Python ###
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
env/
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*,cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# dotenv
.env

# virtualenv
.venv/
venv/
ENV/

# Spyder project settings
.spyderproject

# Rope project settings
.ropeproject


### Vim ###
# swap
[._]*.s[a-v][a-z]
[._]*.sw[a-p]
[._]s[a-v][a-z]
[._]sw[a-p]
# session
Session.vim
# temporary
.netrwhist
*~
# auto-generated tag files
tags

# End of https://www.gitignore.io/api/python,vim
57 changes: 57 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# BEGAN in Tensorflow

Tensorflow implementation of [BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717).

![alt tag](./assets/model.png)


## Requirements

- Python 2.7
- [Pillow](https://pillow.readthedocs.io/en/4.0.x/)
- [tqdm](https://github.com/tqdm/tqdm)
- [TensorFlow](https://github.com/tensorflow/tensorflow)
- [requests](https://github.com/kennethreitz/requests) (Only used for downloading CelebA dataset)


## Usage

First download [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) datasets with:

$ apt-get install p7zip-full # ubuntu
$ brew install p7zip # Mac
$ python download.py

or you can use your own dataset by placing images like:

data
└── YOUR_DATASET_NAME
├── xxx.jpg (name doesn't matter)
├── yyy.jpg
└── ...

To train a model:

$ python main.py --dataset=CelebA --num_gpu=1
$ python main.py --dataset=YOUR_DATASET_NAME --num_gpu=4

To test a model (use your `load_path`):

$ python main.py --dataset=CelebA --load_path=./logs/CelebA_0405_124806 --num_gpu=0 --is_train=False --split valid


## Results

- [BEGAN-tensorflow](https://github.com/carpedm20/began-tensorflow) at least can generate human faces but [BEGAN-pytorch](https://github.com/carpedm20/BEGAN-pytorch) can't.
- Both [BEGAN-tensorflow](https://github.com/carpedm20/began-tensorflow) and [BEGAN-pytorch](https://github.com/carpedm20/BEGAN-pytorch) shows **modal collapses** and I guess this is due to a wrong scheuduling of lr (Paper mentioned that *simply reducing the lr was sufficient to avoid them*).

### Reusults after 82400 step

![alt tag](./assets/82400_1.png) ![alt tag](./assets/82400_2.png)

(in progress)


## Author

Taehoon Kim / [@carpedm20](http://carpedm20.github.io)
Binary file added assets/82400_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/82400_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file added assets/model.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
66 changes: 66 additions & 0 deletions config.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,66 @@
#-*- coding: utf-8 -*-
import argparse

def str2bool(v):
return v.lower() in ('true', '1')

arg_lists = []
parser = argparse.ArgumentParser()

def add_argument_group(name):
arg = parser.add_argument_group(name)
arg_lists.append(arg)
return arg

# Network
net_arg = add_argument_group('Network')
net_arg.add_argument('--input_scale_size', type=int, default=64,
help='input image will be resized with the given value as width and height')
net_arg.add_argument('--conv_hidden_num', type=int, default=128, help='n in the paper')
net_arg.add_argument('--z_num', type=int, default=128)

# Data
data_arg = add_argument_group('Data')
data_arg.add_argument('--dataset', type=str, default='CelebA')
data_arg.add_argument('--split', type=str, default='train')
data_arg.add_argument('--batch_size', type=int, default=16)
data_arg.add_argument('--grayscale', type=str2bool, default=False)
data_arg.add_argument('--num_worker', type=int, default=12)

# Training / test parameters
train_arg = add_argument_group('Training')
train_arg.add_argument('--is_train', type=str2bool, default=True)
train_arg.add_argument('--optimizer', type=str, default='adam')
train_arg.add_argument('--max_step', type=int, default=250000)
train_arg.add_argument('--lr_update_step', type=int, default=3000)
train_arg.add_argument('--d_lr', type=float, default=0.0001)
train_arg.add_argument('--g_lr', type=float, default=0.0001)
train_arg.add_argument('--beta1', type=float, default=0.5)
train_arg.add_argument('--beta2', type=float, default=0.999)
train_arg.add_argument('--gamma', type=float, default=0.5)
train_arg.add_argument('--lambda_k', type=float, default=0.001)
train_arg.add_argument('--use_gpu', type=float, default=True)

# Misc
misc_arg = add_argument_group('Misc')
misc_arg.add_argument('--load_path', type=str, default='')
misc_arg.add_argument('--log_step', type=int, default=50)
misc_arg.add_argument('--save_step', type=int, default=5000)
misc_arg.add_argument('--num_log_samples', type=int, default=3)
misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN'])
misc_arg.add_argument('--log_dir', type=str, default='logs')
misc_arg.add_argument('--data_dir', type=str, default='data')
misc_arg.add_argument('--test_data_path', type=str, default=None,
help='directory with images which will be used in test sample generation')
misc_arg.add_argument('--sample_per_image', type=int, default=64,
help='# of sample per image during test sample generation')
misc_arg.add_argument('--random_seed', type=int, default=123)

def get_config():
config, unparsed = parser.parse_known_args()
if config.use_gpu:
data_format = 'NCHW'
else:
data_format = 'NHWC'
setattr(config, 'data_format', data_format)
return config, unparsed
56 changes: 56 additions & 0 deletions data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import os
from PIL import Image
from glob import glob
import tensorflow as tf

def get_loader(root, batch_size, scale, data_format, split=None, is_grayscale=False, seed=None):
dataset_name = os.path.basename(root)
if dataset_name in ['CelebA'] and split:
root = os.path.join(root, 'splits', split)

for ext in ["jpg", "png"]:
paths = glob("{}/*.{}".format(root, ext))

if ext == "jpg":
tf_decode = tf.image.decode_jpeg
elif ext == "png":
tf_decode = tf.image.decode_png

if len(paths) != 0:
break

with Image.open(paths[0]) as img:
w, h = img.size
shape = [h, w, 3]

filename_queue = tf.train.string_input_producer(list(paths), shuffle=False, seed=seed)
reader = tf.WholeFileReader()
filename, data = reader.read(filename_queue)
image = tf_decode(data, channels=3)

if is_grayscale:
image = tf.image.rgb_to_grayscale(image)
image.set_shape(shape)

min_after_dequeue = 5000
capacity = min_after_dequeue + 3 * batch_size

queue = tf.train.shuffle_batch(
[image], batch_size=batch_size,
num_threads=4, capacity=capacity,
min_after_dequeue=min_after_dequeue, name='synthetic_inputs')

if dataset_name in ['CelebA']:
queue = tf.image.resize_nearest_neighbor(queue, [78, 64])
queue = tf.image.crop_to_bounding_box(queue, 7, 0, 64, 64)
else:
queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size])

if data_format == 'NCHW':
queue = tf.transpose(queue, [0, 3, 1, 2])
elif data_format == 'NHWC':
pass
else:
raise Exception("[!] Unkown data_format: {}".format(data_format))

return tf.to_float(queue)
Loading

0 comments on commit f515b16

Please sign in to comment.