-
Notifications
You must be signed in to change notification settings - Fork 17
/
init.lua
121 lines (99 loc) · 2.69 KB
/
init.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
require 'torchx'
local _ = require 'moses'
require 'nn'
local _cuda, _ = pcall(require, 'cunn')
-- create global rnn table:
rnn = {}
rnn.cuda = _cuda
rnn.version = 2.7 -- better support for bidirection RNNs
-- lua 5.2 compat
function nn.require(packagename)
assert(torch.type(packagename) == 'string')
local success, message = pcall(function() require(packagename) end)
if not success then
print("missing package "..packagename..": run 'luarocks install '"..packagename.."'")
error(message)
end
end
-- c lib:
require "paths"
pcall(function() paths.require 'librnn' end) -- Not sure why this works...
pcall(function() paths.require 'librnn' end)
unpack = unpack or table.unpack
require('rnn.utils')
-- extensions to existing nn.Module
require('rnn.Module')
require('rnn.Container')
require('rnn.Sequential')
require('rnn.ParallelTable')
require('rnn.LookupTable')
require('rnn.Dropout')
require('rnn.BatchNormalization')
-- extensions to existing nn.Criterion
require('rnn.Criterion')
-- modules
require('rnn.LookupTableMaskZero')
require('rnn.MaskZero')
require('rnn.ReverseSequence')
require('rnn.SpatialGlimpse')
require('rnn.ArgMax')
require('rnn.CategoricalEntropy')
require('rnn.TotalDropout')
require('rnn.SAdd')
require('rnn.CopyGrad')
require('rnn.VariableLength')
require('rnn.StepLSTM')
require('rnn.StepGRU')
require('rnn.ReverseUnreverse')
-- Noise Contrastive Estimation
require('rnn.NCEModule')
require('rnn.NCECriterion')
-- REINFORCE
require('rnn.Reinforce')
require('rnn.ReinforceGamma')
require('rnn.ReinforceBernoulli')
require('rnn.ReinforceNormal')
require('rnn.ReinforceCategorical')
-- REINFORCE criterions
require('rnn.VRClassReward')
require('rnn.BinaryClassReward')
-- for testing:
require('rnn.test')
require('rnn.bigtest')
-- recurrent modules
require('rnn.AbstractRecurrent')
require('rnn.Recursor')
require('rnn.Recurrence')
require('rnn.LinearRNN')
require('rnn.LookupRNN')
require('rnn.RecLSTM')
require('rnn.RecGRU')
require('rnn.GRU')
require('rnn.Mufuru')
require('rnn.NormStabilizer')
-- sequencer modules
require('rnn.AbstractSequencer')
require('rnn.Repeater')
require('rnn.Sequencer')
require('rnn.BiSequencer')
require('rnn.RecurrentAttention')
-- sequencer + recurrent modules
require('rnn.SeqLSTM')
require('rnn.SeqGRU')
require('rnn.SeqBLSTM')
require('rnn.SeqBGRU')
-- recurrent criterions:
require('rnn.AbstractSequencerCriterion')
require('rnn.SequencerCriterion')
require('rnn.RepeaterCriterion')
require('rnn.MaskZeroCriterion')
-- deprecated modules
require('rnn.LSTM')
require('rnn.FastLSTM')
require('rnn.SeqLSTMP')
require('rnn.SeqReverseSequence')
require('rnn.BiSequencerLM')
require('rnn.measure')
-- prevent likely name conflicts
nn.rnn = rnn
return rnn