Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

赛题十五:PIRBN 子目录 1D_sine_function #536

Merged
merged 18 commits into from
Oct 20, 2023
41 changes: 41 additions & 0 deletions jointContribution/PIRBN/README.md
co63oc marked this conversation as resolved.
Show resolved Hide resolved
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
# Physics-informed radial basis network (PIRBN)

This repository provides numerical examples of the **physics-informed radial basis network** (**PIRBN**).

Physics-informed neural network (PINN) has recently gained increasing interest in computational mechanics.

This work starts from studying the training dynamics of PINNs via the nerual tangent kernel (NTK) theory. Based on numerical experiments, we found:

- PINNs tend to be a **local approximator** during the training
- For PINNs who fail to be a local apprixmator, the physics-informed loss can be hardly minimised through training

Inspired by findings, we proposed the PIRBN, which can exhibit the local property intuitively. It has been demonstrated that the NTK theory is applicable for PIRBN. Besides, other PINN techniques can be directly migrated to PIRBNs.

Numerical examples include:

- 1D sine funtion (**Eq. 13** in the manuscript)

**PDE**: $\frac{\partial^2 }{\partial x^2}u(x)-4\mu^2\pi^2 sin(2\mu\pi(x))=0, x\in[0,1]$
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前的修改都很好,都是符合要求的,但是图片结果上还有一些不清楚的地方,特此说明下:
论文图片指标:

  1. Fig 1 :单层PINN方法(不是PIRBN),公式(13),mu = 4, 【tanh激活函数】,论文显示勉强收敛
  2. Fig. 2 :单层PINN方法(不是PIRBN),公式(13),改变 mu = 8, 【tanh激活函数】,病态了,论文的图显示结果很差
  3. Fig. 3 :单层PINN方法(不是PIRBN),改变 公式(15), mu = 4, 【tanh激活函数】,病态了,论文的图显示结果很差
  4. Fig. 6: 单层PIRBN(源代码实现),改变 公式(15),改变 mu = 8, 改变【高斯激活函数】,结果非常优秀,说明PIRBN很有用

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外,此处公式好像和源代码对不上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fig1, Fig2, Fig3 使用tanh和论文不一致

Fig6改变激活函数和论文一致
Fig1
图片
Fig2
图片
Fig3
图片
Fig6
图片

公式和源码是一致
图片

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不同激活函数设置

    def gaussian_function(self, temp_x):
        x0 = (
            paddle.reshape(
                paddle.arange(self.n_neu, dtype=paddle.get_default_dtype()),
                (1, self.n_neu),
            )
            * (self.c[1] - self.c[0])
            / (self.n_neu - 1)
            + self.c[0]
        )
        x_new = temp_x - x0
        s = self.b * self.b
        return paddle.exp(-(x_new * x_new) * s)

    def tanh_function(self, temp_x):
        return paddle.tanh(temp_x)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能源代码有点问题Fig 1我用了这篇文章的Reference源码做了一下图
PredictiveIntelligenceLab/PINNsNTK@18ef519
image
可能jacobian那个函数有点问题,看看和这个仓库对齐一下


**BC**: $u(0)=u(1)=0.$

- 1D sine funtion (**Eq. 15** in the manuscript)
**PDE**: $\frac{\partial^2 }{\partial x^2}u(x-100)-4\mu^2\pi^2 sin(2\mu\pi(x-100))=0, x\in[100,101]$

**BC**: $u(100)=u(101)=0.$

For more details in terms of mathematical proofs and numerical examples, please refer to our paper.

# Link

<https://doi.org/10.1016/j.cma.2023.116290>

<https://github.com/JinshuaiBai/PIRBN>

<https://arxiv.org/ftp/arxiv/papers/2304/2304.06234.pdf>

# Enviornmental settings

```
pip install -r requirements.txt
```
86 changes: 86 additions & 0 deletions jointContribution/PIRBN/analytical_solution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
import os

import matplotlib.pyplot as plt
import numpy as np
import paddle
import scipy.io


def output_fig(train_obj, mu, b, right_by, activation_function):
plt.figure(figsize=(15, 9))
rbn = train_obj.pirbn.rbn

target_dir = os.path.join(os.path.dirname(__file__), "../target")
if not os.path.exists(target_dir):
os.mkdir(target_dir)

# Comparisons between the network predictions and the ground truth.
plt.subplot(2, 3, 1)
ns = 1001
dx = 1 / (ns - 1)
xy = np.zeros((ns, 1)).astype(np.float32)
for i in range(0, ns):
xy[i, 0] = i * dx + right_by
y = rbn(paddle.to_tensor(xy), activation_function=activation_function)
y = y.numpy()
y_true = np.sin(2 * mu * np.pi * xy)
plt.plot(xy, y_true)
plt.plot(xy, y, linestyle="--")
plt.legend(["ground truth", "predict"])
plt.xlabel("x")

