diff --git a/jointContribution/PIRBN/rbn_net.py b/jointContribution/PIRBN/rbn_net.py index d03e27f88..d148c5187 100644 --- a/jointContribution/PIRBN/rbn_net.py +++ b/jointContribution/PIRBN/rbn_net.py @@ -69,7 +69,6 @@ def __init__(self, n_neu, c, input_shape_last): ) def forward(self, inputs): # Defines the computation from inputs to outputs - s = self.b * self.b temp_x = paddle.matmul(inputs, paddle.ones((1, self.n_neu))) x0 = ( paddle.reshape( @@ -80,5 +79,10 @@ def forward(self, inputs): # Defines the computation from inputs to outputs / (self.n_neu - 1) + self.c[0] ) - x_new = (temp_x - x0) * (temp_x - x0) - return paddle.exp(-x_new * s) + x_new = temp_x - x0 + return self.rbf_activate(x_new) + + # activation function + def rbf_activate(self, input): + s = self.b * self.b + return paddle.exp(-(input * input) * s)