-
Notifications
You must be signed in to change notification settings - Fork 12
/
main.py
77 lines (68 loc) · 3.6 KB
/
main.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import argparse
import numpy as np
import tensorflow as tf
from functools import partial
from utils import *
from generator import *
from queues import *
from YeNet import YeNet
parser = argparse.ArgumentParser(description='PyTorch implementation of YeNet')
parser.add_argument('train_cover_dir', type=str, metavar='PATH',
help='path of directory containing all ' +
'training cover images')
parser.add_argument('train_stego_dir', type=str, metavar='PATH',
help='path of directory containing all ' +
'training stego images or beta maps')
parser.add_argument('valid_cover_dir', type=str, metavar='PATH',
help='path of directory containing all ' +
'validation cover images')
parser.add_argument('valid_stego_dir', type=str, metavar='PATH',
help='path of directory containing all ' +
'validation stego images or beta maps')
parser.add_argument('--batch-size', type=int, default=32, metavar='N',
help='input batch size for training (default: 32)')
parser.add_argument('--test-batch-size', type=int, default=32, metavar='N',
help='input batch size for testing (default: 32)')
parser.add_argument('--epochs', type=int, default=1000, metavar='N',
help='number of epochs to train (default: 1000)')
parser.add_argument('--lr', type=float, default=4e-1, metavar='LR',
help='learning rate (default: 4e-1)')
parser.add_argument('--use-batch-norm', action='store_true', default=False,
help='use batch normalization after each activation,' +
' also disable pair constraint (default: False)')
parser.add_argument('--embed-otf', action='store_true', default=False,
help='use beta maps and embed on the fly instead' +
' of use stego images (default: False)')
parser.add_argument('--no-cuda', action='store_true', default=False,
help='disables CUDA training')
parser.add_argument('--gpu', type=int, default=0,
help='index of gpu used (default: 0)')
parser.add_argument('--seed', type=int, default=1, metavar='S',
help='random seed (default: 1)')
parser.add_argument('--log-interval', type=int, default=200, metavar='N',
help='how many batches to wait ' +
'before logging training status')
parser.add_argument('--log-path', type=str, default='logs/',
metavar='PATH', help='path to generated log file')
args = parser.parse_args()
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ['CUDA_VISIBLE_DEVICES'] = '' if args.no_cuda else str(args.gpu)
tf.set_random_seed(args.seed)
train_ds_size = len(glob(args.train_cover_dir + '/*')) * 2
if args.embed_otf:
train_gen = partial(gen_embedding_otf, args.train_cover_dir, \
args.train_stego_dir, args.use_batch_norm)
else:
train_gen = partial(gen_flip_and_rot, args.train_cover_dir, \
args.train_stego_dir, args.use_batch_norm)
valid_ds_size = len(glob(args.valid_cover_dir + '/*')) * 2
valid_gen = partial(gen_valid, args.valid_cover_dir, \
args.valid_stego_dir)
if valid_ds_size % 32 != 0:
raise ValueError("change batch size for validation")
optimizer = tf.train.AdadeltaOptimizer(args.lr)
train(YeNet, train_gen, valid_gen, args.batch_size, \
args.test_batch_size, valid_ds_size, \
optimizer, args.log_interval, train_ds_size, \
args.epochs * train_ds_size, train_ds_size, args.log_path, 8)