Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PaddleScience 论文复现 Extended Physics-Informed Neural Networks (XPINNs) #535

Merged
merged 2 commits into from
Sep 13, 2023

Conversation

co63oc
Copy link
Contributor

@co63oc co63oc commented Sep 9, 2023

PR types

Others

PR changes

Others

Describe

PaddlePaddle/Paddle#55663 (comment)

赛题十四:PaddleScience 领域经典论文复现
论文一
题目:Extended Physics-Informed Neural Networks (XPINNs)

网络结构已转换一致,修改为 self.optimizer.clear_grad() 可以对齐精度

            loss = loss1_value + loss2_value + loss3_value
            loss.backward()
            self.optimizer.step()
            self.optimizer.clear_grad()

Test Loss 信息,迭代次数修改为501,tensorflow 1.13.2-gpu 版本,Loss 差异

Test Loss tensorflow paddle
1.794519e-02 2.138322e-01

tensorflow运行
图片

paddle运行
图片

@paddle-bot
Copy link

paddle-bot bot commented Sep 9, 2023

Thanks for your contribution!

loss = loss1_value + loss2_value + loss3_value
loss.backward()
self.optimizer.step()
loss.clear_grad()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

看到精度没对齐,应该是这行的问题?Tensor.clear_grad 只能清空 Tensor 自己的梯度,梯度更新完应该调用self.optimizer.clear_grad()

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为 self.optimizer.clear_grad() 可以对齐精度

@@ -352,7 +352,7 @@ def predict(self, X_star1, X_star2, X_star3):
layers2 = [2, 20, 20, 20, 20, 1]
layers3 = [2, 25, 25, 25, 1]

Max_iter = 501
Max_iter = 701
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

源代码好像是501,这里为改成701的原因是什么呢

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是用来对齐精度,迭代501次Loss较大

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是用来对齐精度,迭代501次Loss较大

一般论文复现默认使用相同的训练策略,如果转换成paddle后,需要增大epoch来对齐精度,可能是某些地方仍然没有对齐,比如参数初始化之类的细节

Copy link
Contributor Author

@co63oc co63oc Sep 11, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改tensorflow 版本1.13.2-gpu,paddle 精度可以对齐,PR描述已更新

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改tensorflow 版本1.13.2-gpu,paddle 精度可以对齐,PR描述已更新

OK,我运行下试试看

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

图片 同样错误不支持float64

看了一下,是这部分代码有点问题,可以再to_tensor末尾加上, _tensor.dtype

_tensor = paddle.multiply(_tensor, paddle.to_tensor(std * math.sqrt(2.0), _tensor.dtype))
_tensor = paddle.add(_tensor, paddle.to_tensor(mean, _tensor.dtype))

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

    def xavier_init(self, size):
        in_dim = size[0]
        out_dim = size[1]        
        import ppsci
        param = paddle.empty(size, "float64")
        param = ppsci.utils.initializer.trunc_normal_(param, 0.0, 1.0)
        return paddle.nn.initializer.Assign(param)

paddle.nn.initializer.Assign 不支持float64
图片

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

暂时用lambda代替下看看

import numpy as np

x_np = np.zeros([2, 2], "float64")

w = paddle.create_parameter(
    x_np.shape,
    dtype="float64",
    default_initializer=lambda p_ten, _: p_ten.set_value(x_np)
)

print(w)

"""
Tensor(shape=[2, 2], dtype=float64, place=Place(gpu:0), stop_gradient=False,
       [[0., 0.],
        [0., 0.]])
"""

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改后精度为 4.684577e-02, tensorflow精度是 1.794519e-02

    def w_init(self, size):
        in_dim = size[0]
        out_dim = size[1]        
        xavier_stddev = np.sqrt(2/(in_dim + out_dim))
        param = paddle.empty(size, "float64")
        param = ppsci.utils.initializer.trunc_normal_(param, 0.0, xavier_stddev)
        return lambda p_ten, _: p_ten.set_value(param)

图片

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改后精度为 4.684577e-02, tensorflow精度是 1.794519e-02

    def w_init(self, size):
        in_dim = size[0]
        out_dim = size[1]        
        xavier_stddev = np.sqrt(2/(in_dim + out_dim))
        param = paddle.empty(size, "float64")
        param = ppsci.utils.initializer.trunc_normal_(param, 0.0, xavier_stddev)
        return lambda p_ten, _: p_ten.set_value(param)

图片

分布相同的情况下,两个框架生成的随机数也是存在不同的,这个diff大概率是合理的,我这边明天换个随机种子再跑下paddle看看

@luotao1 luotao1 added the HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务 label Sep 11, 2023
Comment on lines 151 to 157
param = paddle.empty(size, "float64")
param = ppsci.utils.initializer.trunc_normal_(param, 0.0, xavier_stddev)
return lambda p_ten, _: p_ten.set_value(param)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

加上注释: TODO: Truncated normal and assign support float64

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

Comment on lines 115 to 145
def initialize_NN(self, layers, name_prefix):
weights = []
biases = []
A = []
num_layers = len(layers)
for l in range(0, num_layers - 1):
W = self.create_parameter(
shape=[layers[l], layers[l + 1]],
dtype="float64",
default_initializer=self.w_init((layers[l], layers[l + 1])),
)
b = self.create_parameter(
shape=[1, layers[l + 1]],
dtype="float64",
is_bias=True,
default_initializer=paddle.nn.initializer.Constant(0.0),
)
a = self.create_parameter(
shape=[1],
dtype="float64",
is_bias=True,
default_initializer=paddle.nn.initializer.Constant(0.05),
)

self.add_parameter(name_prefix + "_W_" + str(l), W)
self.add_parameter(name_prefix + "_b_" + str(l), b)
self.add_parameter(name_prefix + "_a_" + str(l), a)
weights.append(W)
biases.append(b)
A.append(a)
return weights, biases, A
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

  1. W和A不要用大写,换成合适的小写名字,并说明变量A的含义(好像是Amplitude?)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改为 amplitudes

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

删除此文件

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

1

--------------------
原仓库 <https://github.com/AmeyaJagtap/XPINNs.git>

安装latex软件包,python依赖
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

详细说明下latex软件包怎么安装的,我这边pip 安装latex仍然会报错

Copy link
Contributor Author

@co63oc co63oc Sep 12, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

下面有代码 apt install latex-cjk-all texlive-latex-extra cm-super dvipng -y
图片

Copy link
Collaborator

@HydrogenSulfate HydrogenSulfate left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
contributor HappyOpenSource Pro 进阶版快乐开源活动,更具挑战性的任务
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants