forked from torch/cutorch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
init.lua
56 lines (48 loc) · 1.98 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
require "torch"
paths.require "libcutorch"
torch.CudaByteStorage.__tostring__ = torch.ByteStorage.__tostring__
torch.CudaByteTensor.__tostring__ = torch.ByteTensor.__tostring__
torch.CudaCharStorage.__tostring__ = torch.CharStorage.__tostring__
torch.CudaCharTensor.__tostring__ = torch.CharTensor.__tostring__
torch.CudaShortStorage.__tostring__ = torch.ShortStorage.__tostring__
torch.CudaShortTensor.__tostring__ = torch.ShortTensor.__tostring__
torch.CudaIntStorage.__tostring__ = torch.IntStorage.__tostring__
torch.CudaIntTensor.__tostring__ = torch.IntTensor.__tostring__
torch.CudaLongStorage.__tostring__ = torch.LongStorage.__tostring__
torch.CudaLongTensor.__tostring__ = torch.LongTensor.__tostring__
torch.CudaStorage.__tostring__ = torch.FloatStorage.__tostring__
torch.CudaTensor.__tostring__ = torch.FloatTensor.__tostring__
torch.CudaDoubleStorage.__tostring__ = torch.DoubleStorage.__tostring__
torch.CudaDoubleTensor.__tostring__ = torch.DoubleTensor.__tostring__
include('Tensor.lua')
include('FFI.lua')
include('test.lua')
local unpack = unpack or table.unpack
function cutorch.withDevice(newDeviceID, closure)
local curDeviceID = cutorch.getDevice()
cutorch.setDevice(newDeviceID)
local vals = {pcall(closure)}
cutorch.setDevice(curDeviceID)
if vals[1] then
return unpack(vals, 2)
end
error(unpack(vals, 2))
end
-- Creates a FloatTensor using the CudaHostAllocator.
-- Accepts either a LongStorage or a sequence of numbers.
function cutorch.createCudaHostTensor(...)
local size
if not ... then
size = torch.LongTensor{0}
elseif torch.isStorage(...) then
size = torch.LongTensor(...)
else
size = torch.LongTensor{...}
end
local storage = torch.FloatStorage(cutorch.CudaHostAllocator, size:prod())
return torch.FloatTensor(storage, 1, size:storage())
end
-- remove this line to disable automatic cutorch heap-tracking
-- for garbage collection
cutorch.setHeapTracking(true)
return cutorch