-
Notifications
You must be signed in to change notification settings - Fork 5
/
Copy pathembedding.go
316 lines (272 loc) · 8.05 KB
/
embedding.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
package golgi
import (
"github.com/pkg/errors"
G "gorgonia.org/gorgonia"
"gorgonia.org/qol"
"gorgonia.org/tensor"
)
func AsRunner() ConsOpt {
return func(layer Layer) (Layer, error) {
switch l := layer.(type) {
case *Embedding:
l.selectFn = runnerindices
return layer, nil
case Pass:
return layer, nil
default:
return nil, errors.Errorf("AsRunner is a construction optin that only supports embeddings so far. %T is not supported.", layer)
}
}
}
// WithClasses is a construction option that specifies how many classes are there in the embedding layer.
func WithClasses(classes int) ConsOpt {
return func(layer Layer) (Layer, error) {
switch l := layer.(type) {
case *Embedding:
l.classes = classes
return layer, nil
case Pass:
return layer, nil
default:
return nil, errors.Errorf("WithClasses is a construction option that only supports Embedding. %T is not supported", layer)
}
}
}
// WithOneHotInput is a construction option for a *Embedding that specifiess the behaviour to accept one-hot-vector/matrix as input.
func WithOneHotInput() ConsOpt {
return func(layer Layer) (Layer, error) {
switch l := layer.(type) {
case *Embedding:
l.selectFn = onehotindices
return layer, nil
case Pass:
return layer, nil
default:
return nil, errors.Errorf("WithOneHotInput is a construction option that is only supported by Embedding. %T is not supported", layer)
}
}
}
type selectionmethod int
const (
byindices selectionmethod = iota
onehotindices
runnerindices
)
// Embedding is a layer that represents an embedding layer.
//
// An embedding layer is essentially a matrix of the shape (classes, dimensions).
// While the Embedding layer can be done by means of using FC, this provides some ease of use.
//
// Specifically the Embedding layer supports forwarding of an input that is a slice of classes.
//
// Let us look at a word-based example, consider the vocab size to be the number of classes. The classical
// word embedding then is simply a (vocab, dims) matrix. Let's set the dims to be 50. For simplicity, let's set the vocab to 10.
// So that the embedding matrix W is a (10, 50) matrix.
//
// W := ⸢w1_1 ... w1_50⸣
// ⋮
// ⸤w10_1 ... w10_50⸥
//
// To select a word vector, we simply slice the matrix. For example, to get the vector of word ID 2, we slice W[2].
// This gives us a 50-dimension vector:
//
// W[2] = [w2_1 ... w2_50]
//
// We can equally do this by multiplying the matrix with a one-hot vector. A vector O given as
// O := [0 0 1 0 0 0 0 0 0 0]
// when multiplied against W, will yield the same result as W[2].
//
// The usual way of selecting from a embedding matrix with a one-hot vector is quite cumbersome. This struct makes it easy.
//
// You can pass in a *tensor.Dense of qol.Class:
//
// wv := tensor.New(tensor.WithBacking([]qol.Class{4, 10, 0, 0, 0, 0, 0, 0,0,0}))
// words := gorgonia.NewVector(g, gorgonia.WithShape(10), gorgonia.WithValue(wv))
// layer.Fwd(words)
//
// The Embedding layer's Fwd function will automatically transform a slice of classes into a one-hot matrix to be multiplied with.
type Embedding struct {
w *G.Node
// internal computation stuff
// oh is a one hot vector/matrix used to "select" from w
oh *G.Node
// config
//selectFn is the kind of selection function used
selectFn selectionmethod
// batch size
bs int
// size
dims int
// clases
classes int
// name
name string
// initialized
initialized bool
// of
of tensor.Dtype
// whether to compute FLOPs
computeFLOPs bool
// computed FLOPs
flops int
}
// NewEmbedding creates a new embedding layer.
func NewEmbedding(opts ...ConsOpt) *Embedding {
retVal := &Embedding{
of: tensor.Float64, // default
bs: 1,
}
for _, opt := range opts {
l, err := opt(retVal)
if err != nil {
panic(err)
}
retVal, _ = l.(*Embedding)
}
if retVal.w != nil &&
(retVal.selectFn != runnerindices ||
retVal.selectFn == runnerindices && retVal.oh != nil) {
retVal.initialized = true
}
if retVal.bs < 1 {
retVal.bs = 1
}
return retVal
}
// Model returns the gorgonia.Nodes associated with the embedding layer.
func (l *Embedding) Model() G.Nodes { return G.Nodes{l.w} }
func (l *Embedding) Fwd(a G.Input) G.Result {
if err := G.CheckOne(a); err != nil {
return G.Err(errors.Wrapf(err, "Fwd of Embedding %v", l.name))
}
if !l.initialized {
if err := l.Init(a.Node()); err != nil {
return G.Err(errors.Wrapf(err, "lazy initialization of %v failed", l.name))
}
l.initialized = true
}
shp := a.Node().Shape()
if shp.Dims() > 2 {
// error or reshape?
// error for now:
return G.Err(errors.Errorf("Cannot accept input of shape %v in Embedding", shp))
}
oh := a.Node()
var useOneHot bool
switch l.selectFn {
case onehotindices:
useOneHot = true
case runnerindices:
oh = l.oh
err := l.Run(a.Node())
if err != nil {
return G.Err(err)
}
useOneHot = true
default:
}
if useOneHot {
retVal, err := G.Mul(oh, l.w)
if err != nil {
return G.Err(errors.Wrapf(err, "Fwd of Embedding %v - Mul error", l.name))
}
if l.selectFn == runnerindices {
switch shp.Dims() {
case 2:
// reshape result to (bs, dims)
retVal, err = G.Reshape(retVal, tensor.Shape{shp[0], shp[1], l.dims})
if err != nil {
return G.Err(errors.Wrapf(err, "Failed to reshape retVal to (%v, %v)", l.bs, l.dims))
}
case 1:
// NOOP
case 0:
// NOOP
}
}
return retVal
}
return G.LiftResult(G.ByIndices(l.w, a.Node(), 0))
}
func (l *Embedding) Name() string { return l.name }
func (l *Embedding) Describe() {}
func (l *Embedding) IsInitialized() bool { return l.initialized }
// Init initializes the embedding layer.
func (l *Embedding) Init(xs ...*G.Node) (err error) {
x := xs[0]
g := x.Graph()
of := l.of
if l.w == nil {
l.w = G.NewMatrix(g, of, G.WithShape(l.classes, l.dims), G.WithInit(G.GlorotN(1)), G.WithName(l.name))
}
if l.selectFn == runnerindices {
// we need to construct the one-hot matrix as well
l.oh = G.NewMatrix(g, of, G.WithShape(l.bs, l.classes), G.WithInit(G.Zeroes()), G.WithName(l.name+"dummy-1hot"))
}
return nil
}
// ConsEmbedding is a construction function to construct a *Embedding. This is typically used in a L() construction manner.
func ConsEmbedding(in G.Input, opts ...ConsOpt) (retVal Layer, err error) {
l := new(Embedding)
for _, opt := range opts {
var o Layer
var ok bool
if o, err = opt(l); err != nil {
return nil, err
}
if l, ok = o.(*Embedding); !ok {
return nil, errors.Errorf("Construction option for an embedding layer returned a non *Embedding. Got %T instead", o)
}
}
if err := G.CheckOne(in); err != nil {
return nil, errors.Wrapf(err, "Cons of an embedding layer %v", l.name)
}
x := in.Node()
if err = l.Init(x); err != nil {
return nil, err
}
return l, nil
}
// Graph returns the underlying computation graph. Embedding implements Grapher.
func (l *Embedding) Graph() *G.ExprGraph { return l.w.Graph() }
// Run is a function that sets the internal one hot vector/matrix
func (l *Embedding) Run(input G.Input) (err error) {
if l.selectFn != runnerindices {
return errors.Errorf("Cannot call Run on Embedding. The selection function is not a runnerindices.")
}
if err := G.CheckOne(input); err != nil {
return errors.Wrapf(err, "Failed to run Embedding %v", l.name)
}
a := input.Node()
T, _ := a.Value().(*tensor.Dense)
vec := T.Data()
oh, _ := l.oh.Value().(*tensor.Dense)
var classes []qol.Class
switch v := vec.(type) {
case []qol.Class:
classes = v
case []uint:
classes = make([]qol.Class, len(v))
for i := range classes {
classes[i] = qol.Class(v[i])
}
case []int:
classes = make([]qol.Class, len(v))
for i := range classes {
classes[i] = qol.Class(v[i])
}
case []float32:
classes = make([]qol.Class, len(v))
for i := range classes {
classes[i] = qol.Class(v[i])
}
case []float64:
classes = make([]qol.Class, len(v))
for i := range classes {
classes[i] = qol.Class(v[i])
}
}
return G.Let(l.oh, qol.UnsafeToOneHotMatrix(classes, uint(l.classes), oh))
}
// Runners returns the embedding itself
func (l *Embedding) Runners() []Runner { return []Runner{l} }