-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathimagenet-gen.lua
125 lines (100 loc) · 3.8 KB
/
imagenet-gen.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
--
-- Copyright (c) 2016, Facebook, Inc.
-- All rights reserved.
--
-- This source code is licensed under the BSD-style license found in the
-- LICENSE file in the root directory of this source tree. An additional grant
-- of patent rights can be found in the PATENTS file in the same directory.
--
-- Script to compute list of ImageNet filenames and classes
--
-- This generates a file gen/imagenet.t7 which contains the list of all
-- ImageNet training and validation images and their classes. This script also
-- works for other datasets arragned with the same layout.
--
local sys = require 'sys'
local ffi = require 'ffi'
local M = {}
local function findClasses(dir)
local dirs = paths.dir(dir)
table.sort(dirs)
local classList = {}
local classToIdx = {}
for _ ,class in ipairs(dirs) do
if not classToIdx[class] and class ~= '.' and class ~= '..' then
table.insert(classList, class)
classToIdx[class] = #classList
end
end
-- assert(#classList == 1000, 'expected 1000 ImageNet classes')
return classList, classToIdx
end
local function findImages(dir, classToIdx)
local imagePath = torch.CharTensor()
local imageClass = torch.LongTensor()
----------------------------------------------------------------------
-- Options for the GNU and BSD find command
local extensionList = {'jpg', 'png', 'jpeg', 'JPG', 'PNG', 'JPEG', 'ppm', 'PPM', 'bmp', 'BMP'}
local findOptions = ' -iname "*.' .. extensionList[1] .. '"'
for i=2,#extensionList do
findOptions = findOptions .. ' -o -iname "*.' .. extensionList[i] .. '"'
end
-- Find all the images using the find command
local f = io.popen('find -L ' .. dir .. findOptions)
local maxLength = -1
local imagePaths = {}
local imageClasses = {}
-- Generate a list of all the images and their class
while true do
local line = f:read('*line')
if not line then break end
local className = paths.basename(paths.dirname(line))
local filename = paths.basename(line)
local path = className .. '/' .. filename
local classId = classToIdx[className]
assert(classId, 'class not found: ' .. className)
table.insert(imagePaths, path)
table.insert(imageClasses, classId)
maxLength = math.max(maxLength, #path + 1)
end
f:close()
-- Convert the generated list to a tensor for faster loading
local nImages = #imagePaths
local imagePath = torch.CharTensor(nImages, maxLength):zero()
for i, path in ipairs(imagePaths) do
ffi.copy(imagePath[i]:data(), path)
end
local imageClass = torch.LongTensor(imageClasses)
return imagePath, imageClass
end
function M.exec(opt, cacheFile)
-- find the image path names
local imagePath = torch.CharTensor() -- path to each image in dataset
local imageClass = torch.LongTensor() -- class index of each image (class index in self.classes)
local trainDir = paths.concat(opt.data, 'train')
local valDir = paths.concat(opt.data, 'val')
assert(paths.dirp(trainDir), 'train directory not found: ' .. trainDir)
assert(paths.dirp(valDir), 'val directory not found: ' .. valDir)
print("=> Generating list of images")
local classList, classToIdx = findClasses(trainDir)
print(" | finding all validation images")
local valImagePath, valImageClass = findImages(valDir, classToIdx)
print(" | finding all training images")
local trainImagePath, trainImageClass = findImages(trainDir, classToIdx)
local info = {
basedir = opt.data,
classList = classList,
train = {
imagePath = trainImagePath,
imageClass = trainImageClass,
},
val = {
imagePath = valImagePath,
imageClass = valImageClass,
},
}
print(" | saving list of images to " .. cacheFile)
torch.save(cacheFile, info)
return info
end
return M