Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
PatWie committed Jul 18, 2017
0 parents commit e3172b6
Show file tree
Hide file tree
Showing 23 changed files with 2,728 additions and 0 deletions.
136 changes: 136 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class

# C extensions
*.so

# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
*.egg-info/
.installed.cfg
*.egg

# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec

# Installer logs
pip-log.txt
pip-delete-this-directory.txt

# Unit test / coverage reports
htmlcov/
.tox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/

# Translations
*.mo
*.pot

# Django stuff:
*.log
local_settings.py

# Flask stuff:
instance/
.webassets-cache

# Scrapy stuff:
.scrapy

# Sphinx documentation
docs/_build/

# PyBuilder
target/

# Jupyter Notebook
.ipynb_checkpoints

# pyenv
.python-version

# celery beat schedule file
celerybeat-schedule

# SageMath parsed files
*.sage.py

# Environments
.env
.venv
env/
venv/
ENV/

# Spyder project settings
.spyderproject
.spyproject

# Rope project settings
.ropeproject

# mkdocs documentation
/site

# mypy
.mypy_cache/

# Prerequisites
*.d

# Compiled Object files
*.slo
*.lo
*.o
*.obj

# Precompiled Headers
*.gch
*.pch

# Compiled Dynamic libraries
*.so
*.dylib
*.dll

# Fortran module files
*.mod
*.smod

# Compiled Static libraries
*.lai
*.la
*.a
*.lib

# Executables
*.exe
*.out
*.app

local/
train_log*/
results/
81 changes: 81 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
# Learning Blind Motion Deblurring

TensorFlow implementation of multi-frame blind deconvolution:

**Learning Blind Motion Deblurring**<br>
Patrick Wieschollek, Michael Hirsch, Bernhard Schölkopf, Hendrik P.A. Lensch<br>
*ICCV 2017*

