Skip to content

Commit

Permalink
Fix
Browse files Browse the repository at this point in the history
  • Loading branch information
co63oc committed Sep 11, 2023
1 parent e27f86a commit 3af0d63
Show file tree
Hide file tree
Showing 2 changed files with 19 additions and 8 deletions.
21 changes: 15 additions & 6 deletions jointContribution/XPINNs/XPINN_2D_PoissonsEqn.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
from matplotlib import patches
from matplotlib import tri

import ppsci

# For the use of the second derivative: paddle.cos, paddle.exp
paddle.framework.core.set_prim_eager_enabled(True)

Expand All @@ -32,7 +34,7 @@ def __init__(self, layer_list):
layer_list[2], "layers3"
)

def convert_tensor(self, dataset):
def preprocess_data(self, dataset):
X_ub, ub, X_f1, X_f2, X_f3, X_fi1, X_fi2 = dataset
self.x_ub = paddle.to_tensor(X_ub[:, 0:1], dtype=paddle.float64)
self.y_ub = paddle.to_tensor(X_ub[:, 1:2], dtype=paddle.float64)
Expand All @@ -49,7 +51,7 @@ def convert_tensor(self, dataset):
self.y_fi2 = paddle.to_tensor(X_fi2[:, 1:2], dtype=paddle.float64)

def forward(self, dataset):
self.convert_tensor(dataset)
self.preprocess_data(dataset)
self.ub1_pred = self.net_u1(self.x_ub, self.y_ub)
self.ub2_pred = self.net_u2(self.x_f2, self.y_f2)
self.ub3_pred = self.net_u3(self.x_f3, self.y_f3)
Expand Down Expand Up @@ -119,9 +121,7 @@ def initialize_NN(self, layers, name_prefix):
W = self.create_parameter(
shape=[layers[l], layers[l + 1]],
dtype="float64",
default_initializer=paddle.nn.initializer.XavierNormal(
layers[l], layers[l + 1]
),
default_initializer=self.w_init((layers[l], layers[l + 1])),
)
b = self.create_parameter(
shape=[1, layers[l + 1]],
Expand All @@ -135,6 +135,7 @@ def initialize_NN(self, layers, name_prefix):
is_bias=True,
default_initializer=paddle.nn.initializer.Constant(0.05),
)

self.add_parameter(name_prefix + "_W_" + str(l), W)
self.add_parameter(name_prefix + "_b_" + str(l), b)
self.add_parameter(name_prefix + "_a_" + str(l), a)
Expand All @@ -143,6 +144,14 @@ def initialize_NN(self, layers, name_prefix):
A.append(a)
return weights, biases, A

def w_init(self, size):
in_dim = size[0]
out_dim = size[1]
xavier_stddev = np.sqrt(2 / (in_dim + out_dim))
param = paddle.empty(size, "float64")
param = ppsci.utils.initializer.trunc_normal_(param, 0.0, xavier_stddev)
return lambda p_ten, _: p_ten.set_value(param)

def neural_net_tanh(self, X, weights, biases, A):
num_layers = len(weights) + 1

Expand Down Expand Up @@ -303,7 +312,7 @@ def train(self, nIter, X_star1, X_star2, X_star3, u_exact2, u_exact3):
loss = loss1_value + loss2_value + loss3_value
loss.backward()
self.optimizer.step()
loss.clear_grad()
self.optimizer.clear_grad()

if it % 20 == 0:
# Predicted solution
Expand Down
6 changes: 4 additions & 2 deletions ppsci/utils/initializer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,8 +91,10 @@ def norm_cdf(x):
_tensor.erfinv_()

# Transform to proper mean, std
_tensor = paddle.multiply(_tensor, paddle.to_tensor(std * math.sqrt(2.0)))
_tensor = paddle.add(_tensor, paddle.to_tensor(mean))
_tensor = paddle.multiply(
_tensor, paddle.to_tensor(std * math.sqrt(2.0), _tensor.dtype)
)
_tensor = paddle.add(_tensor, paddle.to_tensor(mean, _tensor.dtype))

# Clamp to ensure it"s in the proper range
_tensor = paddle.clip(_tensor, min=a, max=b)
Expand Down

0 comments on commit 3af0d63

Please sign in to comment.