forked from tinygrad/tinygrad
-
Notifications
You must be signed in to change notification settings - Fork 0
/
mlops.py
219 lines (173 loc) · 8.59 KB
/
mlops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
from typing import Tuple, Optional
from tinygrad.helpers import argsort, ShapeType
from tinygrad.ops import UnaryOps, BinaryOps, ReduceOps
from tinygrad.tensor import Function
from tinygrad.lazy import LazyBuffer
import math
class Contiguous(Function):
def forward(self, x): return x.contiguous()
def backward(self, grad_output): return grad_output
class Cast(Function):
__slots__ = "input_dtype"
def forward(self, x, dtype):
self.input_dtype = x.dtype
return x.cast(dtype)
def backward(self, grad_output):
return grad_output.cast(self.input_dtype)
# ************* unary ops *************
class Sin(Function):
__slots__ = "x"
def forward(self, x: LazyBuffer) -> LazyBuffer:
self.x = x
return x.unary_op(UnaryOps.SIN)
def backward(self, grad: LazyBuffer) -> LazyBuffer:
return self.x.const_like(math.pi / 2).binary_op(BinaryOps.SUB, self.x).unary_op(UnaryOps.SIN).binary_op(BinaryOps.MUL, grad)
# NOTE: maximum(x, 0) behaves differently where x=0
class Relu(Function):
__slots__ = "ret"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.binary_op(BinaryOps.MAX, x.const_like(0))
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
mask = self.ret.const_like(1).binary_op(BinaryOps.SUB, self.ret.binary_op(BinaryOps.CMPEQ, self.ret.const_like(0)))
return mask.binary_op(BinaryOps.MUL, grad_output)
class Log(Function):
__slots__ = "x"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.x = x
return x.unary_op(UnaryOps.LOG2).binary_op(BinaryOps.MUL, x.const_like(math.log(2)))
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.binary_op(BinaryOps.DIV, self.x)
class Exp(Function):
__slots__ = "ret"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.binary_op(BinaryOps.MUL, x.const_like(1/math.log(2))).unary_op(UnaryOps.EXP2)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.binary_op(BinaryOps.MUL, grad_output)
class Sqrt(Function):
__slots__ = "ret"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.unary_op(UnaryOps.SQRT)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.binary_op(BinaryOps.DIV, self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(2)))
# NOTE: the implicit derivative of sigmoid is not stable
# https://towardsdatascience.com/derivative-of-the-sigmoid-function-536880cf918e
# TODO: have the backend automatically find this
class Sigmoid(Function):
__slots__ = "ret"
def forward(self, x:LazyBuffer) -> LazyBuffer:
self.ret = x.const_like(1).binary_op(BinaryOps.DIV, x.const_like(1).binary_op(BinaryOps.ADD, x.binary_op(BinaryOps.MUL, x.const_like(-1/math.log(2))).unary_op(UnaryOps.EXP2)))
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return self.ret.binary_op(BinaryOps.MUL, self.ret.const_like(1).binary_op(BinaryOps.SUB, self.ret)).binary_op(BinaryOps.MUL, grad_output)
# ************* reduce ops *************
class Sum(Function):
__slots__ = "input_shape"
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
self.input_shape = x.shape
return x.reduce_op(ReduceOps.SUM, new_shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.expand(self.input_shape)
class Max(Function):
__slots__ = "x", "ret"
def forward(self, x:LazyBuffer, new_shape:ShapeType) -> LazyBuffer:
self.x, self.ret = x, x.reduce_op(ReduceOps.MAX, new_shape)
return self.ret
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
# 1s in locations where the max was chosen (can be two locations)
max_is_1s = self.x.binary_op(BinaryOps.CMPEQ, self.ret.expand(self.x.shape))
# sum of locations, averaged
div = max_is_1s.reduce_op(ReduceOps.SUM, grad_output.shape).expand(self.x.shape)
max_is_amount = max_is_1s.binary_op(BinaryOps.DIV, div)
grad_output_expanded = grad_output.expand(self.x.shape)
return max_is_amount.binary_op(BinaryOps.MUL, grad_output_expanded)
# ************* binary ops *************
class Equal(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
return x.binary_op(BinaryOps.CMPEQ, y)
class Maximum(Function):
__slots__ = "x", "y", "ret"
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
self.ret = x.binary_op(BinaryOps.MAX, y)
return self.ret
def backward(self, grad_output:LazyBuffer):
mask = self.y.binary_op(BinaryOps.CMPEQ, self.ret)
eq = self.x.binary_op(BinaryOps.CMPEQ, self.y)
splitter = eq.const_like(2).binary_op(BinaryOps.SUB, eq).binary_op(BinaryOps.DIV, eq.const_like(2))
return grad_output.binary_op(BinaryOps.MUL, mask.const_like(1).binary_op(BinaryOps.SUB, mask).binary_op(BinaryOps.ADD, eq)).binary_op(BinaryOps.MUL, splitter) if self.needs_input_grad[0] else None, \
grad_output.binary_op(BinaryOps.MUL, mask).binary_op(BinaryOps.MUL, splitter) if self.needs_input_grad[1] else None
class Add(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
return x.binary_op(BinaryOps.ADD, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output if self.needs_input_grad[0] else None, \
grad_output if self.needs_input_grad[1] else None
class Sub(Function):
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
return x.binary_op(BinaryOps.SUB, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output if self.needs_input_grad[0] else None, \
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output) if self.needs_input_grad[1] else None
class Mul(Function):
__slots__ = 'x', 'y'
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
return x.binary_op(BinaryOps.MUL, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return self.y.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[0] else None, \
self.x.binary_op(BinaryOps.MUL, grad_output) if self.needs_input_grad[1] else None
class Div(Function):
__slots__ = 'x', 'y'
def forward(self, x:LazyBuffer, y:LazyBuffer) -> LazyBuffer:
self.x, self.y = x, y
return x.binary_op(BinaryOps.DIV, y)
def backward(self, grad_output:LazyBuffer) -> Tuple[Optional[LazyBuffer], Optional[LazyBuffer]]:
return grad_output.binary_op(BinaryOps.DIV, self.y) if self.needs_input_grad[0] else None, \
grad_output.const_like(0).binary_op(BinaryOps.SUB, grad_output).binary_op(BinaryOps.MUL, self.x).binary_op(BinaryOps.DIV, self.y.binary_op(BinaryOps.MUL, self.y)) if self.needs_input_grad[1] else None
# ************* movement ops *************
# NOTE: this is sum in reverse
class Expand(Function):
__slots__ = 'input_shape'
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
self.input_shape = x.shape
return x.expand(shape)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.reduce_op(ReduceOps.SUM, self.input_shape)
class Reshape(Function):
__slots__ = 'input_shape'
def forward(self, x:LazyBuffer, shape:ShapeType) -> LazyBuffer:
self.input_shape = x.shape
return x.reshape(shape)
def backward(self, grad_output:LazyBuffer):
return grad_output.reshape(self.input_shape)
class Permute(Function):
__slots__ = 'input_order'
def forward(self, x:LazyBuffer, order:Tuple[int, ...]) -> LazyBuffer:
self.input_order = order
return x.permute(order)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.permute(argsort(self.input_order))
class Pad(Function):
__slots__ = 'narg'
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
self.narg = tuple([(p[0], s+p[0]) for s,p in zip(x.shape, arg)])
return x.pad(arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.shrink(self.narg)
class Shrink(Function):
__slots__ = 'narg'
def forward(self, x:LazyBuffer, arg:Tuple[Tuple[int, int], ...]) -> LazyBuffer:
self.narg = tuple([(p[0], s-p[1]) for s,p in zip(x.shape, arg)])
return x.shrink(arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.pad(self.narg)
class Flip(Function):
__slots__ = 'arg'
def forward(self, x:LazyBuffer, axis:Tuple[int, ...]):
self.arg = tuple([-1 if i in set(axis) else 1 for i in range(len(x.shape))])
return x.stride(self.arg)
def backward(self, grad_output:LazyBuffer) -> LazyBuffer:
return grad_output.stride(self.arg)