# Point-wise absolute error plot.
plt.subplot(2, 3, 2)
plt.plot(xy, np.abs(y_true - y))
plt.ylim(top=8e-3)
plt.ylabel("Absolute Error")
plt.xlabel("x")

# Loss history of the network during the training process.
plt.subplot(2, 3, 3)
loss_g = train_obj.loss_g
x = range(len(loss_g))
plt.yscale("log")
plt.plot(x, loss_g)
plt.plot(x, train_obj.loss_b)
plt.legend(["Lg", "Lb"])
plt.ylabel("Loss")
plt.xlabel("Iteration")

# Visualise NTK after initialisation, The normalised Kg at 0th iteration.
plt.subplot(2, 3, 4)
jac = train_obj.ntk_list[0]
a = np.dot(jac, np.transpose(jac))
plt.imshow(a / (np.max(abs(a))), cmap="bwr", vmax=1, vmin=-1)
plt.colorbar()
plt.title("Kg at 0th iteration")
plt.xlabel("Sample point index")

# Visualise NTK after training, The normalised Kg at 2000th iteration.
plt.subplot(2, 3, 5)
if 2000 in train_obj.ntk_list:
jac = train_obj.ntk_list[2000]
a = np.dot(jac, np.transpose(jac))
plt.imshow(a / (np.max(abs(a))), cmap="bwr", vmax=1, vmin=-1)
plt.colorbar()
plt.title("Kg at 2000th iteration")
plt.xlabel("Sample point index")

# The normalised Kg at 20000th iteration.
plt.subplot(2, 3, 6)
if 20000 in train_obj.ntk_list:
jac = train_obj.ntk_list[20000]
a = np.dot(jac, np.transpose(jac))
plt.imshow(a / (np.max(abs(a))), cmap="bwr", vmax=1, vmin=-1)
plt.colorbar()
plt.title("Kg at 20000th iteration")
plt.xlabel("Sample point index")

plt.savefig(
os.path.join(
target_dir, f"sine_function_{mu}_{b}_{right_by}_{activation_function}.png"
)
)

# Save data
scipy.io.savemat(os.path.join(target_dir, "out.mat"), {"NTK": a, "x": xy, "y": y})
56 changes: 56 additions & 0 deletions jointContribution/PIRBN/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import analytical_solution
import numpy as np
import pirbn
import rbn_net
import train


# mu, Fig.1, Page8
# right_by, Formula (15) Page9
def sine_function_main(mu, right_by=0, activation_function="gaussian_function"):
# Define the number of sample points
ns = 51

# Define the sample points' interval
dx = 1.0 / (ns - 1)

# Initialise sample points' coordinates
xy = np.zeros((ns, 1)).astype(np.float32)
for i in range(0, ns):
xy[i, 0] = i * dx + right_by
xy_b = np.array([[right_by + 0.0], [right_by + 1.0]])

x = [xy, xy_b]
y = [-4 * mu**2 * np.pi**2 * np.sin(2 * mu * np.pi * xy)]

# Set up radial basis network
n_in = 1
n_out = 1
n_neu = 61
b = 10.0
c = [right_by - 0.1, right_by + 1.1]

# Set up PIRBN
rbn = rbn_net.RBN_Net(n_in, n_out, n_neu, b, c)
train_obj = train.Trainer(
pirbn.PIRBN(rbn),
x,
y,
learning_rate=0.001,
maxiter=20001,
activation_function=activation_function,
)
train_obj.fit()

# Visualise results
analytical_solution.output_fig(train_obj, mu, b, right_by, activation_function)


# Fig.1
sine_function_main(mu=4, right_by=0, activation_function="tanh")
# Fig.2
sine_function_main(mu=8, right_by=0, activation_function="tanh")
# Fig.3
sine_function_main(mu=4, right_by=100, activation_function="tanh")
# Fig.6
sine_function_main(mu=8, right_by=100, activation_function="gaussian_function")
61 changes: 61 additions & 0 deletions jointContribution/PIRBN/pirbn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import paddle


class PIRBN(paddle.nn.Layer):
def __init__(self, rbn):
super().__init__()
self.rbn = rbn

def forward(self, input_data, activation_function="gaussian_function"):
xy, xy_b = input_data
# initialize the differential operators
u_b = self.rbn(xy_b, activation_function)

# obtain partial derivatives of u with respect to x
xy.stop_gradient = False
# Obtain the output from the RBN
u = self.rbn(xy, activation_function)
# Obtain the first-order derivative of the output with respect to the input
u_x = paddle.grad(u, xy, retain_graph=True, create_graph=True)[0]
# Obtain the second-order derivative of the output with respect to the input
u_xx = paddle.grad(u_x, xy, retain_graph=True, create_graph=True)[0]

