Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Sep 26, 2023
1 parent fe2a39d commit a4e5982
Show file tree
Hide file tree
Showing 5 changed files with 82 additions and 39 deletions.
19 changes: 15 additions & 4 deletions jointContribution/PIRBN/analytical_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ def output_fig(train_obj, mu, b, right_by, activation_function):
plt.figure(figsize=(15, 9))
rbn = train_obj.pirbn.rbn

target_dir = os.path.join(os.path.dirname(__file__), "../target")
target_dir = os.path.join(os.path.dirname(__file__), "target")
if not os.path.exists(target_dir):
os.mkdir(target_dir)

Expand All @@ -21,7 +21,7 @@ def output_fig(train_obj, mu, b, right_by, activation_function):
xy = np.zeros((ns, 1)).astype(np.float32)
for i in range(0, ns):
xy[i, 0] = i * dx + right_by
y = rbn(paddle.to_tensor(xy), activation_function=activation_function)
y = rbn(paddle.to_tensor(xy))
y = y.numpy()
y_true = np.sin(2 * mu * np.pi * xy)
plt.plot(xy, y_true)
Expand All @@ -31,8 +31,9 @@ def output_fig(train_obj, mu, b, right_by, activation_function):

# Point-wise absolute error plot.
plt.subplot(2, 3, 2)
plt.plot(xy, np.abs(y_true - y))
plt.ylim(top=8e-3)
xy_y = np.abs(y_true - y)
plt.plot(xy, xy_y)
plt.ylim(top=np.max(xy_y))
plt.ylabel("Absolute Error")
plt.xlabel("x")

Expand All @@ -47,6 +48,16 @@ def output_fig(train_obj, mu, b, right_by, activation_function):
plt.ylabel("Loss")
plt.xlabel("Iteration")

# plt.subplot(2, 3, 3)
# loss_g = train_obj.his_a_g
# x = range(len(loss_g))
# plt.yscale("log")
# plt.plot(x, loss_g)
# plt.plot(x, train_obj.his_a_b)
# plt.legend(["a_g", "a_b"])
# plt.ylabel("Loss")
# plt.xlabel("Iteration")

# Visualise NTK after initialisation, The normalised Kg at 0th iteration.
plt.subplot(2, 3, 4)
jac = train_obj.ntk_list[0]
Expand Down
10 changes: 5 additions & 5 deletions jointContribution/PIRBN/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

# mu, Fig.1, Page5
# right_by, Formula (15) Page5
def sine_function_main(mu, right_by=0, activation_function="gaussian_function"):
def sine_function_main(mu, right_by=0, activation_function="gaussian"):
# Define the number of sample points
ns = 51

Expand All @@ -31,7 +31,7 @@ def sine_function_main(mu, right_by=0, activation_function="gaussian_function"):
c = [right_by - 0.1, right_by + 1.1]

# Set up PIRBN
rbn = rbn_net.RBN_Net(n_in, n_out, n_neu, b, c)
rbn = rbn_net.RBN_Net(n_in, n_out, n_neu, b, c, activation_function)
train_obj = train.Trainer(
pirbn.PIRBN(rbn, activation_function), x, y, learning_rate=0.001, maxiter=20001
)
Expand All @@ -41,11 +41,11 @@ def sine_function_main(mu, right_by=0, activation_function="gaussian_function"):
analytical_solution.output_fig(train_obj, mu, b, right_by, activation_function)


# # Fig.1
# sine_function_main(mu=4, right_by=0, activation_function="tanh")
# Fig.1
sine_function_main(mu=4, right_by=0, activation_function="tanh")
# # Fig.2
# sine_function_main(mu=8, right_by=0, activation_function="tanh")
# # Fig.3
# sine_function_main(mu=4, right_by=100, activation_function="tanh")
# Fig.6
sine_function_main(mu=8, right_by=100, activation_function="gaussian_function")
sine_function_main(mu=8, right_by=100, activation_function="gaussian")
19 changes: 9 additions & 10 deletions jointContribution/PIRBN/pirbn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,20 @@


