-
Notifications
You must be signed in to change notification settings - Fork 0
/
kernel.h
176 lines (133 loc) · 3.63 KB
/
kernel.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
#ifndef _kernel_h_
#define _kernel_h_
#include <vector>
#include <memory>
#include <unordered_map>
#include "ndarray.h"
#include "graph.h"
class Kernel {
public:
virtual ~Kernel() { }
virtual void forward() { }
virtual void backward(const NDArray& output_grad) { }
const NDArray& get_value() const {
return value_;
}
void set_inputs(std::vector<NodeRef>& inputs) {
inputs_.clear();
inputs_.insert(inputs_.end(), inputs.begin(), inputs.end());
}
const std::vector<NodeRef>& get_inputs() const {
return inputs_;
}
const NDArray& get_gradient(const NodeRef& node) const {
static NDArray default_grad({1}, {0});
const auto& it = gradients_.find(node);
if (it != gradients_.end()) {
return it->second;
}
return default_grad;
}
void clear_gradients() {
gradients_.clear();
}
virtual std::string str() const {
return "kernel";
}
protected:
NDArray value_;
std::vector<NodeRef> inputs_;
std::unordered_map<NodeRef, NDArray> gradients_;
};
class ValueKernel : public Kernel {
public:
ValueKernel(const Shape& shape) {
// TODO
value_.zeros(shape.v());
}
void set_value(const NDArray& value) {
value_ = value;
}
virtual std::string str() const {
return value_.str();
}
};
class AddKernel : public Kernel {
public:
AddKernel() = default;
virtual std::string str() const override;
protected:
virtual void forward() override;
virtual void backward(const NDArray& output_grad) override;
};
class SubKernel : public Kernel {
public:
SubKernel() = default;
virtual std::string str() const override;
protected:
virtual void forward() override;
virtual void backward(const NDArray& output_grad) override;
};
class MulKernel : public Kernel {
public:
MulKernel() = default;
virtual std::string str() const override;
protected:
virtual void forward() override;
virtual void backward(const NDArray& output_grad) override;
};
class DotKernel : public Kernel {
public:
DotKernel() = default;
virtual std::string str() const override;
protected:
virtual void forward() override;
virtual void backward(const NDArray& output_grad) override;
};
class MatMulKernel : public Kernel {
public:
MatMulKernel() = default;
virtual std::string str() const override;
protected:
virtual void forward() override;
virtual void backward(const NDArray& output_grad) override;
};
class BatchMatMulKernel : public Kernel {
public:
BatchMatMulKernel() = default;
virtual std::string str() const override;
protected:
virtual void forward() override;
virtual void backward(const NDArray& output_grad) override;
};
class SoftmaxKernel : public Kernel {
public:
SoftmaxKernel() = default;
virtual std::string str() const override;
protected:
virtual void forward() override;
virtual void backward(const NDArray& output_grad) override;
private:
NDArray derivative_;
};
class SoftmaxCrossEntropyKernel : public Kernel {
public:
SoftmaxCrossEntropyKernel() = default;
virtual std::string str() const override;
protected:
virtual void forward() override;
virtual void backward(const NDArray& output_grad) override;
private:
NDArray derivative_;
};
class ReLUKernel : public Kernel {
public:
ReLUKernel() = default;
virtual std::string str() const override;
protected:
virtual void forward() override;
virtual void backward(const NDArray& output_grad) override;
private:
NDArray derivative_;
};
#endif // _kernel_h_