Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Sep 26, 2023
1 parent fb5a9bf commit be7cfc2
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 5 deletions.
10 changes: 10 additions & 0 deletions jointContribution/PIRBN/analytical_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,16 @@ def output_fig(train_obj, mu, b, right_by, activation_function):
plt.ylabel("Loss")
plt.xlabel("Iteration")

# plt.subplot(2, 3, 3)
# loss_g = train_obj.his_a_g
# x = range(len(loss_g))
# plt.yscale("log")
# plt.plot(x, loss_g)
# plt.plot(x, train_obj.his_a_b)
# plt.legend(["a_g", "a_b"])
# plt.ylabel("Loss")
# plt.xlabel("Iteration")

# Visualise NTK after initialisation, The normalised Kg at 0th iteration.
plt.subplot(2, 3, 4)
jac = train_obj.ntk_list[0]
Expand Down
22 changes: 17 additions & 5 deletions jointContribution/PIRBN/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,19 +19,27 @@ def __init__(self, pirbn, x_train, y_train, learning_rate=0.001, maxiter=10000):
self.iter = 0
self.a_g = paddle.to_tensor(1.0)
self.a_b = paddle.to_tensor(1.0)
self.his_a_g = []
self.his_a_b = []
self.optimizer = paddle.optimizer.Adam(
learning_rate=0.001, parameters=self.pirbn.parameters()
)
self.ntk_list = {}
# Update loss by calculate ntk
self.update_loss_by_ntk = True

# For test
# if self.pirbn.activation_function == "tanh":
# self.update_loss_by_ntk = False

def Loss(self, x, y, a_g, a_b):
tmp = self.pirbn(x)
loss_g = 0.5 * paddle.mean(paddle.square(tmp[0] - y[0]))
loss_b = 0.5 * paddle.mean(paddle.square(tmp[1]))
if self.pirbn.activation_function == "tanh":
loss = loss_g + loss_b
else:
if self.update_loss_by_ntk:
loss = loss_g * a_g + loss_b * a_b
else:
loss = loss_g + loss_b
return loss, loss_g, loss_b

def evaluate(self):
Expand All @@ -44,16 +52,20 @@ def evaluate(self):
# boundary loss
self.loss_b.append(loss_b_numpy)
if self.iter % 200 == 0:
if self.pirbn.activation_function == "gaussian_function":
if self.update_loss_by_ntk:
self.a_g, self.a_b, _ = self.pirbn.cal_ntk(self.x_train)
print("\ta_g =", float(self.a_g), "\ta_b =", float(self.a_b))
print(
"Iter: ", self.iter, "\tL1 =", loss_g_numpy, "\tL2 =", loss_b_numpy
)
if self.pirbn.activation_function == "tanh":
else:
a_g, a_b, _ = self.pirbn.cal_ntk(self.x_train)
print("\ta_g =", float(a_g), "\ta_b =", float(a_b))
print(
"Iter: ", self.iter, "\tL1 =", loss_g_numpy, "\tL2 =", loss_b_numpy
)
self.his_a_g.append(self.a_g)
self.his_a_b.append(self.a_b)

self.iter = self.iter + 1
return loss
Expand Down

0 comments on commit be7cfc2

Please sign in to comment.