-
Notifications
You must be signed in to change notification settings - Fork 0
/
node.py
55 lines (41 loc) · 1.63 KB
/
node.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
# Node class for graph
from ad_numpy import ndarray_
from decorators import primitive
from grad_fns import grad_fn_mapping
class Node:
def __init__(self):
self.inputs = {"args" : None, "kwargs" : None}
self.outputs = None
self.op = None
self.name = None
self.grad_fn = None
self.grad = 0.0
self.grad_wrt_args = {}
self.grad_wrt_kwargs = {}
self.inputs_order = {}
def make_node(self, *, args, kwargs, outputs, op, name):
self.inputs = {"args" : args, "kwargs" : kwargs}
self.outputs = outputs
self.op = op
self.name = name
# assign the grad function mapping
if grad_fn_mapping.get(self.op) is None:
print ("Grad function not implemented for ", self.op)
assert False and "You are a failure"
self.grad_fn = grad_fn_mapping[self.op]
def __str__(self):
s = ""
s = s + "--- Node : " + self.name + " --- \n"
if (self.inputs["args"] is not None):
args_lst = list(self.inputs["args"])
for arg in args_lst:
s = s + " Arg : " + str(arg) + "\n"
if (self.inputs["kwargs"] is not None):
for kw in self.inputs["kwargs"].keys():
s = s + " KW : " + kw + str(self.inputs["kwargs"][kw]) + "\n"
s = s + " Outputs : " + str(self.outputs) + "\n"
s = s + " Opeeration : " + str(self.op) + "\n"
s = s + " Grad function : " + str(self.grad_fn) + "\n"
s = s + " Grad : " + str(self.grad) + "\n"
s = s + " Grad wrt args : " + str(self.grad_wrt_args) + "\n"
return s