Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

赛题十五:PIRBN 子目录 1D_sine_function #536

Merged
merged 18 commits into from
Oct 20, 2023

Conversation

co63oc
Copy link
Contributor

@co63oc co63oc commented Sep 12, 2023

PR types

Others

PR changes

Others

Describe

PaddlePaddle/Paddle#55663

赛题十五:PaddleScience 领域前沿论文复现
实现 子目录 1D_nonlinear_spring

测试精度a_g

- Tensorflow Paddle
a_g 1.0004051 1.0005087

Tensorflow 运行
图片

Paddle运行
图片

@paddle-bot
Copy link

paddle-bot bot commented Sep 12, 2023

Thanks for your contribution!

Copy link
Contributor

@wangguan1995 wangguan1995 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NTK plot needs to be implemented.

paddle.framework.core.set_prim_eager_enabled(True)


class Adam:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已修改Adam

import paddle


def cal_adapt(pirbn, x):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

参考:
源论文Page4,Theorem 3.1.4
论文Page8 https://www.sciencedirect.com/science/article/pii/S002199912100663X

  1. 此函数用于计算NTK矩阵,对应论文变量名称为Kg,请补充下相关函数注释(引用论文公式xx,页数xx)
  2. 对关键变量如 lambda_g lambda_b1 lambda_b2 进行注释说明(特征值,计算雅可比矩阵)
  3. 对NTK矩阵进行可视化,尝试复现论文Fig. 1,可以参考https://github.com/PredictiveIntelligenceLab/PINNsNTK/blob/master/PINNsNTK_Poisson1D.ipynb

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为 pirbn.cal_ntk ,变量为 Kg
已增加引用公式页
可视化为 analytical_solution.output_fig subplot(2, 3, 4)

使用原tensorflow代码https://github.com/JinshuaiBai/PIRBN/tree/main/1D_sine_function,显示图和论文图Fig.1不同,设置参数mu=4, b=10.0

### Define mu
mu = 4

### Define the number of sample points
ns = 51

### Define the sample points' interval
dx = 1./(ns-1)

### Initialise sample points' coordinates
xy = np.zeros((ns, 1)).astype(np.float32)
for i in range(0, ns):
    xy[i, 0] = i * dx
xy_b = np.array([[0.], [1.]])

x = [xy, xy_b]
y = [-4*mu**2*np.pi**2*np.sin(2*mu*np.pi*xy)]

### Set up raidal basis network
n_in = 1
n_out = 1
n_neu = 61
b=10.
c = [-0.1, 1.1]

图片
原论文图
图片

y = rbn(paddle.to_tensor(xy))
y = y.numpy()
plt.plot(xy, y)
plt.plot(xy, xy * np.sin(xy))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

解析解请单独命名为analytical_solution, 并进行注释

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

target_dir = os.path.join(os.path.dirname(__file__), "../target")
if not os.path.exists(target_dir):
os.mkdir(target_dir)
plt.savefig(os.path.join(target_dir, "1D_nonlinear_spring.png"))
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

请加一张absolute error计算结果的图片

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

analytical_solution 子图 plt.subplot(2, 3, 2) 计算absolute error

paddle.framework.core.set_prim_eager_enabled(True)


class Adam:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

class 可以命名为Trainer更恰当

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

default_initializer=paddle.nn.initializer.Normal(mean=0.0, std=0.05),
)

def forward(self, inputs): # Defines the computation from inputs to outputs
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

此处应该是用了类似论文Page13的激活函数
尝试单独抽象出激活函数 def activate()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

设置激活函数 RBF_layer1.rbf_activate

@@ -0,0 +1,7 @@
numpy
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个requirements在PaddleScience主目录下已经有了
如果有新的库需要添加,请加到主目录下的requirement.txt

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

@luotao1 luotao1 added the HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 label Sep 13, 2023
@luotao1 luotao1 self-assigned this Sep 13, 2023
@co63oc co63oc changed the title 赛题十五:PIRBN 子目录 1D_nonlinear_spring 赛题十五:PIRBN 子目录 1D_sine_function Sep 13, 2023
Copy link
Contributor

@wangguan1995 wangguan1995 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

当前复现代码的公式有点问题,需要复现的论文指标为:Fig 1 Fig 2 Fig 3 Fig 6,具体内容见review


- 1D sine funtion (**Eq. 1** in the manuscript)

**PDE**: $\frac{\partial^2 }{\partial x^2}u(x)-4\mu^2\pi^2 sin(2\mu\pi(x))=0, x\in[0,1]$
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

之前的修改都很好,都是符合要求的,但是图片结果上还有一些不清楚的地方,特此说明下:
论文图片指标:

  1. Fig 1 :单层PINN方法(不是PIRBN),公式(13),mu = 4, 【tanh激活函数】,论文显示勉强收敛
  2. Fig. 2 :单层PINN方法(不是PIRBN),公式(13),改变 mu = 8, 【tanh激活函数】,病态了,论文的图显示结果很差
  3. Fig. 3 :单层PINN方法(不是PIRBN),改变 公式(15), mu = 4, 【tanh激活函数】,病态了,论文的图显示结果很差
  4. Fig. 6: 单层PIRBN(源代码实现),改变 公式(15),改变 mu = 8, 改变【高斯激活函数】,结果非常优秀,说明PIRBN很有用

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

另外,此处公式好像和源代码对不上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fig1, Fig2, Fig3 使用tanh和论文不一致

Fig6改变激活函数和论文一致
Fig1
图片
Fig2
图片
Fig3
图片
Fig6
图片

