forked from AaltoVision/relativeCameraPose
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmain.lua
85 lines (64 loc) · 1.97 KB
/
main.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
require 'torch'
require 'cutorch'
require 'paths'
require 'xlua'
require 'optim'
require 'nn'
torch.setdefaulttensortype('torch.FloatTensor')
local c = require 'trepl.colorize'
local opts = paths.dofile('opts.lua')
local tnt = require 'torchnet'
opt = opts.parse(arg)
print(opt)
torch.manualSeed(opt.manualSeed)
epoch = opt.epoch_number
-- Getting the multi-gpu functions
paths.dofile('gpu_util.lua')
-- Initializing data provider
paths.dofile('dtu_data_provider.lua')
init_data_provider()
paths.dofile('dtu_construct_minibatch.lua')
-- Loading CNN model
paths.dofile('model.lua')
cudnn.convert(model, cudnn)
collectgarbage()
-- Create Criterion
local mse_1 = nn.MSECriterion() -- orientation loss
local mse_2 = nn.MSECriterion() -- translation loss
local w_mse_1 = 1
local w_mse_2 = 1
criterion = nn.ParallelCriterion():add(mse_1, w_mse_1):add(mse_2, w_mse_2):cuda()
collectgarbage()
-- Create Meters
meter_test_q = tnt.AverageValueMeter()
meter_test_t = tnt.AverageValueMeter()
meter_train_q = tnt.AverageValueMeter()
meter_train_t = tnt.AverageValueMeter()
-- Loading the functions for training
paths.dofile('train.lua')
-- Loading the functions for testing
paths.dofile('test.lua')
local model_parameters, _ = model:getParameters()
print(c.blue '==>' .. ' Number of parameters in the model: ' .. model_parameters:size(1))
cutorch.setDevice(opt.GPU) -- by default, use GPU 1
torch.manualSeed(opt.manualSeed)
if opt.do_evaluation then
evaluation()
else
for i = opt.epochNumber,opt.max_epoch do
train()
test()
collectgarbage()
model:clearState()
-- Saving the model
if not paths.dirp(opt.snapshot_dir) then
paths.mkdir(opt.snapshot_dir)
end
saveDataParallel(paths.concat(opt.snapshot_dir, 'siam_hybridnet_fullsized_SPP_' .. (epoch) .. '.t7'), model)
epoch = epoch + 1
meter_test_q:reset()
meter_test_t:reset()
meter_train_q:reset()
meter_train_t:reset()
end
end