Skip to content

Commit

Permalink
赛题十五:PIRBN 子目录 1D_sine_function (PaddlePaddle#536)
Browse files Browse the repository at this point in the history
* Add PIRBN

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix
  • Loading branch information
co63oc authored Oct 20, 2023
1 parent bb8e7f6 commit c889eb7
Show file tree
Hide file tree
Showing 9 changed files with 759 additions and 0 deletions.
148 changes: 148 additions & 0 deletions docs/zh/examples/pirbn.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,148 @@
# PIRBN

## 1. 背景简介

我们最近发现经过训练,物理信息神经网络(PINN)往往会成为局部近似函数。这一观察结果促使我们开发了一种新型的物理-信息径向基网络(PIRBN),该网络在整个训练过程中都能够维持局部近似性质。与深度神经网络不同,PIRBN 仅包含一个隐藏层和一个径向基“激活”函数。在适当的条件下,我们证明了使用梯度下降方法训练 PIRBN 可以收敛到高斯过程。此外,我们还通过神经邻近核(NTK)理论研究了 PIRBN 的训练动态。此外,我们还对 PIRBN 的初始化策略进行了全面调查。基于数值示例,我们发现 PIRBN 在解决具有高频特征和病态计算域的非线性偏微分方程方面比PINN更有效。此外,现有的 PINN 数值技术,如自适应学习、分解和不同类型的损失函数,也适用于 PIRBN。

<figure markdown>
![介绍](https://paddle-org.bj.bcebos.com/paddlescience/docs/PIRBN/PIRBN_1.png){ loading=lazy }
<figcaption>网络的结构</figcaption>
</figure>
图片左侧为常见神经网络结构的输入层,隐藏层,输出层,隐藏层包含激活层,a 中为单层隐藏层,b 中为多层隐藏层,图片右侧为 PIRBN 网络的激活函数,计算网络的损失 Loss 并反向传递。图片说明当使用 PIRBN 时,每个 RBF 神经元仅在输入接近神经元中心时被激活。直观地说,PIRBN 具有局部逼近特性。通过梯度下降算法训练一个 PIRBN 也可以通过 NTK 理论进行分析。

<figure markdown>
![gaussian](https://paddle-org.bj.bcebos.com/paddlescience/docs/PIRBN/PIRBN_2.png){ loading=lazy }
<figcaption>不同阶数的高斯激活函数</figcaption>
</figure>
(a) 0, 1, 2 阶高斯激活函数
(b) 设置不同 b 值
(c) 设置不同 c 值

当使用高斯函数作为激活函数时,输入与输出之间的映射关系可以数学上表示为高斯函数的某种形式。RBF 网络是一种常用于模式识别、数据插值和函数逼近的神经网络,其关键特征是使用径向基函数作为激活函数,使得网络具有更好的全局逼近能力和灵活性。

## 2. 问题定义

在 NTK 和基于 NTK 的适应性训练方法的帮助下,PINN 在处理具有高频特征的问题时的性能可以得到显著提升。例如,考虑一个偏微分方程及其边界条件:

$$
\begin{aligned}
& \frac{\mathrm{d}^2}{\mathrm{~d} x^2} u(x)-4 \mu^2 \pi^2 \sin (2 \mu \pi x)=0, \text { for } x \in[0,1] \\
& u(0)=u(1)=0
\end{aligned}
$$

其中μ是一个控制PDE解的频率特征的常数。

## 3. 问题求解

接下来开始讲解如何将问题一步一步地转化为 PaddlePaddle 代码,用深度学习的方法求解该问题。
为了快速理解 PaddlePaddle,接下来仅对模型构建、方程构建、计算域构建等关键步骤进行阐述,而其余细节请参考 [API文档](../api/arch.md)

### 3.1 模型构建

在 PIRBN 问题中,建立网络,用 PaddlePaddle 代码表示如下

``` py linenums="40"
--8<--
jointContribution/PIRBN/main.py:40:42
--8<--
```

### 3.2 数据构建

本案例涉及读取数据构建,如下所示

``` py linenums="18"
--8<--
jointContribution/PIRBN/main.py:18:38
--8<--
```

### 3.3 训练和评估构建

训练和评估构建,设置损失计算函数,返回字段,代码如下所示:

``` py linenums="52"
--8<--
jointContribution/PIRBN/train.py:52:90
--8<--
```

### 3.4 超参数设定

接下来我们需要指定训练轮数,此处我们按实验经验,使用 20001 轮训练轮数。

``` py linenums="43"
--8<--
jointContribution/PIRBN/main.py:43:43
--8<--
```

### 3.5 优化器构建

训练过程会调用优化器来更新模型参数,此处选择 `Adam` 优化器并设定 `learning_rate` 为 1e-3。

``` py linenums="33"
--8<--
jointContribution/PIRBN/train.py:33:35
--8<--
```

### 3.6 模型训练与评估

模型训练与评估

``` py linenums="92"
--8<--
jointContribution/PIRBN/train.py:92:99
--8<--
```

## 4. 完整代码

``` py linenums="1" title="main.py"
--8<--
jointContribution/PIRBN/main.py
--8<--
```

## 5. 结果展示

PINN 案例针对 epoch=20001 和 learning\_rate=1e-3 的参数配置进行了实验,结果返回Loss为 0.13567。

PIRBN 案例针对 epoch=20001 和 learning\_rate=1e-3 的参数配置进行了实验,结果返回Loss为 0.59471。

<figure markdown>
![PINN](https://paddle-org.bj.bcebos.com/paddlescience/docs/PIRBN/PIRBN_3.png){ loading=lazy }
<figcaption>PINN 结果图</figcaption>
</figure>
图为使用双曲正切函数(tanh)作为激活函数(activation function),并且使用 LuCun 初始化方法来初始化神经网络中的所有参数。

- 图中子图 1 为预测值和真实值的曲线比较
- 图中子图 2 为误差值
- 图中子图 3 为损失值
- 图中子图 4 为训练 1 次的 Kg 图
- 图中子图 5 为训练 2000 次的 Kg 图
- 图中子图 6 为训练 20000 次的 Kg 图

可以看到预测值和真实值可以匹配,误差值逐渐升高然后逐渐减少,Loss 历史降低后波动,Kg 图随训练次数增加而逐渐收敛。

<figure markdown>
![PIRBN](https://paddle-org.bj.bcebos.com/paddlescience/docs/PIRBN/PIRBN_4.png){ loading=lazy }
<figcaption>PIRBN 结果图</figcaption>
</figure>
图为使用高斯函数(gaussian function)作为激活函数(activation function)生成的数据,并且使用 LuCun 初始化方法来初始化神经网络中的所有参数。

- 图中子图 1 为预测值和真实值的曲线比较
- 图中子图 2 为误差值
- 图中子图 3 为损失值
- 图中子图 4 为训练 1 次的 Kg 图
- 图中子图 5 为训练 2000 次的 Kg 图
- 图中子图 6 为训练 20000 次的 Kg 图

可以看到预测值和真实值可以匹配,误差值逐渐升高然后逐渐减少再升高,Loss 历史降低后波动,Kg 图随训练次数增加而逐渐收敛。

## 6. 参考资料

- [Physics-informed radial basis network (PIRBN): A local approximating neural network for solving nonlinear PDEs](https://arxiv.org/abs/2304.06234)
- <https://github.com/JinshuaiBai/PIRBN>
45 changes: 45 additions & 0 deletions jointContribution/PIRBN/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
# Physics-informed radial basis network (PIRBN)

This repository provides numerical examples of the **physics-informed radial basis network** (**PIRBN**).

Physics-informed neural network (PINN) has recently gained increasing interest in computational mechanics.

This work starts from studying the training dynamics of PINNs via the nerual tangent kernel (NTK) theory. Based on numerical experiments, we found:

- PINNs tend to be a **local approximator** during the training
- For PINNs who fail to be a local apprixmator, the physics-informed loss can be hardly minimised through training

Inspired by findings, we proposed the PIRBN, which can exhibit the local property intuitively. It has been demonstrated that the NTK theory is applicable for PIRBN. Besides, other PINN techniques can be directly migrated to PIRBNs.

Numerical examples include:

- 1D sine funtion (**Eq. 13** in the manuscript)

**PDE**: $\frac{\partial^2 }{\partial x^2}u(x)-4\mu^2\pi^2 sin(2\mu\pi(x))=0, x\in[0,1]$

**BC**: $u(0)=u(1)=0.$

- 1D sine funtion (**Eq. 15** in the manuscript)
**PDE**: $\frac{\partial^2 }{\partial x^2}u(x-100)-4\mu^2\pi^2 sin(2\mu\pi(x-100))=0, x\in[100,101]$

**BC**: $u(100)=u(101)=0.$

For more details in terms of mathematical proofs and numerical examples, please refer to our paper.

## Link

<https://doi.org/10.1016/j.cma.2023.116290>

<https://github.com/JinshuaiBai/PIRBN>

## Enviornmental settings

``` shell
pip install -r requirements.txt
```

## Train

``` python
python main.py
```
84 changes: 84 additions & 0 deletions jointContribution/PIRBN/analytical_solution.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
import os

import matplotlib.pyplot as plt
import numpy as np
import paddle


def output_fig(train_obj, mu, b, right_by, activation_function, output_Kgg):
plt.figure(figsize=(15, 9))
rbn = train_obj.pirbn.rbn

output_dir = os.path.join(os.path.dirname(__file__), "output")
if not os.path.exists(output_dir):
os.mkdir(output_dir)

# Comparisons between the network predictions and the ground truth.
plt.subplot(2, 3, 1)
ns = 1001
dx = 1 / (ns - 1)
xy = np.zeros((ns, 1)).astype(np.float32)
for i in range(0, ns):
xy[i, 0] = i * dx + right_by
y = rbn(paddle.to_tensor(xy))
y = y.numpy()
y_true = np.sin(2 * mu * np.pi * xy)
plt.plot(xy, y_true)
plt.plot(xy, y, linestyle="--")
plt.legend(["ground truth", "predict"])
plt.xlabel("x")

# Point-wise absolute error plot.
plt.subplot(2, 3, 2)
xy_y = np.abs(y_true - y)
plt.plot(xy, xy_y)
plt.ylim(top=np.max(xy_y))
plt.ylabel("Absolute Error")
plt.xlabel("x")

# Loss history of the network during the training process.
plt.subplot(2, 3, 3)
loss_g = train_obj.loss_g
x = range(len(loss_g))
plt.yscale("log")
plt.plot(x, loss_g)
plt.plot(x, train_obj.loss_b)
plt.legend(["Lg", "Lb"])
plt.ylabel("Loss")
plt.xlabel("Iteration")

# Visualise NTK after initialisation, The normalised Kg at 0th iteration.
plt.subplot(2, 3, 4)
index = str(output_Kgg[0])
K = train_obj.ntk_list[index].numpy()
plt.imshow(K / (np.max(abs(K))), cmap="bwr", vmax=1, vmin=-1)
plt.colorbar()
plt.title(f"Kg at {index}-th iteration")
plt.xlabel("Sample point index")

# Visualise NTK after training, The normalised Kg at 2000th iteration.
plt.subplot(2, 3, 5)
index = str(output_Kgg[1])
K = train_obj.ntk_list[index].numpy()
plt.imshow(K / (np.max(abs(K))), cmap="bwr", vmax=1, vmin=-1)
plt.colorbar()
plt.title(f"Kg at {index}-th iteration")
plt.xlabel("Sample point index")

# The normalised Kg at 20000th iteration.
plt.subplot(2, 3, 6)
index = str(output_Kgg[2])
K = train_obj.ntk_list[index].numpy()
plt.imshow(K / (np.max(abs(K))), cmap="bwr", vmax=1, vmin=-1)
plt.colorbar()
plt.title(f"Kg at {index}-th iteration")
plt.xlabel("Sample point index")

plt.savefig(
os.path.join(
output_dir, f"sine_function_{mu}_{b}_{right_by}_{activation_function}.png"
)
)

# Save data
# scipy.io.savemat(os.path.join(output_dir, "out.mat"), {"NTK": a, "x": xy, "y": y})
36 changes: 36 additions & 0 deletions jointContribution/PIRBN/jacobian_function.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import paddle


def flat(x, start_axis=0, stop_axis=None):
# TODO Error if use paddle.flatten -> The Op flatten_grad doesn't have any gradop
stop_axis = None if stop_axis is None else stop_axis + 1
shape = x.shape

# [3, 1] --flat--> [3]
# [2, 2] --flat--> [4]
temp = shape[start_axis:stop_axis]
temp = [0 if x == 1 else x for x in temp] # kill invalid axis
flat_sum = sum(temp)
head = shape[0:start_axis]
body = [flat_sum]
tail = [] if stop_axis is None else shape[stop_axis:]
new_shape = head + body + tail
x_flat = x.reshape(new_shape)
return x_flat


def jacobian(y, x):
J_shape = y.shape + x.shape
J = paddle.zeros(J_shape)
y_flat = flat(y)
J_flat = flat(
J, start_axis=0, stop_axis=len(y.shape) - 1
) # partialy flatten as y_flat
for i, y_i in enumerate(y_flat):
grad = paddle.grad(y_i, x, allow_unused=True)[
0
] # grad[i] == sum by j (dy[j] / dx[i])
if grad is None:
grad = paddle.zeros_like(x)
J_flat[i] = grad
return J_flat.reshape(J_shape)
68 changes: 68 additions & 0 deletions jointContribution/PIRBN/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,68 @@
import analytical_solution
import numpy as np
import pirbn
import rbn_net
import train

import ppsci

# set random seed for reproducibility
SEED = 2023
ppsci.utils.misc.set_random_seed(SEED)

# mu, Fig.1, Page5
# right_by, Formula (15) Page5
def sine_function_main(
mu, adaptive_weights=True, right_by=0, activation_function="gaussian"
):
# Define the number of sample points
ns = 50

# Define the sample points' interval
dx = 1.0 / (ns - 1)

# Initialise sample points' coordinates
x_eq = np.linspace(0.0, 1.0, ns)[:, None]

for i in range(0, ns):
x_eq[i, 0] = i * dx + right_by
x_bc = np.array([[right_by + 0.0], [right_by + 1.0]])
x = [x_eq, x_bc]
y = -4 * mu**2 * np.pi**2 * np.sin(2 * mu * np.pi * x_eq)

# Set up radial basis network
n_in = 1
n_out = 1
n_neu = 61
b = 10.0
c = [right_by - 0.1, right_by + 1.1]

# Set up PIRBN
rbn = rbn_net.RBN_Net(n_in, n_out, n_neu, b, c, activation_function)
rbn_loss = pirbn.PIRBN(rbn, activation_function)
maxiter = 20001
output_Kgg = [0, int(0.1 * maxiter), maxiter - 1]
train_obj = train.Trainer(
rbn_loss,
x,
y,
learning_rate=0.001,
maxiter=maxiter,
adaptive_weights=adaptive_weights,
)
train_obj.fit(output_Kgg)

# Visualise results
analytical_solution.output_fig(
train_obj, mu, b, right_by, activation_function, output_Kgg
)


# Fig.1
sine_function_main(mu=4, right_by=0, activation_function="tanh")
# Fig.2
sine_function_main(mu=8, right_by=0, activation_function="tanh")
# Fig.3
sine_function_main(mu=4, right_by=100, activation_function="tanh")
# Fig.6
sine_function_main(mu=8, right_by=100, activation_function="gaussian")
Loading

0 comments on commit c889eb7

Please sign in to comment.