Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Sep 19, 2023
1 parent 913dd05 commit b3c6919
Show file tree
Hide file tree
Showing 6 changed files with 313 additions and 282 deletions.
7 changes: 6 additions & 1 deletion jointContribution/PIRBN/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,17 @@ Inspired by findings, we proposed the PIRBN, which can exhibit the local propert

Numerical examples include:

- 1D sine funtion (**Eq. 1** in the manuscript)
- 1D sine funtion (**Eq. 13** in the manuscript)

**PDE**: $\frac{\partial^2 }{\partial x^2}u(x)-4\mu^2\pi^2 sin(2\mu\pi(x))=0, x\in[0,1]$

**BC**: $u(0)=u(1)=0.$

- 1D sine funtion (**Eq. 15** in the manuscript)
**PDE**: $\frac{\partial^2 }{\partial x^2}u(x-100)-4\mu^2\pi^2 sin(2\mu\pi(x-100))=0, x\in[100,101]$

**BC**: $u(100)=u(101)=0.$

For more details in terms of mathematical proofs and numerical examples, please refer to our paper.

# Link
Expand Down
40 changes: 27 additions & 13 deletions jointContribution/PIRBN/analytical_solution.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,49 +6,55 @@
import scipy.io


def output_fig(train_obj, mu, b):
def output_fig(train_obj, mu, b, right_by, activation_function):
plt.figure(figsize=(15, 9))
rbn = train_obj.pirbn.rbn

target_dir = os.path.join(os.path.dirname(__file__), "../target")
if not os.path.exists(target_dir):
os.mkdir(target_dir)

# Comparisons between the PINN predictions and the ground truth.
# Comparisons between the network predictions and the ground truth.
plt.subplot(2, 3, 1)
ns = 1001
dx = 1 / (ns - 1)
xy = np.zeros((ns, 1)).astype(np.float32)
for i in range(0, ns):
xy[i, 0] = i * dx
y = rbn(paddle.to_tensor(xy))
xy[i, 0] = i * dx + right_by
y = rbn(paddle.to_tensor(xy), activation_function=activation_function)
y = y.numpy()
y_true = np.sin(2 * mu * np.pi * xy)
plt.plot(xy, y_true)
plt.plot(xy, y, linestyle="--")
plt.legend(["ground truth", "PINN"])
plt.legend(["ground truth", "predict"])
plt.xlabel("x")

# Point-wise absolute error plot.
plt.subplot(2, 3, 2)
plt.plot(xy, np.abs(y_true - y))
plt.ylim(top=1e-3)
plt.legend(["Absolute Error"])
plt.ylim(top=8e-3)
plt.ylabel("Absolute Error")
plt.xlabel("x")

# Loss history of the PINN during the training process.
# Loss history of the network during the training process.
plt.subplot(2, 3, 3)
his_l1 = train_obj.his_l1
x = range(len(his_l1))
loss_b = train_obj.loss_b
x = range(len(loss_b))
plt.yscale("log")
plt.plot(x, his_l1)
plt.plot(x, train_obj.his_l2)
plt.plot(x, loss_b)
plt.plot(x, train_obj.loss_g)
plt.legend(["Lg", "Lb"])
plt.ylabel("Loss")
plt.xlabel("Iteration")

# Visualise NTK after initialisation, The normalised Kg at 0th iteration.
plt.subplot(2, 3, 4)
jac = train_obj.ntk_list[0]
a = np.dot(jac, np.transpose(jac))
plt.imshow(a / (np.max(abs(a))), cmap="bwr", vmax=1, vmin=-1)
plt.colorbar()
plt.title("Kg at 0th iteration")
plt.xlabel("Sample point index")

# Visualise NTK after training, The normalised Kg at 2000th iteration.
plt.subplot(2, 3, 5)
Expand All @@ -57,6 +63,8 @@ def output_fig(train_obj, mu, b):
a = np.dot(jac, np.transpose(jac))
plt.imshow(a / (np.max(abs(a))), cmap="bwr", vmax=1, vmin=-1)
plt.colorbar()
plt.title("Kg at 2000th iteration")
plt.xlabel("Sample point index")

