From 913dd056bf1b86317791ba33ffca99b079801523 Mon Sep 17 00:00:00 2001 From: co63oc Date: Thu, 14 Sep 2023 07:44:19 +0800 Subject: [PATCH] Fix --- .../PIRBN/1D_nonlinear_spring/Cal_jac.py | 50 ----- .../PIRBN/1D_nonlinear_spring/Dif_op.py | 34 ---- .../PIRBN/1D_nonlinear_spring/Main.py | 59 ------ .../PIRBN/1D_nonlinear_spring/OPT.py | 120 ------------ .../PIRBN/1D_nonlinear_spring/PIRBN.py | 29 --- jointContribution/PIRBN/README.md | 39 +--- .../PIRBN/analytical_solution.py | 72 +++++++ jointContribution/PIRBN/main.py | 38 ++++ jointContribution/PIRBN/pirbn.py | 85 ++++++++ .../{1D_nonlinear_spring => }/rbn_net.py | 185 +++++++++--------- jointContribution/PIRBN/requirements.txt | 7 - jointContribution/PIRBN/train.py | 57 ++++++ 12 files changed, 343 insertions(+), 432 deletions(-) delete mode 100644 jointContribution/PIRBN/1D_nonlinear_spring/Cal_jac.py delete mode 100644 jointContribution/PIRBN/1D_nonlinear_spring/Dif_op.py delete mode 100644 jointContribution/PIRBN/1D_nonlinear_spring/Main.py delete mode 100644 jointContribution/PIRBN/1D_nonlinear_spring/OPT.py delete mode 100644 jointContribution/PIRBN/1D_nonlinear_spring/PIRBN.py create mode 100644 jointContribution/PIRBN/analytical_solution.py create mode 100644 jointContribution/PIRBN/main.py create mode 100644 jointContribution/PIRBN/pirbn.py rename jointContribution/PIRBN/{1D_nonlinear_spring => }/rbn_net.py (85%) delete mode 100644 jointContribution/PIRBN/requirements.txt create mode 100644 jointContribution/PIRBN/train.py diff --git a/jointContribution/PIRBN/1D_nonlinear_spring/Cal_jac.py b/jointContribution/PIRBN/1D_nonlinear_spring/Cal_jac.py deleted file mode 100644 index 35556e3bb..000000000 --- a/jointContribution/PIRBN/1D_nonlinear_spring/Cal_jac.py +++ /dev/null @@ -1,50 +0,0 @@ -import paddle - - -def cal_adapt(pirbn, x): - lamda_g = 0.0 - lamda_b1 = 0.0 - lamda_b2 = 0.0 - n_neu = len(pirbn.get_weights()[1]) - - ### in-domain - n1 = x[0].shape[0] - for i in range(n1): - temp_x = [x[0][i, ...].unsqueeze(0), paddle.to_tensor([[0.0]])] - temp_x[0].stop_gradient = False - temp_x[1].stop_gradient = False - y = pirbn(temp_x) - l1t = paddle.grad( - y[0], pirbn.parameters(), retain_graph=True, create_graph=True - ) - for j in l1t: - lamda_g = lamda_g + paddle.sum(j**2) / n1 - temp = paddle.concat((l1t[0], l1t[1].reshape((1, n_neu))), axis=1) - if i == 0: - jac = temp - else: - jac = paddle.concat((jac, temp), axis=0) - ### bound - n2 = x[1].shape[0] - for i in range(n2): - temp_x = [paddle.to_tensor([[0.0]]), x[1][i, ...].unsqueeze(0)] - temp_x[0].stop_gradient = False - temp_x[1].stop_gradient = False - y = pirbn(temp_x) - l1t = paddle.grad( - y[1], pirbn.parameters(), retain_graph=True, create_graph=True - ) - l2t = paddle.grad( - y[2], pirbn.parameters(), retain_graph=True, create_graph=True - ) - for j in l1t: - lamda_b1 = lamda_b1 + paddle.sum(j**2) / n2 - for j in l2t: - lamda_b2 = lamda_b2 + paddle.sum(j**2) / n2 - ### calculate adapt factors - temp = lamda_g + lamda_b1 + lamda_b2 - lamda_g = temp / lamda_g - lamda_b1 = temp / lamda_b1 - lamda_b2 = temp / lamda_b2 - - return lamda_g, lamda_b1, lamda_b2, jac diff --git a/jointContribution/PIRBN/1D_nonlinear_spring/Dif_op.py b/jointContribution/PIRBN/1D_nonlinear_spring/Dif_op.py deleted file mode 100644 index d90712f99..000000000 --- a/jointContribution/PIRBN/1D_nonlinear_spring/Dif_op.py +++ /dev/null @@ -1,34 +0,0 @@ -import paddle - - -class Dif(paddle.nn.Layer): - """This function is to initialise for differential operator. - - Args: - rbn (model): The Feedforward Neural Network. - """ - - def __init__(self, rbn, **kwargs): - super().__init__(**kwargs) - self.rbn = rbn - - def forward(self, x): - """This function is to calculate the differential terms. - - Args: - x (Tensor): The coordinate array - - Returns: - Tuple[Tensor, Tensor]: The first-order derivative of the u with respect to the x; The second-order derivative of the u with respect to the x. - """ - - x.stop_gradient = False - ### Apply the GradientTape function - ### Obtain the output from the RBN - u = self.rbn(x) - ### Obtain the first-order derivative of the output with respect to the input - u_x = paddle.grad(u, x, retain_graph=True, create_graph=True)[0] - - ### Obtain the second-order derivative of the output with respect to the input - u_xx = paddle.grad(u_x, x, retain_graph=True, create_graph=True)[0] - return u_x, u_xx diff --git a/jointContribution/PIRBN/1D_nonlinear_spring/Main.py b/jointContribution/PIRBN/1D_nonlinear_spring/Main.py deleted file mode 100644 index 8653bb4cd..000000000 --- a/jointContribution/PIRBN/1D_nonlinear_spring/Main.py +++ /dev/null @@ -1,59 +0,0 @@ -import os - -import matplotlib.pyplot as plt -import numpy as np -import OPT -import paddle -import PIRBN -import rbn_net -import scipy.io - -### Define the number of sample points -ns = 1001 - -### Define the sample points' interval -dx = 100.0 / (ns - 1) - -### Initialise sample points' coordinates -xy = np.zeros((ns, 1)).astype(np.float32) -for i in range(0, ns): - xy[i, 0] = i * dx -xy_b = np.array([[0.0]]) - -x = [xy, xy_b] -y = [2 * np.cos(xy) + 3 * xy * np.sin(xy) + np.sin(xy * np.sin(xy))] - -### Set up radial basis network -n_in = 1 -n_out = 1 -n_neu = 1021 -b = 1.0 -c = [-1.0, 101.0] - -rbn = rbn_net.RBN_Net(n_in, n_out, n_neu, b, c) - -### Set up PIRBN -pirbn = PIRBN.PIRBN(rbn) - -### Train the PIRBN -opt = OPT.Adam(pirbn, x, y, learning_rate=0.001, maxiter=401) -result = opt.fit() - -### Visualise results -ns = 1001 -dx = 100 / (ns - 1) -xy = np.zeros((ns, 1)).astype(np.float32) -for i in range(0, ns): - xy[i, 0] = i * dx -y = rbn(paddle.to_tensor(xy)) -y = y.numpy() -plt.plot(xy, y) -plt.plot(xy, xy * np.sin(xy)) -plt.legend(["predict", "ground truth"]) -target_dir = os.path.join(os.path.dirname(__file__), "/../target") -if not os.path.exists(target_dir): - os.path.mkdir(target_dir) -plt.savefig(os.path.join(target_dir, "1D_nonlinear_spring.png")) - -### Save data -scipy.io.savemat(os.path.join(target_dir, "1D_nonlinear_spring.mat"), {"x": xy, "y": y}) diff --git a/jointContribution/PIRBN/1D_nonlinear_spring/OPT.py b/jointContribution/PIRBN/1D_nonlinear_spring/OPT.py deleted file mode 100644 index 15fcd1352..000000000 --- a/jointContribution/PIRBN/1D_nonlinear_spring/OPT.py +++ /dev/null @@ -1,120 +0,0 @@ -import Cal_jac -import numpy as np -import paddle - -paddle.framework.core.set_prim_eager_enabled(True) - - -class Adam: - def __init__(self, pirbn, x_train, y_train, learning_rate=0.001, maxiter=10000): - # set attributes - self.pirbn = pirbn - self.learning_rate = learning_rate - self.x_train = [ - paddle.to_tensor(x, dtype=paddle.get_default_dtype()) for x in x_train - ] - self.y_train = [ - paddle.to_tensor(y, dtype=paddle.get_default_dtype()) for y in y_train - ] - self.maxiter = maxiter - self.his_l1 = [] - self.his_l2 = [] - self.his_l3 = [] - self.iter = 0 - self.a_g = 1.0 - self.a_b1 = 1.0 - self.a_b2 = 1.0 - - def set_weights(self, flat_weights): - # get model weights - shapes = [w.shape for w in self.pirbn.get_weights()] - # compute splitting indices - split_ids = np.cumsum([np.prod(shape) for shape in [0] + shapes]) - # reshape weights - weights = [ - flat_weights[from_id:to_id] - .reshape(shape) - .astype(paddle.get_default_dtype()) - for from_id, to_id, shape in zip(split_ids[:-1], split_ids[1:], shapes) - ] - # set weights to the model - self.pirbn.set_weights(weights) - - def Loss(self, x, y, a_g, a_b1, a_b2): - tmp = self.pirbn(x) - l1 = 0.5 * paddle.mean(paddle.square(tmp[0] - y[0])) - l2 = 0.5 * paddle.mean(paddle.square(tmp[1])) - l3 = 0.5 * paddle.mean(paddle.square(tmp[2])) - loss = l1 * a_g + l2 * a_b1 + l3 * a_b2 - grads = paddle.grad( - loss, self.pirbn.parameters(), retain_graph=True, create_graph=True - ) - return loss, grads, l1, l2, l3 - - def evaluate(self, weights): - weights = paddle.to_tensor(weights) - # update weights - self.set_weights(weights) - # compute loss and gradients for weights - loss, grads, l1, l2, l3 = self.Loss( - self.x_train, self.y_train, self.a_g, self.a_b1, self.a_b2 - ) - l1_numpy = float(l1) - l2_numpy = float(l2) - l3_numpy = float(l3) - self.his_l1.append(l1_numpy) - self.his_l2.append(l2_numpy) - self.his_l3.append(l3_numpy) - if self.iter % 200 == 0: - self.a_g, self.a_b1, self.a_b2, _ = Cal_jac.cal_adapt( - self.pirbn, self.x_train - ) - print( - "\ta_g =", - float(self.a_g), - "\ta_b1 =", - float(self.a_b1), - "\ta_b2 =", - float(self.a_b2), - ) - print( - "Iter: ", - self.iter, - "\tL1 =", - l1_numpy, - "\tL2 =", - l2_numpy, - "\tL3 =", - l3_numpy, - ) - self.iter = self.iter + 1 - # convert tf.Tensor to flatten ndarray - loss = loss.astype("float64").item() - grads = np.concatenate([g.numpy().flatten() for g in grads]).astype("float64") - return loss, grads - - def fit(self): - # get initial weights as a flat vector - initial_weights = np.concatenate( - [w.numpy().flatten() for w in self.pirbn.get_weights()] - ) - print(f"Optimizer: Adam (maxiter={self.maxiter})") - beta1 = 0.9 - beta2 = 0.999 - learning_rate = self.learning_rate - eps = 1e-8 - x0 = initial_weights - x = x0 - m = np.zeros_like(x) - v = np.zeros_like(x) - b_w = 0 - - for i in range(0, self.maxiter): - loss, g = self.evaluate(x) - m = (1 - beta1) * g + beta1 * m - v = (1 - beta2) * (g**2) + beta2 * v # second moment estimate. - mhat = m / (1 - beta1 ** (i + 1)) # bias correction. - vhat = v / (1 - beta2 ** (i + 1)) - x = x - learning_rate * mhat / (np.sqrt(vhat) + eps) - - return loss, [self.his_l1, self.his_l2], b_w diff --git a/jointContribution/PIRBN/1D_nonlinear_spring/PIRBN.py b/jointContribution/PIRBN/1D_nonlinear_spring/PIRBN.py deleted file mode 100644 index 77e6863f3..000000000 --- a/jointContribution/PIRBN/1D_nonlinear_spring/PIRBN.py +++ /dev/null @@ -1,29 +0,0 @@ -import paddle -from Dif_op import Dif - - -class PIRBN(paddle.nn.Layer): - def __init__(self, rbn): - super().__init__() - self.rbn = rbn - - def forward(self, input_data): - xy, xy_b = input_data - ### initialize the differential operators - Dif_u = Dif(self.rbn) - u = self.rbn(xy) - u_b = self.rbn(xy_b) - - ### obtain partial derivatives of u with respect to x - _, u_xx = Dif_u(xy) - u_b_x, _ = Dif_u(xy_b) - t = u_xx + 4 * u + paddle.sin(u) - - ### build up the PIRBN - return [t, u_b, u_b_x] - - def get_weights(self): - return self.rbn.get_weights() - - def set_weights(self, weights): - self.rbn.set_weights(weights) diff --git a/jointContribution/PIRBN/README.md b/jointContribution/PIRBN/README.md index 148c233e3..ccd186ba7 100644 --- a/jointContribution/PIRBN/README.md +++ b/jointContribution/PIRBN/README.md @@ -13,44 +13,11 @@ Inspired by findings, we proposed the PIRBN, which can exhibit the local propert Numerical examples include: - - 1D sine funtion (**Eq. 15** in the manuscript) + - 1D sine funtion (**Eq. 1** in the manuscript) - **PDE**: $\frac{\partial^2 }{\partial x^2}u(x-100)-4\mu^2\pi^2 sin(2\mu\pi(x-100))=0, x\in[100,101]$ + **PDE**: $\frac{\partial^2 }{\partial x^2}u(x)-4\mu^2\pi^2 sin(2\mu\pi(x))=0, x\in[0,1]$ - **BC**: $u(100)=u(101)=0.$ - - - 1D sine function coupling problem (**Eq. 30** in the manuscript) - - **PDE**: $\frac{\partial^2 }{\partial x^2}u(x)=f(x), x\in[20,22]$ - - **BC**: $u(20)=u(22)=0.$ - - - 1D nonlinear spring equation (**Eq. 31** in the manuscript) - - **PDE**: $\frac{\partial^2 }{\partial x^2}u(x)+4u(x)+sin[u(x)]=f(x), x\in[0,100]$ - - **BC**: $u(0)=\frac{\partial }{\partial x}u(0)=0.$ - - - 2D wave equation (**Eq. 33** in the manuscript) - - **PDE**: $(\frac{\partial^2 }{\partial x^2}+4\frac{\partial^2 }{\partial y^2})u(x,y)=0, x\in[0,1], y\in[0,1]$ - - **BC**: $u(x,0)=u(x,1)=\frac{\partial }{\partial x}u(0,y)=0,$ - $u(0,y)=sin(\pi y)+0.5sin(4\pi y).$ - - - 2D diffusion equation (**Eq. 35** in the manuscript) - - **PDE**: $(\frac{\partial}{\partial t}-0.01\frac{\partial^2 }{\partial x^2})u(x,t)=g(x,t), x\in[5,10], y\in[5,10]$ - - **BC\IC**: $u(5,t)=b_1(t),u(10,t)=b_2(t),u(x,5)=b_3(x).$ - - - 2D viscoelastic Poiseuille problem (**Eq. 37** in the manuscript) - - **PDEs**: $\rho\frac{\partial}{\partial t}u(y,t)=-f+\frac{\partial}{\partial y}\tau_{xy}(y,t), t\in[0,4],$ - $\eta_0\frac{\partial}{\partial y}u(y,t)=(\lambda\frac{\partial}{\partial t}+1)\tau_{xy}(y,t), y\in[0,1],$ - - **BC\IC**: $u(\pm0.5,t)=u(y,0)=0,$ - $\tau(y,0)=0.$ + **BC**: $u(0)=u(1)=0.$ For more details in terms of mathematical proofs and numerical examples, please refer to our paper. diff --git a/jointContribution/PIRBN/analytical_solution.py b/jointContribution/PIRBN/analytical_solution.py new file mode 100644 index 000000000..7bc9ff9a7 --- /dev/null +++ b/jointContribution/PIRBN/analytical_solution.py @@ -0,0 +1,72 @@ +import os + +import matplotlib.pyplot as plt +import numpy as np +import paddle +import scipy.io + + +def output_fig(train_obj, mu, b): + plt.figure(figsize=(15, 9)) + rbn = train_obj.pirbn.rbn + + target_dir = os.path.join(os.path.dirname(__file__), "../target") + if not os.path.exists(target_dir): + os.mkdir(target_dir) + + # Comparisons between the PINN predictions and the ground truth. + plt.subplot(2, 3, 1) + ns = 1001 + dx = 1 / (ns - 1) + xy = np.zeros((ns, 1)).astype(np.float32) + for i in range(0, ns): + xy[i, 0] = i * dx + y = rbn(paddle.to_tensor(xy)) + y = y.numpy() + y_true = np.sin(2 * mu * np.pi * xy) + plt.plot(xy, y_true) + plt.plot(xy, y, linestyle="--") + plt.legend(["ground truth", "PINN"]) + + # Point-wise absolute error plot. + plt.subplot(2, 3, 2) + plt.plot(xy, np.abs(y_true - y)) + plt.ylim(top=1e-3) + plt.legend(["Absolute Error"]) + + # Loss history of the PINN during the training process. + plt.subplot(2, 3, 3) + his_l1 = train_obj.his_l1 + x = range(len(his_l1)) + plt.yscale("log") + plt.plot(x, his_l1) + plt.plot(x, train_obj.his_l2) + plt.legend(["Lg", "Lb"]) + + # Visualise NTK after initialisation, The normalised Kg at 0th iteration. + plt.subplot(2, 3, 4) + jac = train_obj.ntk_list[0] + a = np.dot(jac, np.transpose(jac)) + plt.imshow(a / (np.max(abs(a))), cmap="bwr", vmax=1, vmin=-1) + plt.colorbar() + + # Visualise NTK after training, The normalised Kg at 2000th iteration. + plt.subplot(2, 3, 5) + if 2000 in train_obj.ntk_list: + jac = train_obj.ntk_list[2000] + a = np.dot(jac, np.transpose(jac)) + plt.imshow(a / (np.max(abs(a))), cmap="bwr", vmax=1, vmin=-1) + plt.colorbar() + + # The normalised Kg at 20000th iteration. + plt.subplot(2, 3, 6) + if 20000 in train_obj.ntk_list: + jac = train_obj.ntk_list[20000] + a = np.dot(jac, np.transpose(jac)) + plt.imshow(a / (np.max(abs(a))), cmap="bwr", vmax=1, vmin=-1) + plt.colorbar() + + plt.savefig(os.path.join(target_dir, f"sine_function_{mu}_{b}.png")) + + # Save data + scipy.io.savemat(os.path.join(target_dir, "out.mat"), {"NTK": a, "x": xy, "y": y}) diff --git a/jointContribution/PIRBN/main.py b/jointContribution/PIRBN/main.py new file mode 100644 index 000000000..e453f58f3 --- /dev/null +++ b/jointContribution/PIRBN/main.py @@ -0,0 +1,38 @@ +import analytical_solution +import numpy as np +import pirbn +import rbn_net +import train + +# Define mu +mu = 4 + +# Define the number of sample points +ns = 51 + +# Define the sample points' interval +dx = 1.0 / (ns - 1) + +# Initialise sample points' coordinates +xy = np.zeros((ns, 1)).astype(np.float32) +for i in range(0, ns): + xy[i, 0] = i * dx +xy_b = np.array([[0.0], [1.0]]) + +x = [xy, xy_b] +y = [-4 * mu**2 * np.pi**2 * np.sin(2 * mu * np.pi * xy)] + +# Set up radial basis network +n_in = 1 +n_out = 1 +n_neu = 61 +b = 10.0 +c = [-0.1, 1.1] + +# Set up PIRBN +rbn = rbn_net.RBN_Net(n_in, n_out, n_neu, b, c) +train_obj = train.Trainer(pirbn.PIRBN(rbn), x, y, learning_rate=0.001, maxiter=20001) +train_obj.fit() + +# Visualise results +analytical_solution.output_fig(train_obj, mu, b) diff --git a/jointContribution/PIRBN/pirbn.py b/jointContribution/PIRBN/pirbn.py new file mode 100644 index 000000000..02f5c86ee --- /dev/null +++ b/jointContribution/PIRBN/pirbn.py @@ -0,0 +1,85 @@ +import paddle + + +class Dif(paddle.nn.Layer): + """This function is to initialise for differential operator. + + Args: + rbn (model): The Feedforward Neural Network. + """ + + def __init__(self, rbn, **kwargs): + super().__init__(**kwargs) + self.rbn = rbn + + def forward(self, x): + """This function is to calculate the differential terms. + + Args: + x (Tensor): The coordinate array + + Returns: + Tuple[Tensor, Tensor]: The first-order derivative of the u with respect to the x; The second-order derivative of the u with respect to the x. + """ + x.stop_gradient = False + # Obtain the output from the RBN + u = self.rbn(x) + # Obtain the first-order derivative of the output with respect to the input + u_x = paddle.grad(u, x, retain_graph=True, create_graph=True)[0] + # Obtain the second-order derivative of the output with respect to the input + u_xx = paddle.grad(u_x, x, retain_graph=True, create_graph=True)[0] + return u_x, u_xx + + +class PIRBN(paddle.nn.Layer): + def __init__(self, rbn): + super().__init__() + self.rbn = rbn + + def forward(self, input_data): + xy, xy_b = input_data + # initialize the differential operators + Dif_u = Dif(self.rbn) + u_b = self.rbn(xy_b) + + # obtain partial derivatives of u with respect to x + _, u_xx = Dif_u(xy) + + return [u_xx, u_b] + + def cal_ntk(self, x): + # Formula (4), Page5, \gamma variable + gamma_g = 0.0 + gamma_b = 0.0 + n_neu = self.rbn.n_neu + + # in-domain + n1 = x[0].shape[0] + for i in range(n1): + temp_x = [x[0][i, ...].unsqueeze(0), paddle.to_tensor([[0.0]])] + y = self.forward(temp_x) + l1t = paddle.grad(y[0], self.parameters()) + for j in l1t: + gamma_g = gamma_g + paddle.sum(j**2) / n1 + temp = paddle.concat((l1t[0], l1t[1].reshape((1, n_neu))), axis=1) + if i == 0: + # Fig.1, Page8, Kg variable + Kg = temp + else: + Kg = paddle.concat((Kg, temp), axis=0) + + # bound + n2 = x[1].shape[0] + for i in range(n2): + temp_x = [paddle.to_tensor([[0.0]]), x[1][i, ...].unsqueeze(0)] + y = self.forward(temp_x) + l1t = paddle.grad(y[1], self.parameters()) + for j in l1t: + gamma_b = gamma_b + paddle.sum(j**2) / n2 + + # calculate adapt factors + temp = gamma_g + gamma_b + gamma_g = temp / gamma_g + gamma_b = temp / gamma_b + + return gamma_g, gamma_b, Kg diff --git a/jointContribution/PIRBN/1D_nonlinear_spring/rbn_net.py b/jointContribution/PIRBN/rbn_net.py similarity index 85% rename from jointContribution/PIRBN/1D_nonlinear_spring/rbn_net.py rename to jointContribution/PIRBN/rbn_net.py index 15c889e78..d148c5187 100644 --- a/jointContribution/PIRBN/1D_nonlinear_spring/rbn_net.py +++ b/jointContribution/PIRBN/rbn_net.py @@ -1,97 +1,88 @@ -import math - -import numpy as np -import paddle - - -class RBN_Net(paddle.nn.Layer): - """This class is to build a radial basis network (RBN). - - Args: - n_in (int): Number of input of the RBN. - n_out (int): Number of output of the RBN. - n_neu (int): Number of neurons in the hidden layer. - b (List[float32]|float32): Initial value for hyperparameter b. - c (List[float32]): Initial value for hyperparameter c. - """ - - def __init__(self, n_in, n_out, n_neu, b, c): - super().__init__() - self.n_in = n_in - self.n_out = n_out - self.n_neu = n_neu - self.b = paddle.to_tensor(b) - self.c = paddle.to_tensor(c) - - self.layer1 = RBF_layer1(self.n_neu, self.c, n_in) - - # LeCun normal - std = math.sqrt(1 / self.n_neu) - self.linear = paddle.nn.Linear( - self.n_neu, - self.n_out, - weight_attr=paddle.ParamAttr( - initializer=paddle.nn.initializer.Normal(mean=0.0, std=std) - ), - bias_attr=False, - ) - self.ini_ab() - - def forward(self, x): - temp = self.layer1(x) - y = self.linear(temp) - return y - - def ini_ab(self): - b = np.ones((1, self.n_neu)) * self.b - self.layer1.b = self.layer1.create_parameter( - (1, self.n_neu), default_initializer=paddle.nn.initializer.Assign(b) - ) - - def get_weights(self): - s = self.state_dict() - ret = [s[i] for i in s] - return ret - - def set_weights(self, weights): - s = self.state_dict() - s["layer1.b"] = weights[0] - s["linear.weight"] = weights[1] - self.set_state_dict(s) - - -class RBF_layer1(paddle.nn.Layer): - """This class is to create the hidden layer of a radial basis network. - - Args: - n_neu (int): Number of neurons in the hidden layer. - c (List[float32]): Initial value for hyperparameter b. - input_shape_last (int): Last item of input shape. - """ - - def __init__(self, n_neu, c, input_shape_last): - super(RBF_layer1, self).__init__() - self.n_neu = n_neu - self.c = c - self.b = self.create_parameter( - shape=(input_shape_last, self.n_neu), - dtype=paddle.get_default_dtype(), - # Convert from tensorflow tf.random_normal_initializer(), default value mean=0.0, std=0.05 - default_initializer=paddle.nn.initializer.Normal(mean=0.0, std=0.05), - ) - # self.b = paddle.normal(mean=0.0, std=0.05, shape=[input_shape_last, self.n_neu]) - - def forward(self, inputs): # Defines the computation from inputs to outputs - s = self.b * self.b - temp_x = paddle.matmul(inputs, paddle.ones((1, self.n_neu))) - x0 = ( - paddle.reshape( - paddle.arange(self.n_neu, dtype=paddle.get_default_dtype()), - (1, self.n_neu), - ) - * (self.c[1] - self.c[0]) - / (self.n_neu - 1) - + self.c[0] - ) - x_new = (temp_x - x0) * (temp_x - x0) - return paddle.exp(-x_new * s) +import math + +import numpy as np +import paddle + + +class RBN_Net(paddle.nn.Layer): + """This class is to build a radial basis network (RBN). + + Args: + n_in (int): Number of input of the RBN. + n_out (int): Number of output of the RBN. + n_neu (int): Number of neurons in the hidden layer. + b (List[float32]|float32): Initial value for hyperparameter b. + c (List[float32]): Initial value for hyperparameter c. + """ + + def __init__(self, n_in, n_out, n_neu, b, c): + super().__init__() + self.n_in = n_in + self.n_out = n_out + self.n_neu = n_neu + self.b = paddle.to_tensor(b) + self.c = paddle.to_tensor(c) + + self.layer1 = RBF_layer1(self.n_neu, self.c, n_in) + # LeCun normal + std = math.sqrt(1 / self.n_neu) + self.linear = paddle.nn.Linear( + self.n_neu, + self.n_out, + weight_attr=paddle.ParamAttr( + initializer=paddle.nn.initializer.Normal(mean=0.0, std=std) + ), + bias_attr=False, + ) + self.ini_ab() + + def forward(self, x): + temp = self.layer1(x) + y = self.linear(temp) + return y + + def ini_ab(self): + b = np.ones((1, self.n_neu)) * self.b + self.layer1.b = self.layer1.create_parameter( + (1, self.n_neu), default_initializer=paddle.nn.initializer.Assign(b) + ) + + +class RBF_layer1(paddle.nn.Layer): + """This class is to create the hidden layer of a radial basis network. + + Args: + n_neu (int): Number of neurons in the hidden layer. + c (List[float32]): Initial value for hyperparameter b. + input_shape_last (int): Last item of input shape. + """ + + def __init__(self, n_neu, c, input_shape_last): + super(RBF_layer1, self).__init__() + self.n_neu = n_neu + self.c = c + self.b = self.create_parameter( + shape=(input_shape_last, self.n_neu), + dtype=paddle.get_default_dtype(), + # Convert from tensorflow tf.random_normal_initializer(), default value mean=0.0, std=0.05 + default_initializer=paddle.nn.initializer.Normal(mean=0.0, std=0.05), + ) + + def forward(self, inputs): # Defines the computation from inputs to outputs + temp_x = paddle.matmul(inputs, paddle.ones((1, self.n_neu))) + x0 = ( + paddle.reshape( + paddle.arange(self.n_neu, dtype=paddle.get_default_dtype()), + (1, self.n_neu), + ) + * (self.c[1] - self.c[0]) + / (self.n_neu - 1) + + self.c[0] + ) + x_new = temp_x - x0 + return self.rbf_activate(x_new) + + # activation function + def rbf_activate(self, input): + s = self.b * self.b + return paddle.exp(-(input * input) * s) diff --git a/jointContribution/PIRBN/requirements.txt b/jointContribution/PIRBN/requirements.txt deleted file mode 100644 index 12c0307bd..000000000 --- a/jointContribution/PIRBN/requirements.txt +++ /dev/null @@ -1,7 +0,0 @@ -numpy -scipy -matplotlib -# GPU -paddlepaddle-gpu -# CPU -# paddlepaddle diff --git a/jointContribution/PIRBN/train.py b/jointContribution/PIRBN/train.py new file mode 100644 index 000000000..a1e1916ae --- /dev/null +++ b/jointContribution/PIRBN/train.py @@ -0,0 +1,57 @@ +import paddle + +# Used to calculate the second-order derivatives +paddle.framework.core.set_prim_eager_enabled(True) + + +class Trainer: + def __init__(self, pirbn, x_train, y_train, learning_rate=0.001, maxiter=10000): + # set attributes + self.pirbn = pirbn + + self.learning_rate = learning_rate + self.x_train = [ + paddle.to_tensor(x, dtype=paddle.get_default_dtype()) for x in x_train + ] + self.y_train = paddle.to_tensor(y_train, dtype=paddle.get_default_dtype()) + self.maxiter = maxiter + self.his_l1 = [] + self.his_l2 = [] + self.iter = 0 + self.a_g = paddle.to_tensor(1.0) + self.a_b = paddle.to_tensor(1.0) + self.optimizer = paddle.optimizer.Adam( + learning_rate=0.001, parameters=self.pirbn.parameters() + ) + self.ntk_list = {} + + def Loss(self, x, y, a_g, a_b): + tmp = self.pirbn(x) + l1 = 0.5 * paddle.mean(paddle.square(tmp[0] - y[0])) + l2 = 0.5 * paddle.mean(paddle.square(tmp[1])) + loss = l1 * a_g + l2 * a_b + return loss, l1, l2 + + def evaluate(self): + # compute loss + loss, l1, l2 = self.Loss(self.x_train, self.y_train, self.a_g, self.a_b) + l1_numpy = float(l1) + l2_numpy = float(l2) + self.his_l1.append(l1_numpy) + self.his_l2.append(l2_numpy) + if self.iter % 200 == 0: + self.a_g, self.a_b, _ = self.pirbn.cal_ntk(self.x_train) + print("\ta_g =", float(self.a_g), "\ta_b =", float(self.a_b)) + print("Iter: ", self.iter, "\tL1 =", l1_numpy, "\tL2 =", l2_numpy) + self.iter = self.iter + 1 + return loss + + def fit(self): + for i in range(0, self.maxiter): + if i in [0, 2000, 20000]: + self.ntk_list[i] = self.pirbn.cal_ntk(self.x_train)[2].numpy() + loss = self.evaluate() + loss.backward() + self.optimizer.step() + self.optimizer.clear_grad() + return loss, [self.his_l1, self.his_l2]