class PIRBN(paddle.nn.Layer):
def __init__(self, rbn, activation_function="gaussian_function"):
def __init__(self, rbn, activation_function="gaussian"):
super().__init__()
self.rbn = rbn
self.activation_function = activation_function

def forward(self, input_data):
xy, xy_b = input_data
# initialize the differential operators
u_b = self.rbn(xy_b, self.activation_function)
u_b = self.rbn(xy_b)

# obtain partial derivatives of u with respect to x
xy.stop_gradient = False
# Obtain the output from the RBN
u = self.rbn(xy, self.activation_function)
u = self.rbn(xy)
# Obtain the first-order derivative of the output with respect to the input
u_x = paddle.grad(u, xy, retain_graph=True, create_graph=True)[0]
# Obtain the second-order derivative of the output with respect to the input
Expand All @@ -37,11 +37,12 @@ def cal_ntk(self, x):
y = self.forward(temp_x)
l1t = paddle.grad(y[0], self.parameters(), allow_unused=True)
for j in l1t:
if j is not None:
lambda_g = lambda_g + paddle.sum(j**2) / n1
lambda_g = lambda_g + paddle.sum(j**2) / n1
# When use tanh activation function, the value may be None
if l1t[0] is None and l1t[1] is not None:
temp = l1t[1].reshape((1, n_neu))
if self.activation_function == "tanh":
temp = paddle.concat(
(l1t[0], l1t[1], l1t[2].reshape((1, n_neu))), axis=1
)
else:
temp = paddle.concat((l1t[0], l1t[1].reshape((1, n_neu))), axis=1)
if i == 0:
Expand All @@ -57,9 +58,7 @@ def cal_ntk(self, x):
y = self.forward(temp_x)
l1t = paddle.grad(y[1], self.rbn.parameters(), allow_unused=True)
for j in l1t:
# When use tanh activation function, the value may be None
if j is not None:
lambda_b = lambda_b + paddle.sum(j**2) / n2
lambda_b = lambda_b + paddle.sum(j**2) / n2

# calculate adapt factors
temp = lambda_g + lambda_b
Expand Down
42 changes: 26 additions & 16 deletions jointContribution/PIRBN/rbn_net.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,15 +15,16 @@ class RBN_Net(paddle.nn.Layer):
c (List[float32]): Initial value for hyperparameter c.
"""

def __init__(self, n_in, n_out, n_neu, b, c):
def __init__(self, n_in, n_out, n_neu, b, c, activation_function="gaussian"):
super().__init__()
self.n_in = n_in
self.n_out = n_out
self.n_neu = n_neu
self.b = paddle.to_tensor(b)
self.c = paddle.to_tensor(c)
self.activation_function = activation_function

self.layer1 = RBF_layer1(self.n_neu, self.c, n_in)
self.layer1 = RBF_layer1(self.n_neu, self.c, n_in, activation_function)
# LeCun normal
std = math.sqrt(1 / self.n_neu)
self.linear = paddle.nn.Linear(
Expand All @@ -34,10 +35,12 @@ def __init__(self, n_in, n_out, n_neu, b, c):
),
bias_attr=False,
)
self.ini_ab()
# gaussian activation_function need to set self.b
if self.activation_function == "gaussian":
self.ini_ab()

def forward(self, x, activation_function="gaussian_function"):
temp = self.layer1(x, activation_function)
def forward(self, x):
temp = self.layer1(x)
y = self.linear(temp)
return y

Expand All @@ -57,28 +60,34 @@ class RBF_layer1(paddle.nn.Layer):
input_shape_last (int): Last item of input shape.
"""

def __init__(self, n_neu, c, input_shape_last):
def __init__(self, n_neu, c, input_shape_last, activation_function="gaussian"):
super(RBF_layer1, self).__init__()
self.n_neu = n_neu
self.c = c
self.activation_function = activation_function
if self.activation_function == "tanh":
self.w = self.create_parameter(
shape=(input_shape_last, self.n_neu),
dtype=paddle.get_default_dtype(),
# Convert from tensorflow tf.random_normal_initializer(), default value mean=0.0, std=0.05
default_initializer=paddle.nn.initializer.Normal(mean=0.0, std=0.05),
)
self.b = self.create_parameter(
shape=(input_shape_last, self.n_neu),
dtype=paddle.get_default_dtype(),
# Convert from tensorflow tf.random_normal_initializer(), default value mean=0.0, std=0.05
default_initializer=paddle.nn.initializer.Normal(mean=0.0, std=0.05),
)