# The normalised Kg at 20000th iteration.
plt.subplot(2, 3, 6)
Expand All @@ -65,8 +73,14 @@ def output_fig(train_obj, mu, b):
a = np.dot(jac, np.transpose(jac))
plt.imshow(a / (np.max(abs(a))), cmap="bwr", vmax=1, vmin=-1)
plt.colorbar()
plt.title("Kg at 20000th iteration")
plt.xlabel("Sample point index")

plt.savefig(os.path.join(target_dir, f"sine_function_{mu}_{b}.png"))
plt.savefig(
os.path.join(
target_dir, f"sine_function_{mu}_{b}_{right_by}_{activation_function}.png"
)
)

# Save data
scipy.io.savemat(os.path.join(target_dir, "out.mat"), {"NTK": a, "x": xy, "y": y})
94 changes: 56 additions & 38 deletions jointContribution/PIRBN/main.py
Original file line number Diff line number Diff line change
@@ -1,38 +1,56 @@
import analytical_solution
import numpy as np
import pirbn
import rbn_net
import train

# Define mu
mu = 4

# Define the number of sample points
ns = 51

# Define the sample points' interval
dx = 1.0 / (ns - 1)

# Initialise sample points' coordinates
xy = np.zeros((ns, 1)).astype(np.float32)
for i in range(0, ns):
xy[i, 0] = i * dx
xy_b = np.array([[0.0], [1.0]])

x = [xy, xy_b]
y = [-4 * mu**2 * np.pi**2 * np.sin(2 * mu * np.pi * xy)]

# Set up radial basis network
n_in = 1
n_out = 1
n_neu = 61
b = 10.0
c = [-0.1, 1.1]

# Set up PIRBN
rbn = rbn_net.RBN_Net(n_in, n_out, n_neu, b, c)
train_obj = train.Trainer(pirbn.PIRBN(rbn), x, y, learning_rate=0.001, maxiter=20001)
train_obj.fit()

# Visualise results
analytical_solution.output_fig(train_obj, mu, b)
import analytical_solution
import numpy as np
import pirbn
import rbn_net
import train


# mu, Fig.1, Page8
# right_by, Formula (15) Page9
def sine_function_main(mu, right_by=0, activation_function="gaussian_function"):
# Define the number of sample points
ns = 51

# Define the sample points' interval
dx = 1.0 / (ns - 1)

# Initialise sample points' coordinates
xy = np.zeros((ns, 1)).astype(np.float32)
for i in range(0, ns):
xy[i, 0] = i * dx + right_by
xy_b = np.array([[right_by + 0.0], [right_by + 1.0]])

x = [xy, xy_b]
y = [-4 * mu**2 * np.pi**2 * np.sin(2 * mu * np.pi * xy)]

# Set up radial basis network
n_in = 1
n_out = 1
n_neu = 61
b = 10.0
c = [right_by - 0.1, right_by + 1.1]

# Set up PIRBN
rbn = rbn_net.RBN_Net(n_in, n_out, n_neu, b, c)
train_obj = train.Trainer(
pirbn.PIRBN(rbn),
x,
y,
learning_rate=0.001,
maxiter=20001,
activation_function=activation_function,
)
train_obj.fit()

# Visualise results
analytical_solution.output_fig(train_obj, mu, b, right_by, activation_function)


# Fig.1
sine_function_main(mu=4, right_by=0, activation_function="tanh")
# Fig.2
sine_function_main(mu=8, right_by=0, activation_function="tanh")
# Fig.3
sine_function_main(mu=4, right_by=100, activation_function="tanh")
# Fig.6
sine_function_main(mu=8, right_by=100, activation_function="gaussian_function")
146 changes: 61 additions & 85 deletions jointContribution/PIRBN/pirbn.py
Original file line number Diff line number Diff line change
@@ -1,85 +1,61 @@
import paddle


class Dif(paddle.nn.Layer):
"""This function is to initialise for differential operator.
Args:
rbn (model): The Feedforward Neural Network.
"""

def __init__(self, rbn, **kwargs):
super().__init__(**kwargs)
self.rbn = rbn

