-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathdonkey.lua
74 lines (67 loc) · 2.05 KB
/
donkey.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
require 'image'
paths.dofile('DataLoader.lua')
-- a cache file of the training metadata (if doesnt exist, will be created)
local trainCache = paths.concat(opt.cache, 'trainCache.t7')
local function loadImage(path)
local input = image.load(path)
if input:dim() == 2 then -- 1-channel image loaded as 2D tensor
input = input:view(1,input:size(1), input:size(2)):repeatTensor(3,1,1)
elseif input:dim() == 3 and input:size(1) == 1 then -- 1-channel image
input = input:repeatTensor(3,1,1)
elseif input:dim() == 3 and input:size(1) == 3 then -- 3-channel image
elseif input:dim() == 3 and input:size(1) == 4 then -- image with alpha
input = input[{{1,3},{},{}}]
else
print(#input)
error('not 2-channel or 3-channel image')
end
input = image.scale(input, 256, 256)
-- input:cuda()
return input
end
-- VGG preprocessing
local bgr_means = {103.939,116.779,123.68}
local function vggPreprocess(img)
local im2 = img:clone()
im2[{1,{},{}}] = img[{3,{},{}}]
im2[{3,{},{}}] = img[{1,{},{}}]
im2:mul(255)
for i=1,3 do
im2[i]:add(-bgr_means[i])
end
return im2
end
local function centerCrop(input)
local oH = 224
local oW = 224
local iW = input:size(3)
local iH = input:size(2)
local w1 = math.ceil((iW-oW)/2)
local h1 = math.ceil((iH-oH)/2)
local out = image.crop(input, w1, h1, w1+oW, h1+oW) -- center patch
return out
end
-- function to load the image
local loadHook = function(path)
collectgarbage()
local worked, im = pcall(loadImage, path)
if worked then
local input = loadImage(path)
local vggPreprocessed = vggPreprocess(input)
local out = centerCrop(vggPreprocessed)
return out, worked
else
return 0, worked
end
end
if paths.filep(trainCache) then
print('Loading train metadata from cache')
loader = torch.load(trainCache)
loader.loadHook = loadHook
else
print('Creating test metadata')
loader = DataLoader{data = opt.data, validation_data = opt.validation_data}
torch.save(trainCache, loader)
loader.loadHook = loadHook
end
collectgarbage()