![results](https://user-images.githubusercontent.com/6756603/28306964-93f64ce2-6ba1-11e7-8cdc-4f112d9d6059.jpg)


## Prerequisites
### 1. Get YouTube videos

The first step is to gather videos from some arbitrary sources. We use YouTube to get some videos with diverse content and recording equipment. To download these videos, we use the python-tool `youtube-dl`.

```bash
pip install youtube-dl --user
```

Some examples are given in `download_videos.sh`. Note, you can use whatever mp4 video you want to use for this task. In fact, for this re-implementation we use some other videos, which also work well.

### 2. Generate Synthetic Motion Blur

Now, we use optical flow to synthetically add motion blur. We used the most simple OpticalFlow method, wich provides reasonable results (we average frames anyway):

```bash
cd synthblur
mkdir build && cd build
cmake ..
make all
```

To convert a video `input.mp4` into a blurry version, run

```bash
./synthblur/build/convert "input.mp4"
```

This gives you multiple outputs:
- 'input.mp4_blurry.mp4'
- 'input.mp4_sharp.mp4'
- 'input.mp4_flow.mp4'

Adding blur from synthetic camera shake is done on-the-fly (see `psf.py`).

### 3. Building a Database
For performance reasons we randomly sample frames from all videos beforehand and store 5+5 consecutive frames (sharp+blurry) into an LMDB file (for training/validation/testing).

I use

```bash
#!/bin/bash
for i in `seq 1 30`; do
python data_sampler.py --pattern '/graphics/scratch/wieschol/YouTubeDataset/train/*_blurry.mp4' --lmdb /graphics/scratch/wieschol/YouTubeDataset/train$i.lmdb --num 5000
done

for i in `seq 1 10`; do
python data_sampler.py --pattern '/graphics/scratch/wieschol/YouTubeDataset/val/*_blurry.mp4' --lmdb /graphics/scratch/wieschol/YouTubeDataset/val$i.lmdb --num 5000
done

```

To visualize the training examples just run

```bash
python data_provider.py --lmdb /graphics/scratch/wieschol/YouTubeDataset/train1.lmdb --show --num 5000
```


## Training

This re-imlementation uses [TensorPack](https://github.com/ppwwyyxx/tensorpack) instead of a custom library. Starting training is done by

```bash
python learning_blind_motion_deblurring_singlescale.py --gpu 0,1 --data path/to/lmdb-files/
```


## Further experiments
We further tried a convLSTM/convGRU and a multi-scale approach (instead of the simple test from the paper). These script are available in `additional_scripts`.
165 changes: 165 additions & 0 deletions additional_scripts/convrnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,165 @@
from abc import ABCMeta, abstractmethod, abstractproperty
import tensorflow as tf
from tensorpack import *
from tensorpack.tfutils.scope_utils import auto_reuse_variable_scope

"""
References:
http://colah.github.io/posts/2015-08-Understanding-LSTMs/
http://www.wildml.com/2015/10/recurrent-neural-network-tutorial-part-4-implementing-a-grulstm-rnn-with-python-and-theano/
"""


class ConvRNNCell(object):
__metaclass__ = ABCMeta

def __init__(self, tensor_shape, out_channel, kernel_shape, nl=tf.nn.tanh, normalize_fn=None):
"""Abstract representation for 2D recurrent cells.
Args:
tensor_shape: shape of inputs (must be fully specified)
out_channel: number of output channels
kernel_shape: size of filters
nl (TYPE, optional): non-linearity (default: tf.nn.tanh)
normalize_fn (None, optional): normalization steps (e.g. tf.contrib.layers.layer_norm)
"""
super(ConvRNNCell, self).__init__()
self.state_tensor = None

assert len(tensor_shape), "tensor_shape should have 4 dims [BHWC]"

self.input_shape = tensor_shape
self.out_channel = out_channel
self.kernel_shape = kernel_shape

self.nl = nl
self.normalize_fn = normalize_fn

@abstractproperty
def default_state(self):
pass

def state(self):
if self.state_tensor is None:
self.state_tensor = self.default_state()
return self.state_tensor

@abstractmethod
def _calc(self, tensor):
pass

def __call__(self, tensor):
return self._calc(tensor)


class ConvLSTMCell(ConvRNNCell):
"""Represent LSTM-layer using convolutions.
conv_gates:
i = sigma(x*U1 + s*W1) input gate
f = sigma(x*U2 + s*W2) forget gate
o = sigma(x*U3 + s*W3) output gate
g = tanh(x*U4 + s*W4) candidate hidden state
memory update:
c = c * f + g * i internal memory
s = tanh(c) * o output hiden state
"""
def default_state(self):
b, h, w, c = self.input_shape
return (tf.zeros([b, h, w, self.out_channel]), tf.zeros([b, h, w, self.out_channel]))

@auto_reuse_variable_scope
def __call__(self, x):
c, s = self.state()

xs = tf.concat(axis=3, values=[x, s])
igfo = Conv2D('conv_gates', xs, 4 * self.out_channel, self.kernel_shape,
nl=tf.identity, use_bias=(self.normalize_fn is None))
# i = input_gate, g = hidden state, f = forget_gate, o = output_gate
i, g, f, o = tf.split(axis=3, num_or_size_splits=4, value=igfo)

if self.normalize_fn is not None:
i, g = self.normalize_fn(i), self.normalize_fn(g)
f, o = self.normalize_fn(f), self.normalize_fn(o)

i, g = tf.nn.sigmoid(i), self.nl(g)
f, o = tf.nn.sigmoid(f), tf.nn.sigmoid(o)

# memory update
c = c * f + g * i
if self.normalize_fn is not None:
c = self.normalize_fn(c)

# output
s = self.nl(c) * tf.nn.sigmoid(o)
self.state_tensor = (c, s)

return s


class ConvGRUCell(ConvRNNCell):
"""Represent GRU-layer using convolutions.
z = sigma(x*U1 + s*W1) update gate
r = sigma(x*U2 + s*W2) reset gate
h = tanh(x*U3 + (s*r)*W3)
s = (1-z)*h + z*s
"""
def default_state(self):
"""GRU just uses the output as the state for the next computation.
"""
b, h, w, c = self.input_shape
return tf.zeros([b, h, w, self.out_channel])

@auto_reuse_variable_scope
def _calc(self, x):
s = self.state()

# we concat x and s to reduce the number of conv-calls
xs = tf.concat(axis=3, values=[x, s])
zr = Conv2D('conv_zr', xs, 2 * self.out_channel, self.kernel_shape,
nl=tf.identity, use_bias=(self.normalize_fn is None))

# z (update gate), r (reset gate)
z, r = tf.split(axis=3, num_or_size_splits=2, value=zr)

if self.normalize_fn is not None:
r, z = self.normalize_fn(r), self.normalize_fn(z)

r, z = tf.sigmoid(r), tf.sigmoid(z)

h = tf.concat(axis=3, values=[x, s * r])
h = Conv2D('conv_h', h, self.out_channel, self.kernel_shape,
nl=tf.identity, use_bias=(self.normalize_fn is None))

if self.normalize_fn is not None:
h = self.normalize_fn(h)

h = self.nl(h)
s = (1 - z) * h + z * s

self.state_tensor = s

return s


@layer_register()
def ConvRNN(x, cell):
assert len(x.get_shape().as_list()) == 4, "input in ConvRNN should be B,H,W,C"
return cell(x)


@layer_register()
def ConvRNN_unroll(x, cell):
assert len(x.get_shape().as_list()) == 5, "input in ConvRNN should be B,T,H,W,C"
time_dim = x.get_shape().as_list()[1]

outputs = []
for t in range(time_dim):
outputs.append(cell(x[:, t, :, :, :]))

return tf.stack(outputs, axis=1)
Loading

0 comments on commit e3172b6

Please sign in to comment.