公式和源码是一致
图片

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不同激活函数设置

    def gaussian_function(self, temp_x):
        x0 = (
            paddle.reshape(
                paddle.arange(self.n_neu, dtype=paddle.get_default_dtype()),
                (1, self.n_neu),
            )
            * (self.c[1] - self.c[0])
            / (self.n_neu - 1)
            + self.c[0]
        )
        x_new = temp_x - x0
        s = self.b * self.b
        return paddle.exp(-(x_new * x_new) * s)

    def tanh_function(self, temp_x):
        return paddle.tanh(temp_x)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能源代码有点问题Fig 1我用了这篇文章的Reference源码做了一下图
PredictiveIntelligenceLab/PINNsNTK@18ef519
image
可能jacobian那个函数有点问题,看看和这个仓库对齐一下

import paddle


class Dif(paddle.nn.Layer):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个类有些多余,请去掉并合并

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

return [u_xx, u_b]

def cal_ntk(self, x):
# Formula (4), Page5, \gamma variable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是指这个公式吗?
Uploading image.png…

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个论文是旧版的,作者8月3号更新了论文新地址和版本,辛苦更新一下页码
image

gamma_g = gamma_g + paddle.sum(j**2) / n1
temp = paddle.concat((l1t[0], l1t[1].reshape((1, n_neu))), axis=1)
if i == 0:
# Fig.1, Page8, Kg variable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

页码好像是5?
image

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss.backward()
self.optimizer.step()
self.optimizer.clear_grad()
return loss, [self.his_l1, self.his_l2]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

返回值没用上, 可以去掉了

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

def cal_ntk(self, x):
# Formula (4), Page5, \gamma variable
gamma_g = 0.0
gamma_b = 0.0
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

]
self.y_train = paddle.to_tensor(y_train, dtype=paddle.get_default_dtype())
self.maxiter = maxiter
self.his_l1 = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为loss_b 并注释是boundary loss

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

self.y_train = paddle.to_tensor(y_train, dtype=paddle.get_default_dtype())
self.maxiter = maxiter
self.his_l1 = []
self.his_l2 = []
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

改为loss_g 并注释是eq loss

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1


def evaluate(self):
# compute loss
loss, l1, l2 = self.Loss(self.x_train, self.y_train, self.a_g, self.a_b)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上,增强可读性

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

+ self.c[0]
)
x_new = temp_x - x0
return self.rbf_activate(x_new)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

如果和PINN对比,可能会出现两种激活函数 tanh 和 这个rbf_activate在论文里名称应该叫Gaussian function,改一下激活函数名

注释 #Formula (19), Page7

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor

@wangguan1995 wangguan1995 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need debug for fig 1


- 1D sine funtion (**Eq. 1** in the manuscript)

**PDE**: $\frac{\partial^2 }{\partial x^2}u(x)-4\mu^2\pi^2 sin(2\mu\pi(x))=0, x\in[0,1]$
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

可能源代码有点问题Fig 1我用了这篇文章的Reference源码做了一下图
PredictiveIntelligenceLab/PINNsNTK@18ef519
image
可能jacobian那个函数有点问题,看看和这个仓库对齐一下

return [u_xx, u_b]

def cal_ntk(self, x):
# Formula (4), Page5, \gamma variable
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个论文是旧版的,作者8月3号更新了论文新地址和版本,辛苦更新一下页码
image

+ self.c[0]
)
x_new = temp_x - x0
return self.rbf_activate(x_new)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

@@ -0,0 +1,67 @@
import paddle

# Used to calculate the second-order derivatives
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删掉

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

已增加 PINN 目录运行PINN.py, Kg结果图
图片

Page页码已修改

注释已删除
图片

Copy link
Contributor

@wangguan1995 wangguan1995 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need to fix

jointContribution/PIRBN/rbn_net.py Outdated Show resolved Hide resolved
docs/zh/examples/pirbn.md Outdated Show resolved Hide resolved
jointContribution/PIRBN/main.py Outdated Show resolved Hide resolved
docs/zh/examples/pirbn.md Show resolved Hide resolved
jointContribution/PIRBN/main.py Outdated Show resolved Hide resolved
jointContribution/PIRBN/main.py Outdated Show resolved Hide resolved
jointContribution/PIRBN/main.py Outdated Show resolved Hide resolved
jointContribution/PIRBN/rbn_net.py Outdated Show resolved Hide resolved
Copy link
Contributor

@wangguan1995 wangguan1995 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

need fix

docs/zh/examples/pirbn.md Outdated Show resolved Hide resolved
Copy link
Contributor

@wangguan1995 wangguan1995 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pass

jointContribution/PIRBN/jacobian_test.py Outdated Show resolved Hide resolved
jointContribution/PIRBN/rbn_net.py Outdated Show resolved Hide resolved
jointContribution/PIRBN/rbn_net.py Outdated Show resolved Hide resolved
jointContribution/PIRBN/rbn_net.py Outdated Show resolved Hide resolved
jointContribution/PIRBN/rbn_net.py Outdated Show resolved Hide resolved
jointContribution/PIRBN/rbn_net.py Outdated Show resolved Hide resolved
jointContribution/PIRBN/train.py Outdated Show resolved Hide resolved
jointContribution/PIRBN/README.md Show resolved Hide resolved
Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@HydrogenSulfate HydrogenSulfate merged commit eab263b into PaddlePaddle:develop Oct 20, 2023
4 checks passed
@co63oc co63oc deleted the pirbn branch October 31, 2023 06:50
huohuohuohuohuo123 pushed a commit to huohuohuohuohuo123/PaddleScience that referenced this pull request Aug 12, 2024
* Add PIRBN

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix

* Fix
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants