forked from carpedm20/BEGAN-tensorflow
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
0 parents
commit f515b16
Showing
14 changed files
with
921 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
# Data | ||
data/hand | ||
data/gaze | ||
data/* | ||
samples | ||
outputs | ||
|
||
# ipython checkpoints | ||
.ipynb_checkpoints | ||
|
||
# Log | ||
logs | ||
|
||
# ETC | ||
paper.pdf | ||
.DS_Store | ||
|
||
# Created by https://www.gitignore.io/api/python,vim | ||
|
||
### Python ### | ||
# Byte-compiled / optimized / DLL files | ||
__pycache__/ | ||
*.py[cod] | ||
*$py.class | ||
|
||
# C extensions | ||
*.so | ||
|
||
# Distribution / packaging | ||
.Python | ||
env/ | ||
build/ | ||
develop-eggs/ | ||
dist/ | ||
downloads/ | ||
eggs/ | ||
.eggs/ | ||
lib/ | ||
lib64/ | ||
parts/ | ||
sdist/ | ||
var/ | ||
wheels/ | ||
*.egg-info/ | ||
.installed.cfg | ||
*.egg | ||
|
||
# PyInstaller | ||
# Usually these files are written by a python script from a template | ||
# before PyInstaller builds the exe, so as to inject date/other infos into it. | ||
*.manifest | ||
*.spec | ||
|
||
# Installer logs | ||
pip-log.txt | ||
pip-delete-this-directory.txt | ||
|
||
# Unit test / coverage reports | ||
htmlcov/ | ||
.tox/ | ||
.coverage | ||
.coverage.* | ||
.cache | ||
nosetests.xml | ||
coverage.xml | ||
*,cover | ||
.hypothesis/ | ||
|
||
# Translations | ||
*.mo | ||
*.pot | ||
|
||
# Django stuff: | ||
*.log | ||
local_settings.py | ||
|
||
# Flask stuff: | ||
instance/ | ||
.webassets-cache | ||
|
||
# Scrapy stuff: | ||
.scrapy | ||
|
||
# Sphinx documentation | ||
docs/_build/ | ||
|
||
# PyBuilder | ||
target/ | ||
|
||
# Jupyter Notebook | ||
.ipynb_checkpoints | ||
|
||
# pyenv | ||
.python-version | ||
|
||
# celery beat schedule file | ||
celerybeat-schedule | ||
|
||
# dotenv | ||
.env | ||
|
||
# virtualenv | ||
.venv/ | ||
venv/ | ||
ENV/ | ||
|
||
# Spyder project settings | ||
.spyderproject | ||
|
||
# Rope project settings | ||
.ropeproject | ||
|
||
|
||
### Vim ### | ||
# swap | ||
[._]*.s[a-v][a-z] | ||
[._]*.sw[a-p] | ||
[._]s[a-v][a-z] | ||
[._]sw[a-p] | ||
# session | ||
Session.vim | ||
# temporary | ||
.netrwhist | ||
*~ | ||
# auto-generated tag files | ||
tags | ||
|
||
# End of https://www.gitignore.io/api/python,vim |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# BEGAN in Tensorflow | ||
|
||
Tensorflow implementation of [BEGAN: Boundary Equilibrium Generative Adversarial Networks](https://arxiv.org/abs/1703.10717). | ||
|
||
![alt tag](./assets/model.png) | ||
|
||
|
||
## Requirements | ||
|
||
- Python 2.7 | ||
- [Pillow](https://pillow.readthedocs.io/en/4.0.x/) | ||
- [tqdm](https://github.com/tqdm/tqdm) | ||
- [TensorFlow](https://github.com/tensorflow/tensorflow) | ||
- [requests](https://github.com/kennethreitz/requests) (Only used for downloading CelebA dataset) | ||
|
||
|
||
## Usage | ||
|
||
First download [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html) datasets with: | ||
|
||
$ apt-get install p7zip-full # ubuntu | ||
$ brew install p7zip # Mac | ||
$ python download.py | ||
|
||
or you can use your own dataset by placing images like: | ||
|
||
data | ||
└── YOUR_DATASET_NAME | ||
├── xxx.jpg (name doesn't matter) | ||
├── yyy.jpg | ||
└── ... | ||
|
||
To train a model: | ||
|
||
$ python main.py --dataset=CelebA --num_gpu=1 | ||
$ python main.py --dataset=YOUR_DATASET_NAME --num_gpu=4 | ||
|
||
To test a model (use your `load_path`): | ||
|
||
$ python main.py --dataset=CelebA --load_path=./logs/CelebA_0405_124806 --num_gpu=0 --is_train=False --split valid | ||
|
||
|
||
## Results | ||
|
||
- [BEGAN-tensorflow](https://github.com/carpedm20/began-tensorflow) at least can generate human faces but [BEGAN-pytorch](https://github.com/carpedm20/BEGAN-pytorch) can't. | ||
- Both [BEGAN-tensorflow](https://github.com/carpedm20/began-tensorflow) and [BEGAN-pytorch](https://github.com/carpedm20/BEGAN-pytorch) shows **modal collapses** and I guess this is due to a wrong scheuduling of lr (Paper mentioned that *simply reducing the lr was sufficient to avoid them*). | ||
|
||
### Reusults after 82400 step | ||
|
||
![alt tag](./assets/82400_1.png) ![alt tag](./assets/82400_2.png) | ||
|
||
(in progress) | ||
|
||
|
||
## Author | ||
|
||
Taehoon Kim / [@carpedm20](http://carpedm20.github.io) |
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,66 @@ | ||
#-*- coding: utf-8 -*- | ||
import argparse | ||
|
||
def str2bool(v): | ||
return v.lower() in ('true', '1') | ||
|
||
arg_lists = [] | ||
parser = argparse.ArgumentParser() | ||
|
||
def add_argument_group(name): | ||
arg = parser.add_argument_group(name) | ||
arg_lists.append(arg) | ||
return arg | ||
|
||
# Network | ||
net_arg = add_argument_group('Network') | ||
net_arg.add_argument('--input_scale_size', type=int, default=64, | ||
help='input image will be resized with the given value as width and height') | ||
net_arg.add_argument('--conv_hidden_num', type=int, default=128, help='n in the paper') | ||
net_arg.add_argument('--z_num', type=int, default=128) | ||
|
||
# Data | ||
data_arg = add_argument_group('Data') | ||
data_arg.add_argument('--dataset', type=str, default='CelebA') | ||
data_arg.add_argument('--split', type=str, default='train') | ||
data_arg.add_argument('--batch_size', type=int, default=16) | ||
data_arg.add_argument('--grayscale', type=str2bool, default=False) | ||
data_arg.add_argument('--num_worker', type=int, default=12) | ||
|
||
# Training / test parameters | ||
train_arg = add_argument_group('Training') | ||
train_arg.add_argument('--is_train', type=str2bool, default=True) | ||
train_arg.add_argument('--optimizer', type=str, default='adam') | ||
train_arg.add_argument('--max_step', type=int, default=250000) | ||
train_arg.add_argument('--lr_update_step', type=int, default=3000) | ||
train_arg.add_argument('--d_lr', type=float, default=0.0001) | ||
train_arg.add_argument('--g_lr', type=float, default=0.0001) | ||
train_arg.add_argument('--beta1', type=float, default=0.5) | ||
train_arg.add_argument('--beta2', type=float, default=0.999) | ||
train_arg.add_argument('--gamma', type=float, default=0.5) | ||
train_arg.add_argument('--lambda_k', type=float, default=0.001) | ||
train_arg.add_argument('--use_gpu', type=float, default=True) | ||
|
||
# Misc | ||
misc_arg = add_argument_group('Misc') | ||
misc_arg.add_argument('--load_path', type=str, default='') | ||
misc_arg.add_argument('--log_step', type=int, default=50) | ||
misc_arg.add_argument('--save_step', type=int, default=5000) | ||
misc_arg.add_argument('--num_log_samples', type=int, default=3) | ||
misc_arg.add_argument('--log_level', type=str, default='INFO', choices=['INFO', 'DEBUG', 'WARN']) | ||
misc_arg.add_argument('--log_dir', type=str, default='logs') | ||
misc_arg.add_argument('--data_dir', type=str, default='data') | ||
misc_arg.add_argument('--test_data_path', type=str, default=None, | ||
help='directory with images which will be used in test sample generation') | ||
misc_arg.add_argument('--sample_per_image', type=int, default=64, | ||
help='# of sample per image during test sample generation') | ||
misc_arg.add_argument('--random_seed', type=int, default=123) | ||
|
||
def get_config(): | ||
config, unparsed = parser.parse_known_args() | ||
if config.use_gpu: | ||
data_format = 'NCHW' | ||
else: | ||
data_format = 'NHWC' | ||
setattr(config, 'data_format', data_format) | ||
return config, unparsed |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,56 @@ | ||
import os | ||
from PIL import Image | ||
from glob import glob | ||
import tensorflow as tf | ||
|
||
def get_loader(root, batch_size, scale, data_format, split=None, is_grayscale=False, seed=None): | ||
dataset_name = os.path.basename(root) | ||
if dataset_name in ['CelebA'] and split: | ||
root = os.path.join(root, 'splits', split) | ||
|
||
for ext in ["jpg", "png"]: | ||
paths = glob("{}/*.{}".format(root, ext)) | ||
|
||
if ext == "jpg": | ||
tf_decode = tf.image.decode_jpeg | ||
elif ext == "png": | ||
tf_decode = tf.image.decode_png | ||
|
||
if len(paths) != 0: | ||
break | ||
|
||
with Image.open(paths[0]) as img: | ||
w, h = img.size | ||
shape = [h, w, 3] | ||
|
||
filename_queue = tf.train.string_input_producer(list(paths), shuffle=False, seed=seed) | ||
reader = tf.WholeFileReader() | ||
filename, data = reader.read(filename_queue) | ||
image = tf_decode(data, channels=3) | ||
|
||
if is_grayscale: | ||
image = tf.image.rgb_to_grayscale(image) | ||
image.set_shape(shape) | ||
|
||
min_after_dequeue = 5000 | ||
capacity = min_after_dequeue + 3 * batch_size | ||
|
||
queue = tf.train.shuffle_batch( | ||
[image], batch_size=batch_size, | ||
num_threads=4, capacity=capacity, | ||
min_after_dequeue=min_after_dequeue, name='synthetic_inputs') | ||
|
||
if dataset_name in ['CelebA']: | ||
queue = tf.image.resize_nearest_neighbor(queue, [78, 64]) | ||
queue = tf.image.crop_to_bounding_box(queue, 7, 0, 64, 64) | ||
else: | ||
queue = tf.image.resize_nearest_neighbor(queue, [scale_size, scale_size]) | ||
|
||
if data_format == 'NCHW': | ||
queue = tf.transpose(queue, [0, 3, 1, 2]) | ||
elif data_format == 'NHWC': | ||
pass | ||
else: | ||
raise Exception("[!] Unkown data_format: {}".format(data_format)) | ||
|
||
return tf.to_float(queue) |
Oops, something went wrong.