return [u_xx, u_b]

def cal_ntk(self, x):
# Formula (4), Page5, \lambda variable
# Lambda represents the eigenvalues of the matrix(Kg)
lambda_g = 0.0
lambda_b = 0.0
n_neu = self.rbn.n_neu

# in-domain
n1 = x[0].shape[0]
for i in range(n1):
temp_x = [x[0][i, ...].unsqueeze(0), paddle.to_tensor([[0.0]])]
y = self.forward(temp_x)
l1t = paddle.grad(y[0], self.parameters())
for j in l1t:
lambda_g = lambda_g + paddle.sum(j**2) / n1
temp = paddle.concat((l1t[0], l1t[1].reshape((1, n_neu))), axis=1)
if i == 0:
# Fig.1, Page8, Kg variable
Kg = temp
else:
Kg = paddle.concat((Kg, temp), axis=0)

# bound
n2 = x[1].shape[0]
for i in range(n2):
temp_x = [paddle.to_tensor([[0.0]]), x[1][i, ...].unsqueeze(0)]
y = self.forward(temp_x)
l1t = paddle.grad(y[1], self.parameters())
for j in l1t:
lambda_b = lambda_b + paddle.sum(j**2) / n2

# calculate adapt factors
temp = lambda_g + lambda_b
lambda_g = temp / lambda_g
lambda_b = temp / lambda_b

return lambda_g, lambda_b, Kg
96 changes: 96 additions & 0 deletions jointContribution/PIRBN/rbn_net.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
import math

import numpy as np
import paddle


class RBN_Net(paddle.nn.Layer):
"""This class is to build a radial basis network (RBN).

Args:
n_in (int): Number of input of the RBN.
n_out (int): Number of output of the RBN.
n_neu (int): Number of neurons in the hidden layer.
b (List[float32]|float32): Initial value for hyperparameter b.
co63oc marked this conversation as resolved.
Show resolved Hide resolved
c (List[float32]): Initial value for hyperparameter c.
"""

def __init__(self, n_in, n_out, n_neu, b, c):
super().__init__()
self.n_in = n_in
self.n_out = n_out
self.n_neu = n_neu
self.b = paddle.to_tensor(b)
self.c = paddle.to_tensor(c)

self.layer1 = RBF_layer1(self.n_neu, self.c, n_in)
# LeCun normal
std = math.sqrt(1 / self.n_neu)
self.linear = paddle.nn.Linear(
self.n_neu,
self.n_out,
weight_attr=paddle.ParamAttr(
initializer=paddle.nn.initializer.Normal(mean=0.0, std=std)
),
bias_attr=False,
)
self.ini_ab()

def forward(self, x, activation_function="gaussian_function"):
temp = self.layer1(x, activation_function)
y = self.linear(temp)
return y

def ini_ab(self):
co63oc marked this conversation as resolved.
Show resolved Hide resolved
b = np.ones((1, self.n_neu)) * self.b
self.layer1.b = self.layer1.create_parameter(
(1, self.n_neu), default_initializer=paddle.nn.initializer.Assign(b)
)


class RBF_layer1(paddle.nn.Layer):
"""This class is to create the hidden layer of a radial basis network.

Args:
n_neu (int): Number of neurons in the hidden layer.
c (List[float32]): Initial value for hyperparameter b.
input_shape_last (int): Last item of input shape.
"""

def __init__(self, n_neu, c, input_shape_last):
super(RBF_layer1, self).__init__()
self.n_neu = n_neu
self.c = c
self.b = self.create_parameter(
shape=(input_shape_last, self.n_neu),
dtype=paddle.get_default_dtype(),
# Convert from tensorflow tf.random_normal_initializer(), default value mean=0.0, std=0.05
default_initializer=paddle.nn.initializer.Normal(mean=0.0, std=0.05),
)

def forward(
self, inputs, activation_function="gaussian_function"
): # Defines the computation from inputs to outputs
temp_x = paddle.matmul(inputs, paddle.ones((1, self.n_neu)))
if activation_function == "gaussian_function":
return self.gaussian_function(temp_x)
else:
return self.tanh_function(temp_x)

# Gaussian function,#Formula (19), Page10
def gaussian_function(self, temp_x):
x0 = (
paddle.reshape(
paddle.arange(self.n_neu, dtype=paddle.get_default_dtype()),
(1, self.n_neu),
)
* (self.c[1] - self.c[0])
/ (self.n_neu - 1)
+ self.c[0]
)
x_new = temp_x - x0
s = self.b * self.b
return paddle.exp(-(x_new * x_new) * s)

def tanh_function(self, temp_x):
return paddle.tanh(temp_x)
Loading