def forward(self, x):
"""This function is to calculate the differential terms.
Args:
x (Tensor): The coordinate array
Returns:
Tuple[Tensor, Tensor]: The first-order derivative of the u with respect to the x; The second-order derivative of the u with respect to the x.
"""
x.stop_gradient = False
# Obtain the output from the RBN
u = self.rbn(x)
# Obtain the first-order derivative of the output with respect to the input
u_x = paddle.grad(u, x, retain_graph=True, create_graph=True)[0]
# Obtain the second-order derivative of the output with respect to the input
u_xx = paddle.grad(u_x, x, retain_graph=True, create_graph=True)[0]
return u_x, u_xx


class PIRBN(paddle.nn.Layer):
def __init__(self, rbn):
super().__init__()
self.rbn = rbn

def forward(self, input_data):
xy, xy_b = input_data
# initialize the differential operators
Dif_u = Dif(self.rbn)
u_b = self.rbn(xy_b)

# obtain partial derivatives of u with respect to x
_, u_xx = Dif_u(xy)

return [u_xx, u_b]

def cal_ntk(self, x):
# Formula (4), Page5, \gamma variable
gamma_g = 0.0
gamma_b = 0.0
n_neu = self.rbn.n_neu

# in-domain
n1 = x[0].shape[0]
for i in range(n1):
temp_x = [x[0][i, ...].unsqueeze(0), paddle.to_tensor([[0.0]])]
y = self.forward(temp_x)
l1t = paddle.grad(y[0], self.parameters())
for j in l1t:
gamma_g = gamma_g + paddle.sum(j**2) / n1
temp = paddle.concat((l1t[0], l1t[1].reshape((1, n_neu))), axis=1)
if i == 0:
# Fig.1, Page8, Kg variable
Kg = temp
else:
Kg = paddle.concat((Kg, temp), axis=0)

# bound
n2 = x[1].shape[0]
for i in range(n2):
temp_x = [paddle.to_tensor([[0.0]]), x[1][i, ...].unsqueeze(0)]
y = self.forward(temp_x)
l1t = paddle.grad(y[1], self.parameters())
for j in l1t:
gamma_b = gamma_b + paddle.sum(j**2) / n2

# calculate adapt factors
temp = gamma_g + gamma_b
gamma_g = temp / gamma_g
gamma_b = temp / gamma_b

return gamma_g, gamma_b, Kg
import paddle


class PIRBN(paddle.nn.Layer):
def __init__(self, rbn):
super().__init__()
self.rbn = rbn

def forward(self, input_data, activation_function="gaussian_function"):
xy, xy_b = input_data
# initialize the differential operators
u_b = self.rbn(xy_b, activation_function)

# obtain partial derivatives of u with respect to x
xy.stop_gradient = False
# Obtain the output from the RBN
u = self.rbn(xy, activation_function)
# Obtain the first-order derivative of the output with respect to the input
u_x = paddle.grad(u, xy, retain_graph=True, create_graph=True)[0]
# Obtain the second-order derivative of the output with respect to the input
u_xx = paddle.grad(u_x, xy, retain_graph=True, create_graph=True)[0]

return [u_xx, u_b]

def cal_ntk(self, x):
# Formula (4), Page5, \lambda variable
# Lambda represents the eigenvalues of the matrix(Kg)
lambda_g = 0.0
lambda_b = 0.0
n_neu = self.rbn.n_neu

# in-domain
n1 = x[0].shape[0]
for i in range(n1):
temp_x = [x[0][i, ...].unsqueeze(0), paddle.to_tensor([[0.0]])]
y = self.forward(temp_x)
l1t = paddle.grad(y[0], self.parameters())
for j in l1t:
lambda_g = lambda_g + paddle.sum(j**2) / n1
temp = paddle.concat((l1t[0], l1t[1].reshape((1, n_neu))), axis=1)
if i == 0:
# Fig.1, Page8, Kg variable
Kg = temp
else:
Kg = paddle.concat((Kg, temp), axis=0)

# bound
n2 = x[1].shape[0]
for i in range(n2):
temp_x = [paddle.to_tensor([[0.0]]), x[1][i, ...].unsqueeze(0)]
y = self.forward(temp_x)
l1t = paddle.grad(y[1], self.parameters())
for j in l1t:
lambda_b = lambda_b + paddle.sum(j**2) / n2

# calculate adapt factors
temp = lambda_g + lambda_b
lambda_g = temp / lambda_g
lambda_b = temp / lambda_b

return lambda_g, lambda_b, Kg
Loading

0 comments on commit b3c6919

Please sign in to comment.