diff --git a/README.md b/README.md index 04757b6e..ef2ef4ed 100644 --- a/README.md +++ b/README.md @@ -1,2 +1,182 @@ # pix2pix-tensorflow -Tensorflow Port of Image-to-image translation using conditional adversarial nets https://phillipi.github.io/pix2pix/ + +Based on [pix2pix](https://phillipi.github.io/pix2pix/) by Isola et al. + +[Article about this implemention](https://affinelayer.com/pix2pix/) + +Tensorflow implementation of pix2pix. Learns a mapping from input images to output images, like these examples from the original paper: + + + +This port is based directly on the torch implementation, and not on an existing Tensorflow implementation. It is meant to be a faithful implementation of the original work and so does not add anything. The processing speed on a GPU with cuDNN was equivalent to the Torch implementation in testing. + +## Setup + +### Prerequisites +- Tensorflow 0.12.1 + +### Recommended +- Linux with Tensorflow GPU edition + cuDNN + +### Getting Started + +```sh +# Clone this repo +git clone https://github.com/affinelayer/pix2pix-tensorflow.git +cd pix2pix-tensorflow +# Download the CMP Facades dataset http://cmp.felk.cvut.cz/~tylecr1/facade/ +python tools/download-dataset.py facades +# Train the model (this may take 1-8 hours depending on GPU, on CPU you will be waiting for a bit) +python pix2pix.py --mode train --output_dir facades_train --max_epochs 200 --input_dir facades/train --which_direction BtoA +# Test the model +python pix2pix.py --mode test --output_dir facades_test --input_dir facades/val --checkpoint facades_train +``` + +The test run will output an HTML file at `facades_test/index.html` that shows input/output/target image sets. + +## Datasets + +The data format used by this program is the same as the original pix2pix format, which consists of images of input and desired output side by side like: + + + +For example: + + + +Some datasets have been made available by the authors of the pix2pix paper. To download those datasets, use the included script `tools/download-dataset.py`. + +| dataset | image | +| --- | --- | +| `python tools/download-dataset.py facades`
400 images from [CMP Facades dataset](http://cmp.felk.cvut.cz/~tylecr1/facade/). (31MB) | | +| `python tools/download-dataset.py cityscapes`
2975 images from the [Cityscapes training set](https://www.cityscapes-dataset.com/). (113M) | | +| `python tools/download-dataset.py maps`
1096 training images scraped from Google Maps (246M) | | +| `python tools/download-dataset.py edges2shoes`
50k training images from [UT Zappos50K dataset](http://vision.cs.utexas.edu/projects/finegrained/utzap50k/). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. (2.2GB) | | +| `python tools/download-dataset.py edges2handbags`
137K Amazon Handbag images from [iGAN project](https://github.com/junyanz/iGAN). Edges are computed by [HED](https://github.com/s9xie/hed) edge detector + post-processing. (8.6GB) | | + +The `facades` dataset is the smallest and easiest to get started with. + +### Creating your own dataset + +#### Example: creating images with blank centers for [inpainting](https://people.eecs.berkeley.edu/~pathak/context_encoder/) + + + +```sh +# Resize source images +python tools/process.py --input_dir photos/original --operation resize --output_dir photos/resized +# Create images with blank centers +python tools/process.py --input_dir photos/resized --operation blank --output_dir photos/blank +# Combine resized images with blanked images +python tools/process.py --input_dir photos/resized --b_dir photos/blank --operation combine --output_dir photos/combined +# Split into train/val set +python tools/split.py --dir photos/combined +``` + +The folder `photos/combined` will now have `train` and `val` subfolders that you can use for training and testing. + +#### Creating image pairs from existing images + +If you have two directories `a` and `b`, with corresponding images (same name, same dimensions, different data) you can combine them with `process.py`: + +```sh +python tools/process.py --input_dir a --b_dir b --operation combine --output_dir c +``` + +This puts the images in a side-by-side combined image that `pix2pix.py` expects. + +#### Colorization + +For colorization, your images should ideally all be the same aspect ratio. You can resize and crop them with the resize command: +```sh +python tools/process.py --input_dir photos/original --operation resize --output_dir photos/resized +``` + +No other processing is required, the colorzation mode (see Training section below) uses single images instead of image pairs. + +## Training + +### Image Pairs + +For normal training with image pairs, you need to specify which directory contains the training images, and which direction to train on. The direction options are `AtoB` or `BtoA` +```sh +python pix2pix.py --mode train --output_dir facades_train --max_epochs 200 --input_dir facades/train --which_direction BtoA +``` + +### Colorization + +`pix2pix.py` includes special code to handle colorization with single images instead of pairs, using that looks like this: + +```sh +python pix2pix.py --mode train --output_dir photos_train --max_epochs 200 --input_dir photos/train --lab_colorization +``` + +In this mode, image A is the black and white image (lightness only), and image B contains the color channels of that image (no lightness information). + +### Tips + +You can look at the loss and computation graph using tensorboard: +```sh +tensorboard --logdir=facades_train +``` + + + +If you wish to write in-progress pictures as the network is training, use `--display_freq 50`. This will update `facades_train/index.html` every 50 steps with the current training inputs and outputs. + +## Testing + +Testing is done with `--mode test`. You should specify the checkpoint to use with `--checkpoint`, this should point to the `output_dir` that you created previously with `--mode train`: + +```sh +python pix2pix.py --mode test --output_dir facades_test --input_dir facades/val --checkpoint facades_train +``` + +The testing mode will load some of the configuration options from the checkpoint provided so you do not need to specify `which_direction` for instance. + +The test run will output an HTML file at `facades_test/index.html` that shows input/output/target image sets: + + + +## Implementation Validation + +Validation of the code was performed on a Linux machine with a ~1.3 TFLOPS Nvidia GTX 750 Ti GPU. Due to a lack of compute power, validation is not extensive and only the `facades` dataset at 200 epochs was tested. + +```sh +git clone https://github.com/affinelayer/pix2pix-tensorflow.git +cd pix2pix-tensorflow +python tools/download-dataset.py facades +time nvidia-docker run --volume $PWD:/prj --workdir /prj --env PYTHONUNBUFFERED=x affinelayer/tensorflow:pix2pix python pix2pix.py --mode train --output_dir facades_train --max_epochs 200 --input_dir facades/train --which_direction BtoA +nvidia-docker run --volume $PWD:/prj --workdir /prj --env PYTHONUNBUFFERED=x affinelayer/tensorflow:pix2pix python pix2pix.py --mode test --output_dir facades_test --input_dir facades/val --checkpoint facades_train +``` + +Comparison on facades dataset: + +| Input | Tensorflow | Torch | Target | +| --- | --- | --- | --- | +| | | | | +| | | | | +| | | | | +| | | | | + +## Unimplemented Features + +The following models have not been implemented: +- defineG_encoder_decoder +- defineG_unet_128 +- defineD_pixelGAN + +## Citation +If you use this code for your research, please cite the paper this code is based on: Image-to-Image Translation Using Conditional Adversarial Networks: + +``` +@article{pix2pix2016, + title={Image-to-Image Translation with Conditional Adversarial Networks}, + author={Isola, Phillip and Zhu, Jun-Yan and Zhou, Tinghui and Efros, Alexei A}, + journal={arxiv}, + year={2016} +} +``` + +## Acknowledgments +This is a port of [pix2pix](https://github.com/phillipi/pix2pix) from Torch to Tensorflow. It also contains colorspace conversion code ported from Torch. diff --git a/docs/1-inputs.png b/docs/1-inputs.png new file mode 100644 index 00000000..a12be3c7 Binary files /dev/null and b/docs/1-inputs.png differ diff --git a/docs/1-targets.png b/docs/1-targets.png new file mode 100644 index 00000000..f4548779 Binary files /dev/null and b/docs/1-targets.png differ diff --git a/docs/1-tensorflow.png b/docs/1-tensorflow.png new file mode 100644 index 00000000..262392a0 Binary files /dev/null and b/docs/1-tensorflow.png differ diff --git a/docs/1-torch.jpg b/docs/1-torch.jpg new file mode 100644 index 00000000..2b70d278 Binary files /dev/null and b/docs/1-torch.jpg differ diff --git a/docs/418.png b/docs/418.png new file mode 100644 index 00000000..34bfb3dc Binary files /dev/null and b/docs/418.png differ diff --git a/docs/5-inputs.png b/docs/5-inputs.png new file mode 100644 index 00000000..d58a5196 Binary files /dev/null and b/docs/5-inputs.png differ diff --git a/docs/5-targets.png b/docs/5-targets.png new file mode 100644 index 00000000..066d88db Binary files /dev/null and b/docs/5-targets.png differ diff --git a/docs/5-tensorflow.png b/docs/5-tensorflow.png new file mode 100644 index 00000000..591e1266 Binary files /dev/null and b/docs/5-tensorflow.png differ diff --git a/docs/5-torch.jpg b/docs/5-torch.jpg new file mode 100644 index 00000000..c989387a Binary files /dev/null and b/docs/5-torch.jpg differ diff --git a/docs/51-inputs.png b/docs/51-inputs.png new file mode 100644 index 00000000..1d8a5719 Binary files /dev/null and b/docs/51-inputs.png differ diff --git a/docs/51-targets.png b/docs/51-targets.png new file mode 100644 index 00000000..42012ddd Binary files /dev/null and b/docs/51-targets.png differ diff --git a/docs/51-tensorflow.png b/docs/51-tensorflow.png new file mode 100644 index 00000000..19075ce0 Binary files /dev/null and b/docs/51-tensorflow.png differ diff --git a/docs/51-torch.jpg b/docs/51-torch.jpg new file mode 100644 index 00000000..a4013e00 Binary files /dev/null and b/docs/51-torch.jpg differ diff --git a/docs/95-inputs.png b/docs/95-inputs.png new file mode 100644 index 00000000..6fc2ec26 Binary files /dev/null and b/docs/95-inputs.png differ diff --git a/docs/95-targets.png b/docs/95-targets.png new file mode 100644 index 00000000..f594d737 Binary files /dev/null and b/docs/95-targets.png differ diff --git a/docs/95-tensorflow.png b/docs/95-tensorflow.png new file mode 100644 index 00000000..e4c34d1c Binary files /dev/null and b/docs/95-tensorflow.png differ diff --git a/docs/95-torch.jpg b/docs/95-torch.jpg new file mode 100644 index 00000000..84bed739 Binary files /dev/null and b/docs/95-torch.jpg differ diff --git a/docs/ab.png b/docs/ab.png new file mode 100644 index 00000000..1dadedbd Binary files /dev/null and b/docs/ab.png differ diff --git a/docs/cityscapes.jpg b/docs/cityscapes.jpg new file mode 100755 index 00000000..dfebed73 Binary files /dev/null and b/docs/cityscapes.jpg differ diff --git a/docs/combine.png b/docs/combine.png new file mode 100644 index 00000000..72b35952 Binary files /dev/null and b/docs/combine.png differ diff --git a/docs/edges2handbags.jpg b/docs/edges2handbags.jpg new file mode 100755 index 00000000..4dbcac47 Binary files /dev/null and b/docs/edges2handbags.jpg differ diff --git a/docs/edges2shoes.jpg b/docs/edges2shoes.jpg new file mode 100755 index 00000000..55278d45 Binary files /dev/null and b/docs/edges2shoes.jpg differ diff --git a/docs/examples.jpg b/docs/examples.jpg new file mode 100644 index 00000000..b1f24d5e Binary files /dev/null and b/docs/examples.jpg differ diff --git a/docs/facades.jpg b/docs/facades.jpg new file mode 100755 index 00000000..b88704f6 Binary files /dev/null and b/docs/facades.jpg differ diff --git a/docs/maps.jpg b/docs/maps.jpg new file mode 100755 index 00000000..4ecdfec8 Binary files /dev/null and b/docs/maps.jpg differ diff --git a/docs/tensorboard-graph.png b/docs/tensorboard-graph.png new file mode 100644 index 00000000..fce1f62b Binary files /dev/null and b/docs/tensorboard-graph.png differ diff --git a/docs/tensorboard-image.png b/docs/tensorboard-image.png new file mode 100644 index 00000000..8a9581bf Binary files /dev/null and b/docs/tensorboard-image.png differ diff --git a/docs/tensorboard-scalar.png b/docs/tensorboard-scalar.png new file mode 100644 index 00000000..358028c9 Binary files /dev/null and b/docs/tensorboard-scalar.png differ diff --git a/docs/test-html.png b/docs/test-html.png new file mode 100644 index 00000000..aed11d0a Binary files /dev/null and b/docs/test-html.png differ diff --git a/pix2pix.py b/pix2pix.py new file mode 100644 index 00000000..3d7faa92 --- /dev/null +++ b/pix2pix.py @@ -0,0 +1,691 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import tensorflow as tf +import numpy as np +import argparse +import os +import json +import glob +import random +import collections +import math +import time + +parser = argparse.ArgumentParser() +parser.add_argument("--input_dir", required=True, help="path to folder containing images") +parser.add_argument("--mode", required=True, choices=["train", "test"]) +parser.add_argument("--output_dir", required=True, help="where to put output files") +parser.add_argument("--seed", type=int) +parser.add_argument("--checkpoint", default=None, help="directory with checkpoint to resume training from or use for testing") + +parser.add_argument("--max_steps", type=int, help="number of training steps (0 to disable)") +parser.add_argument("--max_epochs", type=int, help="number of training epochs") +parser.add_argument("--summary_freq", type=int, default=10, help="update summaries every summary_freq steps") +parser.add_argument("--progress_freq", type=int, default=50, help="display progress every progress_freq steps") +# to get tracing working on GPU, LD_LIBRARY_PATH may need to be modified: +# LD_LIBRARY_PATH=/usr/local/nvidia/lib:/usr/local/nvidia/lib64:/usr/local/cuda/extras/CUPTI/lib64 +parser.add_argument("--trace_freq", type=int, default=0, help="trace execution every trace_freq steps") +parser.add_argument("--display_freq", type=int, default=0, help="write current training images every display_freq steps") +parser.add_argument("--save_freq", type=int, default=5000, help="save model every save_freq steps, 0 to disable") + +parser.add_argument("--aspect_ratio", type=float, default=1.0, help="aspect ratio of output images (width/height)") +parser.add_argument("--lab_colorization", action="store_true", help="split A image into brightness (A) and color (B), ignore B image") +parser.add_argument("--batch_size", type=int, default=1, help="number of images in batch") +parser.add_argument("--which_direction", type=str, default="AtoB", choices=["AtoB", "BtoA"]) +parser.add_argument("--ngf", type=int, default=64, help="number of generator filters in first conv layer") +parser.add_argument("--ndf", type=int, default=64, help="number of discriminator filters in first conv layer") +parser.add_argument("--scale_size", type=int, default=286, help="scale images to this size before cropping to 256x256") +parser.add_argument("--flip", dest="flip", action="store_true", help="flip images horizontally") +parser.add_argument("--no_flip", dest="flip", action="store_false", help="don't flip images horizontally") +parser.set_defaults(flip=True) +parser.add_argument("--lr", type=float, default=0.0002, help="initial learning rate for adam") +parser.add_argument("--beta1", type=float, default=0.5, help="momentum term of adam") +parser.add_argument("--l1_weight", type=float, default=100.0, help="weight on L1 term for generator gradient") +parser.add_argument("--gan_weight", type=float, default=1.0, help="weight on GAN term for generator gradient") +a = parser.parse_args() + +EPS = 1e-12 +CROP_SIZE = 256 + +Examples = collections.namedtuple("Examples", "paths, inputs, targets, count, steps_per_epoch") +Model = collections.namedtuple("Model", "outputs, predict_real, predict_fake, discrim_loss, gen_loss_GAN, gen_loss_L1, train") + + +def conv(batch_input, out_channels, stride): + with tf.variable_scope("conv"): + in_channels = batch_input.get_shape()[3] + filter = tf.get_variable("filter", [4, 4, in_channels, out_channels], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.02)) + # [batch, in_height, in_width, in_channels], [filter_width, filter_height, in_channels, out_channels] + # => [batch, out_height, out_width, out_channels] + padded_input = tf.pad(batch_input, [[0, 0], [1, 1], [1, 1], [0, 0]], mode="CONSTANT") + conv = tf.nn.conv2d(padded_input, filter, [1, stride, stride, 1], padding="VALID") + return conv + + +def lrelu(x, a): + with tf.name_scope("lrelu"): + # adding these together creates the leak part and linear part + # then cancels them out by subtracting/adding an absolute value term + # leak: a*x/2 - a*abs(x)/2 + # linear: x/2 + abs(x)/2 + + # this block looks like it has 2 inputs on the graph unless we do this + x = tf.identity(x) + return (0.5 * (1 + a)) * x + (0.5 * (1 - a)) * tf.abs(x) + + +def batchnorm(input): + with tf.variable_scope("batchnorm"): + # this block looks like it has 3 inputs on the graph unless we do this + input = tf.identity(input) + + channels = input.get_shape()[3] + offset = tf.get_variable("offset", [channels], dtype=tf.float32, initializer=tf.zeros_initializer) + scale = tf.get_variable("scale", [channels], dtype=tf.float32, initializer=tf.random_normal_initializer(1.0, 0.02)) + mean, variance = tf.nn.moments(input, axes=[0, 1, 2], keep_dims=False) + variance_epsilon = 1e-5 + normalized = tf.nn.batch_normalization(input, mean, variance, offset, scale, variance_epsilon=variance_epsilon) + return normalized + + +def deconv(batch_input, out_channels): + with tf.variable_scope("deconv"): + batch, in_height, in_width, in_channels = [int(d) for d in batch_input.get_shape()] + filter = tf.get_variable("filter", [4, 4, out_channels, in_channels], dtype=tf.float32, initializer=tf.random_normal_initializer(0, 0.02)) + # [batch, in_height, in_width, in_channels], [filter_width, filter_height, out_channels, in_channels] + # => [batch, out_height, out_width, out_channels] + conv = tf.nn.conv2d_transpose(batch_input, filter, [batch, in_height * 2, in_width * 2, out_channels], [1, 2, 2, 1], padding="SAME") + return conv + + +def check_image(image): + assertion = tf.assert_equal(tf.shape(image)[-1], 3, message="image must have 3 color channels") + with tf.control_dependencies([assertion]): + image = tf.identity(image) + + if image.get_shape().ndims not in (3, 4): + raise ValueError("image must be either 3 or 4 dimensions") + + # make the last dimension 3 so that you can unstack the colors + shape = list(image.get_shape()) + shape[-1] = 3 + image.set_shape(shape) + return image + +# based on https://github.com/torch/image/blob/9f65c30167b2048ecbe8b7befdc6b2d6d12baee9/generic/image.c +def rgb_to_lab(srgb): + with tf.name_scope("rgb_to_lab"): + srgb = check_image(srgb) + srgb_pixels = tf.reshape(srgb, [-1, 3]) + + with tf.name_scope("srgb_to_xyz"): + linear_mask = tf.cast(srgb_pixels <= 0.04045, dtype=tf.float32) + exponential_mask = tf.cast(srgb_pixels > 0.04045, dtype=tf.float32) + rgb_pixels = (srgb_pixels / 12.92 * linear_mask) + (((srgb_pixels + 0.055) / 1.055) ** 2.4) * exponential_mask + rgb_to_xyz = tf.constant([ + # X Y Z + [0.412453, 0.212671, 0.019334], # R + [0.357580, 0.715160, 0.119193], # G + [0.180423, 0.072169, 0.950227], # B + ]) + xyz_pixels = tf.matmul(rgb_pixels, rgb_to_xyz) + + # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions + with tf.name_scope("xyz_to_cielab"): + # convert to fx = f(X/Xn), fy = f(Y/Yn), fz = f(Z/Zn) + + # normalize for D65 white point + xyz_normalized_pixels = tf.multiply(xyz_pixels, [1/0.950456, 1.0, 1/1.088754]) + + epsilon = 6/29 + linear_mask = tf.cast(xyz_normalized_pixels <= (epsilon**3), dtype=tf.float32) + exponential_mask = tf.cast(xyz_normalized_pixels > (epsilon**3), dtype=tf.float32) + fxfyfz_pixels = (xyz_normalized_pixels / (3 * epsilon**2) + 4/29) * linear_mask + (xyz_normalized_pixels ** (1/3)) * exponential_mask + + # convert to lab + fxfyfz_to_lab = tf.constant([ + # l a b + [ 0.0, 500.0, 0.0], # fx + [116.0, -500.0, 200.0], # fy + [ 0.0, 0.0, -200.0], # fz + ]) + lab_pixels = tf.matmul(fxfyfz_pixels, fxfyfz_to_lab) + tf.constant([-16.0, 0.0, 0.0]) + + return tf.reshape(lab_pixels, tf.shape(srgb)) + + +def lab_to_rgb(lab): + with tf.name_scope("lab_to_rgb"): + lab = check_image(lab) + lab_pixels = tf.reshape(lab, [-1, 3]) + + # https://en.wikipedia.org/wiki/Lab_color_space#CIELAB-CIEXYZ_conversions + with tf.name_scope("cielab_to_xyz"): + # convert to fxfyfz + lab_to_fxfyfz = tf.constant([ + # fx fy fz + [1/116.0, 1/116.0, 1/116.0], # l + [1/500.0, 0.0, 0.0], # a + [ 0.0, 0.0, -1/200.0], # b + ]) + fxfyfz_pixels = tf.matmul(lab_pixels + tf.constant([16.0, 0.0, 0.0]), lab_to_fxfyfz) + + # convert to xyz + epsilon = 6/29 + linear_mask = tf.cast(fxfyfz_pixels <= epsilon, dtype=tf.float32) + exponential_mask = tf.cast(fxfyfz_pixels > epsilon, dtype=tf.float32) + xyz_pixels = (3 * epsilon**2 * (fxfyfz_pixels - 4/29)) * linear_mask + (fxfyfz_pixels ** 3) * exponential_mask + + # denormalize for D65 white point + xyz_pixels = tf.multiply(xyz_pixels, [0.950456, 1.0, 1.088754]) + + with tf.name_scope("xyz_to_srgb"): + xyz_to_rgb = tf.constant([ + # r g b + [ 3.2404542, -0.9692660, 0.0556434], # x + [-1.5371385, 1.8760108, -0.2040259], # y + [-0.4985314, 0.0415560, 1.0572252], # z + ]) + rgb_pixels = tf.matmul(xyz_pixels, xyz_to_rgb) + # avoid a slightly negative number messing up the conversion + rgb_pixels = tf.clip_by_value(rgb_pixels, 0.0, 1.0) + linear_mask = tf.cast(rgb_pixels <= 0.0031308, dtype=tf.float32) + exponential_mask = tf.cast(rgb_pixels > 0.0031308, dtype=tf.float32) + srgb_pixels = (rgb_pixels * 12.92 * linear_mask) + ((rgb_pixels ** (1/2.4) * 1.055) - 0.055) * exponential_mask + + return tf.reshape(srgb_pixels, tf.shape(lab)) + + +def load_examples(): + input_paths = glob.glob(os.path.join(a.input_dir, "*.jpg")) + decode = tf.image.decode_jpeg + if len(input_paths) == 0: + input_paths = glob.glob(os.path.join(a.input_dir, "*.png")) + decode = tf.image.decode_png + + def get_name(path): + name, _ = os.path.splitext(os.path.basename(path)) + return name + + # if the image names are numbers, sort by the value rather than asciibetically + # having sorted inputs means that the outputs are sorted in test mode + if all(get_name(path).isdigit() for path in input_paths): + input_paths = sorted(input_paths, key=lambda path: int(get_name(path))) + else: + input_paths = sorted(input_paths) + + with tf.name_scope("load_images"): + path_queue = tf.train.string_input_producer(input_paths, shuffle=a.mode == "train") + reader = tf.WholeFileReader() + paths, contents = reader.read(path_queue) + raw_input = decode(contents) + raw_input = tf.image.convert_image_dtype(raw_input, dtype=tf.float32) + + assertion = tf.assert_equal(tf.shape(raw_input)[2], 3, message="image does not have 3 channels") + with tf.control_dependencies([assertion]): + raw_input = tf.identity(raw_input) + + raw_input.set_shape([None, None, 3]) + + if a.lab_colorization: + # load color and brightness from image, no B image exists here + lab = rgb_to_lab(raw_input) + L_chan, a_chan, b_chan = tf.unstack(lab, axis=2) + a_images = tf.expand_dims(L_chan, axis=2) / 50 - 1 # black and white with input range [0, 100] + b_images = tf.stack([a_chan, b_chan], axis=2) / 110 # color channels with input range ~[-110, 110], not exact + else: + # break apart image pair and move to range [-1, 1] + width = tf.shape(raw_input)[1] # [height, width, channels] + a_images = raw_input[:,:width//2,:] * 2 - 1 + b_images = raw_input[:,width//2:,:] * 2 - 1 + + if a.which_direction == "AtoB": + inputs, targets = [a_images, b_images] + elif a.which_direction == "BtoA": + inputs, targets = [b_images, a_images] + else: + raise Exception("invalid direction") + + # synchronize seed for image operations so that we do the same operations to both + # input and output images + seed = random.randint(0, 2**31 - 1) + def transform(image): + r = image + if a.flip: + r = tf.image.random_flip_left_right(r, seed=seed) + + # area produces a nice downscaling, but does nearest neighbor for upscaling + # assume we're going to be doing downscaling here + r = tf.image.resize_images(r, [a.scale_size, a.scale_size], method=tf.image.ResizeMethod.AREA) + + offset = tf.cast(tf.floor(tf.random_uniform([2], 0, a.scale_size - CROP_SIZE + 1, seed=seed)), dtype=tf.int32) + if a.scale_size > CROP_SIZE: + r = tf.image.crop_to_bounding_box(r, offset[0], offset[1], CROP_SIZE, CROP_SIZE) + elif a.scale_size < CROP_SIZE: + raise Exception("scale size cannot be less than crop size") + return r + + with tf.name_scope("input_images"): + input_images = transform(inputs) + + with tf.name_scope("target_images"): + target_images = transform(targets) + + paths, inputs, targets = tf.train.batch([paths, input_images, target_images], batch_size=a.batch_size) + steps_per_epoch = int(math.ceil(len(input_paths) / a.batch_size)) + + return Examples( + paths=paths, + inputs=inputs, + targets=targets, + count=len(input_paths), + steps_per_epoch=steps_per_epoch, + ) + + +def create_model(inputs, targets): + def create_generator(generator_inputs, generator_outputs_channels): + layers = [] + + # encoder_1: [batch, 256, 256, in_channels] => [batch, 128, 128, ngf] + with tf.variable_scope("encoder_1"): + output = conv(generator_inputs, a.ngf, stride=2) + layers.append(output) + + layer_specs = [ + a.ngf * 2, # encoder_2: [batch, 128, 128, ngf] => [batch, 64, 64, ngf * 2] + a.ngf * 4, # encoder_3: [batch, 64, 64, ngf * 2] => [batch, 32, 32, ngf * 4] + a.ngf * 8, # encoder_4: [batch, 32, 32, ngf * 4] => [batch, 16, 16, ngf * 8] + a.ngf * 8, # encoder_5: [batch, 16, 16, ngf * 8] => [batch, 8, 8, ngf * 8] + a.ngf * 8, # encoder_6: [batch, 8, 8, ngf * 8] => [batch, 4, 4, ngf * 8] + a.ngf * 8, # encoder_7: [batch, 4, 4, ngf * 8] => [batch, 2, 2, ngf * 8] + a.ngf * 8, # encoder_8: [batch, 2, 2, ngf * 8] => [batch, 1, 1, ngf * 8] + ] + + for out_channels in layer_specs: + with tf.variable_scope("encoder_%d" % (len(layers) + 1)): + rectified = lrelu(layers[-1], 0.2) + # [batch, in_height, in_width, in_channels] => [batch, in_height/2, in_width/2, out_channels] + convolved = conv(rectified, out_channels, stride=2) + output = batchnorm(convolved) + layers.append(output) + + layer_specs = [ + (a.ngf * 8, 0.5), # decoder_8: [batch, 1, 1, ngf * 8] => [batch, 2, 2, ngf * 8 * 2] + (a.ngf * 8, 0.5), # decoder_7: [batch, 2, 2, ngf * 8 * 2] => [batch, 4, 4, ngf * 8 * 2] + (a.ngf * 8, 0.5), # decoder_6: [batch, 4, 4, ngf * 8 * 2] => [batch, 8, 8, ngf * 8 * 2] + (a.ngf * 8, 0.0), # decoder_5: [batch, 8, 8, ngf * 8 * 2] => [batch, 16, 16, ngf * 8 * 2] + (a.ngf * 4, 0.0), # decoder_4: [batch, 16, 16, ngf * 8 * 2] => [batch, 32, 32, ngf * 4 * 2] + (a.ngf * 2, 0.0), # decoder_3: [batch, 32, 32, ngf * 4 * 2] => [batch, 64, 64, ngf * 2 * 2] + (a.ngf, 0.0), # decoder_2: [batch, 64, 64, ngf * 2 * 2] => [batch, 128, 128, ngf * 2] + ] + + num_encoder_layers = len(layers) + for decoder_layer, (out_channels, dropout) in enumerate(layer_specs): + skip_layer = num_encoder_layers - decoder_layer - 1 + with tf.variable_scope("decoder_%d" % (skip_layer + 1)): + if decoder_layer == 0: + # first decoder layer doesn't have skip connections + # since it is directly connected to the skip_layer + input = layers[-1] + else: + input = tf.concat_v2([layers[-1], layers[skip_layer]], axis=3) + + rectified = tf.nn.relu(input) + # [batch, in_height, in_width, in_channels] => [batch, in_height*2, in_width*2, out_channels] + output = deconv(rectified, out_channels) + output = batchnorm(output) + + if dropout > 0.0: + output = tf.nn.dropout(output, keep_prob=1 - dropout) + + layers.append(output) + + # decoder_1: [batch, 128, 128, ngf * 2] => [batch, 256, 256, generator_outputs_channels] + with tf.variable_scope("decoder_1"): + input = tf.concat_v2([layers[-1], layers[0]], axis=3) + rectified = tf.nn.relu(input) + output = deconv(rectified, generator_outputs_channels) + output = tf.tanh(output) + layers.append(output) + + return layers[-1] + + def create_discriminator(discrim_inputs, discrim_targets): + n_layers = 3 + layers = [] + + # 2x [batch, height, width, in_channels] => [batch, height, width, in_channels * 2] + input = tf.concat_v2([discrim_inputs, discrim_targets], axis=3) + + # layer_1: [batch, 256, 256, in_channels * 2] => [batch * 2, 128, 128, ndf] + with tf.variable_scope("layer_1"): + convolved = conv(input, a.ndf, stride=2) + rectified = lrelu(convolved, 0.2) + layers.append(rectified) + + # layer_2: [batch * 2, 128, 128, ndf] => [batch * 2, 64, 64, ndf * 2] + # layer_3: [batch * 2, 64, 64, ndf * 2] => [batch * 2, 32, 32, ndf * 4] + # layer_4: [batch * 2, 32, 32, ndf * 4] => [batch * 2, 31, 31, ndf * 8] + for i in range(n_layers): + with tf.variable_scope("layer_%d" % (len(layers) + 1)): + out_channels = a.ndf * min(2**(i+1), 8) + stride = 1 if i == n_layers - 1 else 2 # last layer here has stride 1 + convolved = conv(layers[-1], out_channels, stride=stride) + normalized = batchnorm(convolved) + rectified = lrelu(normalized, 0.2) + layers.append(rectified) + + # layer_5: [batch * 2, 31, 31, ndf * 8] => [batch * 2, 30, 30, 1] + with tf.variable_scope("layer_%d" % (len(layers) + 1)): + convolved = conv(rectified, out_channels=1, stride=1) + output = tf.sigmoid(convolved) + layers.append(output) + + return layers[-1] + + with tf.variable_scope("generator") as scope: + out_channels = int(targets.get_shape()[-1]) + outputs = create_generator(inputs, out_channels) + + # create two copies of discriminator, one for real pairs and one for fake pairs + # they share the same underlying variables + with tf.name_scope("real_discriminator"): + with tf.variable_scope("discriminator"): + # 2x [batch, height, width, channels] => [batch, 30, 30, 1] + predict_real = create_discriminator(inputs, targets) + + with tf.name_scope("fake_discriminator"): + with tf.variable_scope("discriminator", reuse=True): + # 2x [batch, height, width, channels] => [batch, 30, 30, 1] + predict_fake = create_discriminator(inputs, outputs) + + with tf.name_scope("discriminator_loss"): + # minimizing -tf.log will try to get inputs to 1 + # predict_real => 1 + # predict_fake => 0 + discrim_loss = tf.reduce_mean(-(tf.log(predict_real + EPS) + tf.log(1 - predict_fake + EPS))) + + with tf.name_scope("generator_loss"): + # predict_fake => 1 + # abs(targets - outputs) => 0 + gen_loss_GAN = tf.reduce_mean(-tf.log(predict_fake + EPS)) + gen_loss_L1 = tf.reduce_mean(tf.abs(targets - outputs)) + gen_loss = gen_loss_GAN * a.gan_weight + gen_loss_L1 * a.l1_weight + + with tf.name_scope("discriminator_train"): + discrim_tvars = [var for var in tf.trainable_variables() if var.name.startswith("discriminator")] + discrim_optim = tf.train.AdamOptimizer(a.lr, a.beta1) + discrim_train = discrim_optim.minimize(discrim_loss, var_list=discrim_tvars) + + with tf.name_scope("generator_train"): + with tf.control_dependencies([discrim_train]): + gen_tvars = [var for var in tf.trainable_variables() if var.name.startswith("generator")] + gen_optim = tf.train.AdamOptimizer(a.lr, a.beta1) + gen_train = gen_optim.minimize(gen_loss, var_list=gen_tvars) + + ema = tf.train.ExponentialMovingAverage(decay=0.99) + update_losses = ema.apply([discrim_loss, gen_loss_GAN, gen_loss_L1]) + + global_step = tf.contrib.framework.get_or_create_global_step() + incr_global_step = tf.assign(global_step, global_step+1) + + return Model( + predict_real=predict_real, + predict_fake=predict_fake, + discrim_loss=ema.average(discrim_loss), + gen_loss_GAN=ema.average(gen_loss_GAN), + gen_loss_L1=ema.average(gen_loss_L1), + outputs=outputs, + train=tf.group(update_losses, incr_global_step, gen_train), + ) + + +def save_images(fetches, image_dir, step=None): + filesets = [] + for i, in_path in enumerate(fetches["paths"]): + name, _ = os.path.splitext(os.path.basename(in_path)) + fileset = {"name": name, "step": step} + for kind in ["inputs", "outputs", "targets"]: + filename = name + "-" + kind + ".png" + if step is not None: + filename = "%08d-%s" % (step, filename) + fileset[kind] = filename + out_path = os.path.join(image_dir, filename) + contents = fetches[kind][i] + with open(out_path, "w") as f: + f.write(contents) + filesets.append(fileset) + return filesets + + +def append_index(filesets, step=False): + index_path = os.path.join(a.output_dir, "index.html") + if os.path.exists(index_path): + index = open(index_path, "a") + else: + index = open(index_path, "w") + index.write("") + if step: + index.write("") + index.write("") + + for fileset in filesets: + index.write("") + + if step: + index.write("" % fileset["step"]) + index.write("" % fileset["name"]) + + for kind in ["inputs", "outputs", "targets"]: + index.write("" % fileset[kind]) + + index.write("") + return index_path + + +def main(): + if a.seed is None: + a.seed = random.randint(0, 2**31 - 1) + + tf.set_random_seed(a.seed) + np.random.seed(a.seed) + random.seed(a.seed) + + if not os.path.exists(a.output_dir): + os.makedirs(a.output_dir) + + if a.mode == "test": + if a.checkpoint is None: + raise Exception("checkpoint required for test mode") + + # load some options from the checkpoint + options = {"which_direction", "ngf", "ndf", "lab_colorization"} + with open(os.path.join(a.checkpoint, "options.json")) as f: + for key, val in json.loads(f.read()).iteritems(): + if key in options: + print("loaded", key, "=", val) + setattr(a, key, val) + # disable these features in test mode + a.scale_size = CROP_SIZE + a.flip = False + + for k, v in a._get_kwargs(): + print(k, "=", v) + + with open(os.path.join(a.output_dir, "options.json"), "w") as f: + f.write(json.dumps(vars(a), sort_keys=True, indent=4)) + + examples = load_examples() + + print("examples count = %d" % examples.count) + + model = create_model(examples.inputs, examples.targets) + + def deprocess(image): + if a.aspect_ratio != 1.0: + # upscale to correct aspect ratio + size = [CROP_SIZE, int(round(CROP_SIZE * a.aspect_ratio))] + image = tf.image.resize_images(image, size=size, method=tf.image.ResizeMethod.BICUBIC) + + if a.lab_colorization: + # colorization mode images can be 1 channel (L) or 2 channels (a,b) + num_channels = int(image.get_shape()[-1]) + if num_channels == 1: + return tf.image.convert_image_dtype((image + 1) / 2, dtype=tf.uint8, saturate=True) + elif num_channels == 2: + # (a, b) color channels, convert to rgb + # a_chan and b_chan have range [-1, 1] => [-110, 110] + a_chan, b_chan = tf.unstack(image * 110, axis=3) + # get L_chan from inputs or targets + if a.which_direction == "AtoB": + brightness = examples.inputs + elif a.which_direction == "BtoA": + brightness = examples.targets + else: + raise Exception("invalid direction") + # L_chan has range [-1, 1] => [0, 100] + L_chan = tf.squeeze((brightness + 1) / 2 * 100, axis=3) + lab = tf.stack([L_chan, a_chan, b_chan], axis=3) + rgb = lab_to_rgb(lab) + return tf.image.convert_image_dtype(rgb, dtype=tf.uint8, saturate=True) + else: + raise Exception("unexpected number of channels") + else: + return tf.image.convert_image_dtype((image + 1) / 2, dtype=tf.uint8, saturate=True) + + # reverse any processing on images so they can be written to disk or displayed to user + with tf.name_scope("deprocess_inputs"): + deprocessed_inputs = deprocess(examples.inputs) + + with tf.name_scope("deprocess_targets"): + deprocessed_targets = deprocess(examples.targets) + + with tf.name_scope("deprocess_outputs"): + deprocessed_outputs = deprocess(model.outputs) + + with tf.name_scope("encode_images"): + display_fetches = { + "paths": examples.paths, + "inputs": tf.map_fn(tf.image.encode_png, deprocessed_inputs, dtype=tf.string, name="input_pngs"), + "targets": tf.map_fn(tf.image.encode_png, deprocessed_targets, dtype=tf.string, name="target_pngs"), + "outputs": tf.map_fn(tf.image.encode_png, deprocessed_outputs, dtype=tf.string, name="output_pngs"), + } + + # summaries + with tf.name_scope("inputs_summary"): + tf.summary.image("inputs", deprocessed_inputs) + + with tf.name_scope("targets_summary"): + tf.summary.image("targets", deprocessed_targets) + + with tf.name_scope("outputs_summary"): + tf.summary.image("outputs", deprocessed_outputs) + + with tf.name_scope("predict_real_summary"): + tf.summary.image("predict_real", tf.image.convert_image_dtype(model.predict_real, dtype=tf.uint8)) + + with tf.name_scope("predict_fake_summary"): + tf.summary.image("predict_fake", tf.image.convert_image_dtype(model.predict_fake, dtype=tf.uint8)) + + tf.summary.scalar("discriminator_loss", model.discrim_loss) + tf.summary.scalar("generator_loss_GAN", model.gen_loss_GAN) + tf.summary.scalar("generator_loss_L1", model.gen_loss_L1) + + with tf.name_scope("parameter_count"): + parameter_count = tf.reduce_sum([tf.reduce_prod(tf.shape(v)) for v in tf.trainable_variables()]) + + image_dir = os.path.join(a.output_dir, "images") + if not os.path.exists(image_dir): + os.makedirs(image_dir) + + saver = tf.train.Saver(max_to_keep=1) + + logdir = a.output_dir if (a.trace_freq > 0 or a.summary_freq > 0) else None + sv = tf.train.Supervisor(logdir=logdir, save_summaries_secs=0, saver=None) + with sv.managed_session() as sess: + print("parameter_count =", sess.run(parameter_count)) + + if a.checkpoint is not None: + print("loading model from checkpoint") + checkpoint = tf.train.latest_checkpoint(a.checkpoint) + saver.restore(sess, checkpoint) + + if a.mode == "test": + # testing + # run a single epoch over all input data + for step in range(examples.steps_per_epoch): + results = sess.run(display_fetches) + filesets = save_images(results, image_dir) + for i, path in enumerate(results["paths"]): + print(step * a.batch_size + i + 1, "evaluated image", os.path.basename(path)) + index_path = append_index(filesets) + + print("wrote index at", index_path) + else: + # training + max_steps = 2**32 + if a.max_epochs is not None: + max_steps = examples.steps_per_epoch * a.max_epochs + if a.max_steps is not None: + max_steps = a.max_steps + + start_time = time.time() + for step in range(max_steps): + def should(freq): + return freq > 0 and ((step + 1) % freq == 0 or step == max_steps - 1) + + options = None + run_metadata = None + if should(a.trace_freq): + options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) + run_metadata = tf.RunMetadata() + + fetches = { + "train": model.train, + "global_step": sv.global_step, + } + + if should(a.progress_freq): + fetches["discrim_loss"] = model.discrim_loss + fetches["gen_loss_GAN"] = model.gen_loss_GAN + fetches["gen_loss_L1"] = model.gen_loss_L1 + + if should(a.summary_freq): + fetches["summary"] = sv.summary_op + + if should(a.display_freq): + fetches["display"] = display_fetches + + results = sess.run(fetches, options=options, run_metadata=run_metadata) + + if should(a.summary_freq): + sv.summary_writer.add_summary(results["summary"], results["global_step"]) + + if should(a.display_freq): + print("saving display images") + filesets = save_images(results["display"], image_dir, step=results["global_step"]) + append_index(filesets, step=True) + + if should(a.trace_freq): + print("recording trace") + sv.summary_writer.add_run_metadata(run_metadata, "step_%d" % results["global_step"]) + + if should(a.progress_freq): + global_step = results["global_step"] + print("progress epoch %d step %d image/sec %0.1f" % (global_step // examples.steps_per_epoch, global_step % examples.steps_per_epoch, global_step * a.batch_size / (time.time() - start_time))) + print("discrim_loss", results["discrim_loss"]) + print("gen_loss_GAN", results["gen_loss_GAN"]) + print("gen_loss_L1", results["gen_loss_L1"]) + + if should(a.save_freq): + print("saving model") + saver.save(sess, os.path.join(a.output_dir, "model"), global_step=sv.global_step) + + if sv.should_stop(): + break + + +main() diff --git a/tools/download-dataset.py b/tools/download-dataset.py new file mode 100644 index 00000000..6e90c5cf --- /dev/null +++ b/tools/download-dataset.py @@ -0,0 +1,21 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import urllib2 +import sys +import tarfile +import tempfile +import shutil + +dataset = sys.argv[1] +url = "https://people.eecs.berkeley.edu/~tinghuiz/projects/pix2pix/datasets/%s.tar.gz" % dataset +with tempfile.TemporaryFile() as tmp: + print("downloading", url) + shutil.copyfileobj(urllib2.urlopen(url), tmp) + print("extracting") + tmp.seek(0) + tar = tarfile.open(fileobj=tmp) + tar.extractall() + tar.close() + print("done") diff --git a/tools/process.py b/tools/process.py new file mode 100644 index 00000000..a111e641 --- /dev/null +++ b/tools/process.py @@ -0,0 +1,246 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import argparse +import os +import random +import tensorflow as tf +import numpy as np + + +parser = argparse.ArgumentParser() +parser.add_argument("--input_dir", required=True, help="path to folder containing images") +parser.add_argument("--output_dir", required=True, help="output path") +parser.add_argument("--operation", required=True, choices=["grayscale", "resize", "blank", "combine"]) +parser.add_argument("--pad", action="store_true", help="pad instead of crop for resize operation") +parser.add_argument("--size", type=int, default=256, help="size to use for resize operation") +parser.add_argument("--b_dir", type=str, help="path to folder containing B images for combine operation") +a = parser.parse_args() + + +def grayscale(img): + img = img / 255 + img = 0.299 * img[:,:,0] + 0.587 * img[:,:,1] + 0.114 * img[:,:,2] + return (np.expand_dims(img, axis=2) * 255).astype(np.uint8) + + +def normalize(img): + img -= img.min() + img /= img.max() + return img + + +def create_op(func, **placeholders): + op = func(**placeholders) + + def f(**kwargs): + feed_dict = {} + for argname, argvalue in kwargs.iteritems(): + placeholder = placeholders[argname] + feed_dict[placeholder] = argvalue + return op.eval(feed_dict=feed_dict) + + return f + +downscale = create_op( + func=tf.image.resize_images, + images=tf.placeholder(tf.float32, [None, None, None]), + size=tf.placeholder(tf.int32, [2]), + method=tf.image.ResizeMethod.AREA, +) + +upscale = create_op( + func=tf.image.resize_images, + images=tf.placeholder(tf.float32, [None, None, None]), + size=tf.placeholder(tf.int32, [2]), + method=tf.image.ResizeMethod.BICUBIC, +) + +decode_jpeg = create_op( + func=tf.image.decode_jpeg, + contents=tf.placeholder(tf.string), +) + +decode_png = create_op( + func=tf.image.decode_png, + contents=tf.placeholder(tf.string), +) + +rgb_to_grayscale = create_op( + func=tf.image.rgb_to_grayscale, + images=tf.placeholder(tf.float32), +) + +grayscale_to_rgb = create_op( + func=tf.image.grayscale_to_rgb, + images=tf.placeholder(tf.float32), +) + +encode_jpeg = create_op( + func=tf.image.encode_jpeg, + image=tf.placeholder(tf.uint8), +) + +encode_png = create_op( + func=tf.image.encode_png, + image=tf.placeholder(tf.uint8), +) + +crop = create_op( + func=tf.image.crop_to_bounding_box, + image=tf.placeholder(tf.float32), + offset_height=tf.placeholder(tf.int32, []), + offset_width=tf.placeholder(tf.int32, []), + target_height=tf.placeholder(tf.int32, []), + target_width=tf.placeholder(tf.int32, []), +) + +pad = create_op( + func=tf.image.pad_to_bounding_box, + image=tf.placeholder(tf.float32), + offset_height=tf.placeholder(tf.int32, []), + offset_width=tf.placeholder(tf.int32, []), + target_height=tf.placeholder(tf.int32, []), + target_width=tf.placeholder(tf.int32, []), +) + +to_uint8 = create_op( + func=tf.image.convert_image_dtype, + image=tf.placeholder(tf.float32), + dtype=tf.uint8, + saturate=True, +) + +to_float32 = create_op( + func=tf.image.convert_image_dtype, + image=tf.placeholder(tf.uint8), + dtype=tf.float32, +) + + +def load(path): + contents = open(path).read() + _, ext = os.path.splitext(path.lower()) + + if ext == ".jpg": + image = decode_jpeg(contents=contents) + elif ext == ".png": + image = decode_png(contents=contents) + else: + raise Exception("invalid image suffix") + + return to_float32(image=image) + + +def find(d): + result = [] + for filename in os.listdir(d): + _, ext = os.path.splitext(filename.lower()) + if ext == ".jpg" or ext == ".png": + result.append(os.path.join(d, filename)) + result.sort() + return result + + +def save(image, path): + _, ext = os.path.splitext(path.lower()) + image = to_uint8(image=image) + if ext == ".jpg": + encoded = encode_jpeg(image=image) + elif ext == ".png": + encoded = encode_png(image=image) + else: + raise Exception("invalid image suffix") + + if os.path.exists(path): + raise Exception("file already exists at " + path) + + with open(path, "w") as f: + f.write(encoded) + + +def png_path(path): + basename, _ = os.path.splitext(os.path.basename(path)) + return os.path.join(os.path.dirname(path), basename + ".png") + + +def main(): + random.seed(0) + + if not os.path.exists(a.output_dir): + os.makedirs(a.output_dir) + + with tf.Session() as sess: + for src_path in find(a.input_dir): + dst_path = png_path(os.path.join(a.output_dir, os.path.basename(src_path))) + print(src_path, "->", dst_path) + src = load(src_path) + + if a.operation == "grayscale": + dst = grayscale_to_rgb(images=rgb_to_grayscale(images=src)) + elif a.operation == "resize": + height, width, _ = src.shape + dst = src + if height != width: + if a.pad: + size = max(height, width) + # pad to correct ratio + oh = (size - height) // 2 + ow = (size - width) // 2 + dst = pad(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size) + else: + # crop to correct ratio + size = min(height, width) + oh = (height - size) // 2 + ow = (width - size) // 2 + dst = crop(image=dst, offset_height=oh, offset_width=ow, target_height=size, target_width=size) + + assert(dst.shape[0] == dst.shape[1]) + + size, _, _ = dst.shape + if size > a.size: + dst = downscale(images=dst, size=[a.size, a.size]) + elif size < a.size: + dst = upscale(images=dst, size=[a.size, a.size]) + elif a.operation == "blank": + height, width, _ = src.shape + if height != width: + raise Exception("non-square image") + + image_size = width + size = int(image_size * 0.3) + offset = int(image_size / 2 - size / 2) + + dst = src + dst[offset:offset + size,offset:offset + size,:] = np.ones([size, size, 3]) + elif a.operation == "combine": + if a.b_dir is None: + raise Exception("missing b_dir") + + # find corresponding file in b_dir, could have a different extension + basename, _ = os.path.splitext(os.path.basename(src_path)) + for ext in [".png", ".jpg"]: + sibling_path = os.path.join(a.b_dir, basename + ext) + if os.path.exists(sibling_path): + sibling = load(sibling_path) + break + else: + raise Exception("could not find sibling image for " + src_path) + + # make sure that dimensions are correct + height, width, _ = src.shape + if height != sibling.shape[0] or width != sibling.shape[1]: + raise Exception("differing sizes") + + # remove alpha channel + src = src[:,:,:3] + sibling = sibling[:,:,:3] + dst = np.concatenate([src, sibling], axis=1) + else: + raise Exception("invalid operation") + + save(dst, dst_path) + + +main() diff --git a/tools/split.py b/tools/split.py new file mode 100644 index 00000000..ef93f2cc --- /dev/null +++ b/tools/split.py @@ -0,0 +1,40 @@ +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import random +import argparse +import glob +import os + + +parser = argparse.ArgumentParser() +parser.add_argument("--dir", type=str, required=True, help="path to folder containing images") +parser.add_argument("--train_frac", type=float, default=0.8, help="percentage of images to use for training set") +parser.add_argument("--test_frac", type=float, default=0.0, help="percentage of images to use for test set") +a = parser.parse_args() + + +def main(): + random.seed(0) + + files = glob.glob(os.path.join(a.dir, "*.png")) + assignments = [] + assignments.extend(["train"] * int(a.train_frac * len(files))) + assignments.extend(["test"] * int(a.test_frac * len(files))) + assignments.extend(["val"] * int(len(files) - len(assignments))) + random.shuffle(assignments) + + for name in ["train", "val", "test"]: + if name in assignments: + d = os.path.join(a.dir, name) + if not os.path.exists(d): + os.makedirs(d) + + print(len(files), len(assignments)) + for inpath, assignment in zip(files, assignments): + outpath = os.path.join(a.dir, assignment, os.path.basename(inpath)) + print(inpath, "->", outpath) + os.rename(inpath, outpath) + +main()
stepnameinputoutputtarget
%d%s