-
Notifications
You must be signed in to change notification settings - Fork 19
/
main.lua
326 lines (272 loc) · 12.1 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
local seed = 12345
torch.manualSeed(seed)
require 'Imports'
local cmd = GeneralOptions:get_flags()
local params = cmd:parse(arg)
if(params.profile == 1) then
require 'Pepperfish'
profiler = new_profiler()
profiler:start()
end
print(params)
local use_cuda = params.gpuid >= 0
params.use_cuda= use_cuda
if(use_cuda)then
print('USING GPU '..params.gpuid)
require 'cutorch'
require 'cunn'
cutorch.setDevice(params.gpuid + 1)
cutorch.manualSeed(seed)
if(params.cudnn == 1) then
print('using cudnn')
require 'cudnn'
end
end
local load_problem
if(params.problem == "SequenceTagging") then
load_problem = function(params)
local problem_config = torch.load(params.problem_config)
problem_config.batch_size = params.batch_size
local y_shape = {problem_config.batch_size,problem_config.length,problem_config.domain_size}
problem_config.y_shape = y_shape
local model = ChainSPEN(problem_config,params)
local evaluator_factory = function(batcher, soft_predictor)
local hard_predictor = RoundingPredictor(soft_predictor,y_shape)
return HammingEvaluator(batcher, function(x) return hard_predictor:predict(x) end)
end
local preprocess_func = nil
local train_batcher = BatcherFromFile(params.train_list, preprocess_func, params.batch_size, use_cuda)
local test_batcher = BatcherFromFile(params.test_list, preprocess_func, params.batch_size, use_cuda)
return model, y_shape, evaluator_factory, preprocess_func, train_batcher, test_batcher
end
elseif(params.problem == "MultiLabelClassification") then
load_problem = function(params)
local problem_config = torch.load(params.problem_config)
problem_config.batch_size = params.batch_size
local y_shape = {problem_config.batch_size,problem_config.label_dim,2}
problem_config.y_shape = y_shape
local model = MLCSPEN(problem_config,params)
local evaluator_factory = function(batcher, soft_predictor)
return MultiLabelEvaluation(batcher, soft_predictor, problem_config.prediction_thresh, params.results_file)
end
local adder = nn.AddConstant(1)
if(use_cuda) then adder:cuda() end
preprocess_func = function(a,b,c) return adder:forward(a):clone(), b, c end --TODO: make this unnecessary by preprocessing data differently. Right now, the labels are 0-indexed, so we have to add one.
local train_batcher = BatcherFromFile(params.train_list, preprocess_func, params.batch_size, use_cuda)
local test_batcher = BatcherFromFile(params.test_list, preprocess_func, params.batch_size, use_cuda)
return model, y_shape, evaluator_factory, preprocess_func, train_batcher, test_batcher
end
elseif(params.problem == "Denoise") then
load_problem = function(params)
local problem_config = torch.load(params.problem_config)
problem_config.batch_size = params.batch_size
local y_shape, train_preprocess_func
if(problem_config.use_random_crops == 1) then
local crop_height = 96 --these are the crop sizes used in the proximalnet paper. TODO: surface command line options for these
local crop_width = 128
y_shape = {problem_config.batch_size,crop_height,crop_width}
local a_crop_contiguous, b_crop_contiguous
local y_crop_start_max = problem_config.height - crop_height
local x_crop_start_max = problem_config.width - crop_width
--This randomly crops the images in order to speed up training
--Note that it uses the same crop locations for every image in the minibatch
train_preprocess_func = function(a,b,c)
local y_start = torch.rand(1):mul(y_crop_start_max):ceil()[1]
local x_start = torch.rand(1):mul(x_crop_start_max):ceil()[1]
local a_crop = a:narrow(2,y_start,crop_height):narrow(3,x_start,crop_width)
local b_crop = b:narrow(2,y_start,crop_height):narrow(3,x_start,crop_width)
a_crop_contiguous = a_crop_contiguous or a_crop:clone()
a_crop_contiguous:copy(a_crop)
b_crop_contiguous = b_crop_contiguous or b_crop:clone()
b_crop_contiguous:copy(b_crop)
return a_crop_contiguous, b_crop_contiguous, c
end
else
y_shape = {problem_config.batch_size,problem_config.height,problem_config.width}
end
problem_config.y_shape = y_shape
local model = DenoiseSPEN(problem_config,params)
local evaluator_factory = function(batcher, soft_predictor)
return PSNREvaluator(batcher, function(x) return soft_predictor:forward(x) end)
end
local train_batcher = BatcherFromFile(params.train_list, train_preprocess_func, params.batch_size, use_cuda)
--NOTE: this doesn't return the actual test set score, but an approximation using random crops on the dev set.
--It will require some more engineering to be able to actually run on the full-size test images, as the network expects smaller images.
local test_batcher = BatcherFromFile(params.test_list, train_preprocess_func, params.batch_size, use_cuda)
return model, y_shape, evaluator_factory, preprocess_func, train_batcher, test_batcher
end
else
error('invalid problem type')
end
local model, y_shape, evaluator_factory, preprocess_func, train_batcher, test_batcher = load_problem(params)
local pretrain_train_config = {}
do
pretrain_train_config.soft_predictor = model.classifier_network
pretrain_train_config.modules_to_update = model.classifier_network
pretrain_train_config.stop_feature_backprop = false
pretrain_train_config.stop_unary_backprop = false
local criterion_name = (params.continuous_outputs == 1) and "MSECriterion" or "ClassNLLCriterion"
print(params)
pretrain_train_config.loss_wrapper = TrainingWrappers:independent_training(model.classifier_network, criterion_name, y_shape, params)
pretrain_train_config.items_to_save = {
classifier = model.classifier_network
}
end
local full_train_config = {}
do
params.return_all_iterates = params.penalize_all_iterates == 1
params.num_iterates = params.max_inference_iters
local gd_inference_config = GradientBasedInferenceConfig:get_gd_inference_config(params)
local full_gd_prediction_net = GradientBasedInference(y_shape, gd_inference_config):spen_inference(model)
gd_prediction_net = full_gd_prediction_net
if(params.return_all_iterates) then
gd_prediction_net = nn.Sequential():add(full_gd_prediction_net):add(nn.SelectTable(1))
end
--initialize the unaries from a loaded model
if(params.init_classifier ~= "") then
assert(not (params.init_full_net ~= ""), "shouldn't be initializing both classifier and full energy network from file")
print('initializing classifier from '..params.init_classifier)
model.classifier_network:getParameters():copy(torch.load(params.init_classifier):getParameters())
end
if(params.init_full_net ~= "") then
print('initializing parameters from '..params.init_full_net)
if(params.use_cuda) then gd_prediction_net:double() end --the fact that we have to do this is mysterious.
gd_prediction_net:getParameters():copy(torch.load(params.init_full_net):getParameters())
if(params.use_cuda) then gd_prediction_net:cuda() end
end
full_train_config.soft_predictor = gd_prediction_net
full_train_config.modules_to_update = gd_prediction_net
if(params.training_method == "E2E") then
local criterion_name = (params.continuous_outputs == 1) and "MSECriterion" or "ClassNLLCriterion"
full_train_config.loss_wrapper = TrainingWrappers:independent_training(full_gd_prediction_net, criterion_name, y_shape, params)
elseif(params.training_method == "SSVM") then
local criterion_name = 'MSECriterion' --todo: surface an option for this
assert(not params.return_all_iterates)
full_train_config.loss_wrapper = TrainingWrappers:ssvm_training(y_shape, model, criterion_name, gd_inference_config, params)
else
error('invalid training method')
end
full_train_config.items_to_save = {
predictor = gd_prediction_net,
energy_net = model:full_energy_net()
}
end
local clamp_features_train_config = Util:copyTable(full_train_config) --this is a copy by reference
clamp_features_train_config.stop_feature_backprop = true
local clamp_unaries_train_config = Util:copyTable(full_train_config) --this is a copy by reference
clamp_unaries_train_config.stop_unary_backprop = true
clamp_unaries_train_config.stop_feature_backprop = true
local train_configurations = {
pretrain = pretrain_train_config,
clamp_features = clamp_features_train_config,
clamp_unaries = clamp_unaries_train_config,
full = full_train_config
}
local function evaluate_only(config,params)
if(params.use_cuda) then config.soft_predictor:cuda() end
local evaluator = evaluator_factory(test_batcher, config.soft_predictor)
evaluator:evaluate(0)
end
local function train(config, params, name)
assert(config)
if(params.use_cuda) then
config.loss_wrapper:cuda()
config.soft_predictor:cuda()
end
model:set_feature_backprop( not config.stop_feature_backprop )
model:set_unary_backprop( not config.stop_unary_backprop )
local callbacks = {}
local evaluator = evaluator_factory(test_batcher, config.soft_predictor)
local evaluate = Callback(function(data) return evaluator:evaluate(data.epoch) end, params.evaluation_frequency)
table.insert(callbacks,evaluate)
local opt_state = {}
if(params.init_opt_state ~= "") then
print('loading opt_state from '..params.init_opt_state)
opt_state = torch.load(params.init_opt_state)
end
local optimization_config = {
opt_state = opt_state,
opt_config = {
learningRate=params.learning_rate,
learningRateDecay=0, --this gets updated by lr_start below
beta1 = params.adam_beta1,
beta2 = params.adam_beta2,
epsilon=params.adam_epsilon,
weightDecay=params.l2
},
opt_method = optim.adam,
gradient_clip = params.gradient_clip,
regularization = config.regularization,
modules_to_update = config.modules_to_update
}
local general_config = {
num_epochs = params.num_epochs,
batches_per_epoch = params.batches_per_epoch,
batch_size = params.batch_size,
assert_nan = true,
}
local nonzero_learning_rate_set = false
local function set_lr(data)
if(data.epoch > params.learning_rate_decay_start and not set_nonzero_learning_rate) then
optimization_config.opt_config.learningRateDecay = params.learning_rate_decay
nonzero_learning_rate_set = true
end
end
local lr_start = Callback(set_lr, 1)
table.insert(callbacks,lr_start)
if(params.icnn == 1) then
local params_to_clamp = model.global_potentials_network:parameters()
--todo: we actually don't need to clamp the biases
local function clamp()
Util:deep_apply(params_to_clamp,function(t) t:cmax(0) end)
end
general_config.post_process_parameter_update = clamp
end
config.items_to_save.opt_state = optimization_config.opt_state
local saved_model_base = params.model_file.."-"..name
local saver = Saver(saved_model_base,config.items_to_save)
local save = Callback(function(data) return saver:save(data.epoch) end, params.save_frequency)
table.insert(callbacks,save)
Train(config.loss_wrapper,train_batcher, optimization_config, general_config, callbacks):train()
end
if(params.evaluate_classifier_only == 1) then
evaluate_only(train_configurations.pretrain, params)
os.exit()
end
if(params.evaluate_spen_only == 1) then
evaluate_only(train_configurations.full, params)
os.exit()
end
for params_file in io.lines(params.training_configs) do
print('loading specific training config from '..params_file)
local specific_params = torch.load(params_file)
local mode = specific_params.training_mode
if(specific_params.num_epochs > 0) then
print(specific_params)
local all_params = Util:copyTable(params)
for k,v in pairs(specific_params) do
assert(not all_params[k],'repeated key: '..k)
all_params[k] = v
end
print('starting training for mode: '..mode)
if(mode == "pretrain_unaries") then
train(train_configurations.pretrain, all_params, mode)
elseif(mode == "clamp_unaries") then
train(train_configurations.clamp_unaries, all_params, mode)
elseif(mode == "clamp_features") then
train(train_configurations.clamp_features, all_params, mode)
elseif(mode == "update_all") then
train(train_configurations.full, all_params, mode)
else
error('invalid training mode: '..mode)
end
end
end
if(params.profile == 1) then
profiler:stop()
local report = "profile.txt"
print('writing profiling report to: '..report)
local outfile = io.open( report, "w+" )
profiler:report( outfile )
outfile:close()
end