def forward(
self, inputs, activation_function="gaussian_function"
): # Defines the computation from inputs to outputs
temp_x = paddle.matmul(inputs, paddle.ones((1, self.n_neu)))
if activation_function == "gaussian_function":
return self.gaussian_function(temp_x)
def forward(self, inputs):
if self.activation_function == "gaussian":
return self.gaussian_function(inputs)
else:
return self.tanh_function(temp_x)
return self.tanh_function(inputs)

# Gaussian function,Formula (19), Page7
def gaussian_function(self, temp_x):
def gaussian_function(self, inputs):
temp_x = paddle.matmul(inputs, paddle.ones((1, self.n_neu)))
x0 = (
paddle.reshape(
paddle.arange(self.n_neu, dtype=paddle.get_default_dtype()),
Expand All @@ -92,5 +101,6 @@ def gaussian_function(self, temp_x):
s = self.b * self.b
return paddle.exp(-(x_new * x_new) * s)

def tanh_function(self, temp_x):
return paddle.tanh(temp_x)
def tanh_function(self, inputs):
outputs = paddle.add(paddle.matmul(inputs, self.w), self.b)
return paddle.tanh(outputs)
31 changes: 27 additions & 4 deletions jointContribution/PIRBN/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,27 @@ def __init__(self, pirbn, x_train, y_train, learning_rate=0.001, maxiter=10000):
self.iter = 0
self.a_g = paddle.to_tensor(1.0)
self.a_b = paddle.to_tensor(1.0)
self.his_a_g = []
self.his_a_b = []
self.optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=self.pirbn.parameters()
)
self.ntk_list = {}
# Update loss by calculate ntk
self.update_loss_by_ntk = True

# For test
# if self.pirbn.activation_function == "tanh":
# self.update_loss_by_ntk = False

def Loss(self, x, y, a_g, a_b):
tmp = self.pirbn(x)
loss_g = 0.5 * paddle.mean(paddle.square(tmp[0] - y[0]))
loss_b = 0.5 * paddle.mean(paddle.square(tmp[1]))
loss = loss_g * a_g + loss_b * a_b
if self.update_loss_by_ntk:
loss = loss_g * a_g + loss_b * a_b
else:
loss = loss_g + loss_b
return loss, loss_g, loss_b

def evaluate(self):
Expand All @@ -41,9 +52,21 @@ def evaluate(self):
# boundary loss
self.loss_b.append(loss_b_numpy)
if self.iter % 200 == 0:
self.a_g, self.a_b, _ = self.pirbn.cal_ntk(self.x_train)
print("\ta_g =", float(self.a_g), "\ta_b =", float(self.a_b))
print("Iter: ", self.iter, "\tL1 =", loss_g_numpy, "\tL2 =", loss_b_numpy)
if self.update_loss_by_ntk:
self.a_g, self.a_b, _ = self.pirbn.cal_ntk(self.x_train)
print("\ta_g =", float(self.a_g), "\ta_b =", float(self.a_b))
print(
"Iter: ", self.iter, "\tL1 =", loss_g_numpy, "\tL2 =", loss_b_numpy
)
else:
a_g, a_b, _ = self.pirbn.cal_ntk(self.x_train)
print("\ta_g =", float(a_g), "\ta_b =", float(a_b))
print(
"Iter: ", self.iter, "\tL1 =", loss_g_numpy, "\tL2 =", loss_b_numpy
)
self.his_a_g.append(self.a_g)
self.his_a_b.append(self.a_b)

self.iter = self.iter + 1
return loss

Expand Down

0 comments on commit a4e5982

Please sign in to comment.