forked from NVIDIA/DIGITS
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathlenet-fine-tune.lua
71 lines (62 loc) · 2.63 KB
/
lenet-fine-tune.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
-- return function that returns network definition
return function(params)
-- get original number of classes (10 i.e. one per digit)
local nclasses = 10
-- get number of channels from external parameters
local channels = 1
-- params.inputShape may be nil during visualization
if params.inputShape then
channels = params.inputShape[1]
assert(params.inputShape[2]==28 and params.inputShape[3]==28, 'Network expects 28x28 images')
end
if pcall(function() require('cudnn') end) then
print('Using CuDNN backend')
backend = cudnn
convLayer = cudnn.SpatialConvolution
convLayerName = 'cudnn.SpatialConvolution'
else
print('Failed to load cudnn backend (is libcudnn.so in your library path?)')
if pcall(function() require('cunn') end) then
print('Falling back to legacy cunn backend')
else
print('Failed to load cunn backend (is CUDA installed?)')
print('Falling back to legacy nn backend')
end
backend = nn -- works with cunn or nn
convLayer = nn.SpatialConvolutionMM
convLayerName = 'nn.SpatialConvolutionMM'
end
-- -- This is a LeNet model. For more information: http://yann.lecun.com/exdb/lenet/
local lenet = nn.Sequential()
lenet:add(nn.MulConstant(0.00390625))
lenet:add(backend.SpatialConvolution(channels,20,5,5,1,1,0)) -- channels*28*28 -> 20*24*24
lenet:add(backend.SpatialMaxPooling(2, 2, 2, 2)) -- 20*24*24 -> 20*12*12
lenet:add(backend.SpatialConvolution(20,50,5,5,1,1,0)) -- 20*12*12 -> 50*8*8
lenet:add(backend.SpatialMaxPooling(2,2,2,2)) -- 50*8*8 -> 50*4*4
lenet:add(nn.View(-1):setNumInputDims(3)) -- 50*4*4 -> 800
lenet:add(nn.Linear(800,500)) -- 800 -> 500
lenet:add(backend.ReLU())
lenet:add(nn.Linear(500, nclasses)) -- 500 -> nclasses
lenet:add(nn.LogSoftMax())
-- multi-GPU implementation needed
assert(params.ngpus <= 1, "Multi-GPU implementation needed")
local model = lenet
local function lenetMnistOddOrEvenFineTune(net)
-- fix weights of existing layers
local function dummyAccGradParameters() end
net:get(2).accGradParameters = dummyAccGradParameters
net:get(4).accGradParameters = dummyAccGradParameters
net:get(7).accGradParameters = dummyAccGradParameters
-- insert 10->2 linear layer
local l = nn.Linear(10, 2)
net:insert(l, 10)
return net
end
return {
model = model,
loss = nn.ClassNLLCriterion(),
trainBatchSize = 1,
validationBatchSize = 100,
fineTuneHook =lenetMnistOddOrEvenFineTune,
}
end