Skip to content

Commit

Permalink
Merge pull request #5 from zhmeishi/train
Browse files Browse the repository at this point in the history
Train
  • Loading branch information
fhshi authored Oct 4, 2021
2 parents df5ec60 + 2c77fa0 commit 67e6d01
Show file tree
Hide file tree
Showing 8 changed files with 385 additions and 21 deletions.
20 changes: 17 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@



# Deep Online Fused Video Stabilization

Expand All @@ -18,7 +19,6 @@ pip install -r requirements.txt --ignore-installed
## Data Preparation
Download sample video [here](https://drive.google.com/file/d/1nju9H8ohYZh6dGsdrQjQXFgfgkrFtkRi/view?usp=sharing).
Uncompress the *video* folder under the *dvs* folder.

```
python load_frame_sensor_data.py
```
Expand Down Expand Up @@ -52,7 +52,21 @@ In *s_114_outdoor_running_trail_daytime.jpg*, the blue curve is the output of ou
*s_114_outdoor_running_trail_daytime_stab_crop.mp4* is cropped stabilized video. Note, the cropped video is generated after running the metrics code.

## Training
TBA
Download dataset for training and test [here](https://storage.googleapis.com/dataset_release/all.zip).
Uncompress *all.zip* and move *dataset_release* folder under the *dvs* folder.

Follow FlowNet2 Preparation Section.
```
python warp/read_write.py --dir_path ./dataset_release # video2frames
cd flownet2
bash run_release.sh # generate optical flow file for dataset
```

Run training code.
```
python train.py
```
The model is saved in *checkpoint/stabilzation_train*.

## Citation
If you use this code or dataset for your research, please cite our paper.
Expand All @@ -63,4 +77,4 @@ If you use this code or dataset for your research, please cite our paper.
journal={arXiv preprint arXiv:2102.01279},
year={2021}
}
```
```
57 changes: 57 additions & 0 deletions dvs/conf/stabilzation_train.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
data:
exp: 'stabilzation_train'
checkpoints_dir: './checkpoint'
log: './log'
data_dir: './dataset_release'
use_cuda: true
batch_size: 16
resize_ratio: 0.25
number_real: 10
number_virtual: 2
time_train: 2000 # ms
sample_freq: 40 # ms
channel_size: 1
num_workers: 16 # num_workers for data_loader
model:
load_model: null
cnn:
activate_function: relu # sigmoid, relu, tanh, quadratic
batch_norm: true
gap: false
layers:
rnn:
layers:
- - 512
- true
- - 512
- true
fc:
activate_function: relu
batch_norm: false # (batch_norm and drop_out) is False
layers:
- - 256
- true
- - 4 # last layer should be equal to nr_class
- true
drop_out: 0
train:
optimizer: "adam" # adam or sgd
momentum: 0.9 # for sgd
decay_epoch: null
epoch: 400
snapshot: 2
init_lr: 0.0001
lr_decay: 0.5
lr_step: 200 # if > 0 decay_epoch should be null
seed: 1
weight_decay: 0.0001
clip_norm: False
init: "xavier_uniform" # xavier_uniform or xavier_normal
loss:
follow: 10
angle: 1
smooth: 10 #10
c2_smooth: 200 #20
undefine: 2.0
opt: 0.1
stay: 0
2 changes: 1 addition & 1 deletion dvs/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ def get_data_loader(cf, no_flo = False):
def get_dataset(cf, no_flo = False):
resize_ratio = cf["data"]["resize_ratio"]
train_transform, test_transform = _data_transforms()
train_path = os.path.join(cf["data"]["data_dir"], "train")
train_path = os.path.join(cf["data"]["data_dir"], "training")
test_path = os.path.join(cf["data"]["data_dir"], "test")
if not os.path.exists(train_path):
train_path = cf["data"]["data_dir"]
Expand Down
10 changes: 10 additions & 0 deletions dvs/flownet2/run_release.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
#!/bin/bash
python main.py --inference --model FlowNet2 --save_flow --inference_dataset Google \
--inference_dataset_root ./../dataset_release/test \
--resume ./FlowNet2_checkpoint.pth.tar \
--inference_visualize

python main.py --inference --model FlowNet2 --save_flow --inference_dataset Google \
--inference_dataset_root ./../dataset_release/training \
--resume ./FlowNet2_checkpoint.pth.tar \
--inference_visualize
16 changes: 1 addition & 15 deletions dvs/load_frame_sensor_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,21 +91,7 @@ def inference(cf, data_path, USE_CUDA):
rotations_real, lens_offsets_real = get_rotations(data.frame[:data.length], data.gyro, data.ois, data.length)
fig_path = os.path.join(data_path, video_name+"_real.jpg")
visual_rotation(rotations_real, lens_offsets_real, None, None, None, None, fig_path)

# print("------Start Warping Video--------")
# grid = get_grid(test_loader.dataset.static_options, \
# data.frame[:data.length], data.gyro, data.ois, virtual_queue[:data.length,1:], no_shutter = False)

# grid_rm_shutter = get_grid(test_loader.dataset.static_options, \
# data.frame[:data.length], data.gyro, np.zeros(data.ois.shape), virtual_queue[:data.length,1:], no_shutter = False)

# video_path = os.path.join(data_path, video_name+".mp4")
# data_name = data_path.split("/")[-1]
# save_path = os.path.join(data_path, video_name+"_no_ois.mp4")
# warp_video(grid, video_path, save_path, losses = None)

# save_path = os.path.join(data_path, video_name+"_no_shutter.mp4")
# warp_video(grid_rm_shutter, video_path, save_path, losses = None)

return

def main(args = None):
Expand Down
Loading

0 comments on commit 67e6d01

Please sign in to comment.