-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathcifar-dataset.lua
103 lines (91 loc) · 3.62 KB
/
cifar-dataset.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
path = require 'pl.path'
require 'image'
Dataset = {}
local CIFAR, parent = torch.class("Dataset.LOADER")
function get_Data(dataset, path, do_shuffling)
local data = torch.Tensor(60000, 3, 32, 32)
local label = torch.Tensor(60000)
local train_data = torch.load(path..dataset..'-train.t7')
data[{ {1, 50000} }] = train_data.data
label[{ {1, 50000} }] = train_data.label
local test_data = torch.load(path..dataset..'-test.t7')
data[{ {50001, 60000} }] = test_data.data
label[{ {50001, 60000} }] = test_data.label
if do_shuffling then
local gen = torch.Generator()
torch.manualSeed(gen, 777)
local shuffle = torch.randperm(gen, 50000)
data[{ {1, 50000} }] = data:index(1, shuffle:long())
label[{ {1, 50000} }] = label:index(1, shuffle:long())
end
return data, label + 1
end
function CIFAR:__init(data, label, mode, opt)
local max_trainvalid = 50000
local trsize = opt.trsize
local vasize = opt.vasize
local tesize = 10000
if trsize + vasize > max_trainvalid then
print('Too large train + validation! Not enough data left for test')
end
self.batchSize = opt.batchSize
self.mode = mode
if mode == "train" then
self.data = data[{ {1,trsize} }]
self.label = label[{ {1,trsize} }]
self.augmentation = opt.augmentation
print(string.format('Train data: %d to %d', 1, trsize))
elseif mode == "valid" then
self.data = data[{ {trsize+1, trsize+vasize} }]
self.label = label[{ {trsize+1, trsize+vasize} }]
print(string.format('Validation data: %d to %d', trsize+1, trsize+vasize))
elseif mode == "test" then
self.data = data[{ {max_trainvalid+1, max_trainvalid+tesize} }]
self.label = label[{ {max_trainvalid+1, max_trainvalid+tesize} }]
print(string.format('Test data: %d to %d', max_trainvalid+1, max_trainvalid + tesize))
end
end
function CIFAR:preprocess(mean, std)
mean = mean or self.data:mean(1)
std = std or self.data:std()
self.data:add(-mean:expandAs(self.data)):mul(1/std)
return mean,std
end
function CIFAR:size()
return self.data:size(1)
end
function CIFAR:sampleIndices(indices, batch)
batch = batch or {inputs = torch.zeros(indices:size(1), 3, 32,32),
outputs = torch.zeros(indices:size(1))}
if self.mode == "train" then
if self.augmentation then
batch.inputs:zero()
for i,index in ipairs(torch.totable(indices)) do
-- Copy self.data[index] into batch.inputs[i], with standard data augmentation
local input = batch.inputs[i]
input:zero()
-- Translation by at most 4 pixels
local xoffs, yoffs = torch.random(-4,4), torch.random(-4,4)
local input_y = {math.max(1, 1 + yoffs),
math.min(32, 32 + yoffs)}
local data_y = {math.max(1, 1 - yoffs),
math.min(32, 32 - yoffs)}
local input_x = {math.max(1, 1 + xoffs),
math.min(32, 32 + xoffs)}
local data_x = {math.max(1, 1 - xoffs),
math.min(32, 32 - xoffs)}
input[{ {}, input_y, input_x }] = self.data[index][{ {}, data_y, data_x }]
-- Horizontal flip, each side with half probability
if torch.random(1,2)==1 then
input:copy(image.hflip(input))
end
end
else
batch.inputs:copy(self.data:index(1, indices))
end
elseif self.mode=="test" or self.mode=="valid" then
batch.inputs:copy(self.data:index(1, indices))
end
batch.outputs:copy(self.label:index(1, indices))
return batch
end