forked from dalab/hessian-rank
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy patharchitectures.py
77 lines (65 loc) · 2.83 KB
/
architectures.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import jax
from jax.experimental import stax
from jax.nn import leaky_relu
from utils import DenseNoBias
def fully_connected(units, classes, activation, init, bias=False):
"""
Implements a simple fully-connected neural network using the library stax based on jax.
:param units: list, list of hidden layer sizes, excluding output
:param classes: int, number of classes, i.e. outputs of the network
:param activation: str, non-linearity , one of 'linear', 'relu', 'tanh', 'elu', 'sigmoid', 'leaky_relu'
:param init: str, initialization scheme, one of 'orthogonal', 'uniform', 'glorot'
:param bias: bool, use bias in layers or not
:return: init_fn, apply_fn
"""
if activation is None or activation == 'linear':
if bias == False:
architecture = [DenseNoBias(i, W_init=init()) for i in units]
if classes == 2:
# If only two classes, we use one output and encode labels as -1,1
architecture += [DenseNoBias(1, W_init=init())]
else:
architecture += [DenseNoBias(classes, W_init=init())]
else:
architecture = [stax.Dense(i, W_init=init()) for i in units]
if classes == 2:
architecture += [stax.Dense(1, W_init=init())]
else:
architecture += [stax.Dense(classes, W_init=init())]
elif activation == 'relu':
architecture = []
for i in range(len(units)):
architecture += [DenseNoBias(units[i], W_init=init()), stax.Relu]
if classes == 2:
architecture += [DenseNoBias(1)]
else:
architecture += [DenseNoBias(classes)]
elif activation == 'tanh':
architecture = []
for i in range(len(units)):
architecture += [DenseNoBias(units[i], W_init=init()), stax.Tanh]
architecture += [DenseNoBias(classes)]
elif activation == 'elu':
architecture = []
for i in range(len(units)):
architecture += [DenseNoBias(units[i], W_init=init()), stax.Elu]
architecture += [DenseNoBias(classes)]
elif activation == 'sigmoid':
architecture = []
for i in range(len(units)):
architecture += [DenseNoBias(units[i], W_init=init()), stax.Sigmoid]
architecture += [DenseNoBias(classes)]
elif activation == 'leaky_relu':
def leaky_relu_fixed(x):
return leaky_relu(x, negative_slope=0.01)
architecture = []
for i in range(len(units)):
architecture += [DenseNoBias(units[i]), jax.experimental.stax.elementwise(leaky_relu_fixed)]
if classes == 2:
architecture += [DenseNoBias(1)]
else:
architecture += [DenseNoBias(classes)]
init_fn, apply_fn = stax.serial(
*architecture
)
return init_fn, apply_fn