Skip to content

Commit

Permalink
Backend paddle: Support deeponet and other examples
Browse files Browse the repository at this point in the history
  • Loading branch information
lijialin03 committed Aug 4, 2023
1 parent 197f298 commit 9f59959
Show file tree
Hide file tree
Showing 11 changed files with 84 additions and 37 deletions.
4 changes: 2 additions & 2 deletions deepxde/icbc/boundary_conditions.py
Original file line number Diff line number Diff line change
Expand Up @@ -193,9 +193,9 @@ def __init__(
self.batch_size = batch_size

if batch_size is not None: # batch iterator and state
if backend_name != "pytorch":
if backend_name not in ["pytorch","paddle"] :
raise RuntimeError(
"batch_size only implemented for pytorch backend"
"batch_size only implemented for pytorch and paddle backend"
)
self.batch_sampler = data.sampler.BatchSampler(
len(self), shuffle=shuffle
Expand Down
14 changes: 11 additions & 3 deletions deepxde/nn/paddle/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -75,8 +75,12 @@ def forward(self, inputs):
raise AssertionError(
"Output sizes of branch net and trunk net do not match."
)
x = paddle.einsum("bi,bi->b", x_func, x_loc) # [batch_size, ]
x = paddle.reshape(x, [-1, 1]) # reshape [batch_size, ] to [batch_size, 1]
# Use the following formula to temporarily replace paddle.einsum()
# Because no higher(>=2) orderderivatives for the op now
x = paddle.sum(x_func*x_loc,axis = 1,keepdim=True)
# TODO:
# x = paddle.einsum("bi,bi->b", x_func, x_loc) # [batch_size, ]
# x = paddle.reshape(x, [-1, 1]) # reshape [batch_size, ] to [batch_size, 1]
# Add bias
if self.use_bias:
x += self.b
Expand Down Expand Up @@ -143,7 +147,11 @@ def forward(self, inputs):
raise AssertionError(
"Output sizes of branch net and trunk net do not match."
)
x = paddle.einsum("bi,ni->bn", x_func, x_loc)
# Use the following formula to temporarily replace paddle.einsum()
# Because no higher(>=2) orderderivatives for the op now
x = x_func@x_loc.T
# TODO:
# x = paddle.einsum("bi,ni->bn", x_func, x_loc)
# Add bias
x += self.b

Expand Down
21 changes: 15 additions & 6 deletions examples/operator/advection_aligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
"""Backend supported: tensorflow.compat.v1, tensorflow"""
"""Backend supported: tensorflow.compat.v1, tensorflow, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.backend import tf

if dde.backend.backend_name == "paddle":
dim_x = 5
sin = dde.backend.paddle.sin
cos = dde.backend.paddle.cos
concat = dde.backend.paddle.concat
else:
dim_x = 2
sin = dde.backend.tf.sin
cos = dde.backend.tf.cos
concat = dde.backend.tf.concat


# PDE
def pde(x, y, v):
Expand Down Expand Up @@ -36,7 +47,7 @@ def func_ic(x, v):
# Net
net = dde.nn.DeepONetCartesianProd(
[50, 128, 128, 128],
[2, 128, 128, 128],
[dim_x, 128, 128, 128],
"tanh",
"Glorot normal",
)
Expand All @@ -45,9 +56,7 @@ def func_ic(x, v):
def periodic(x):
x, t = x[:, :1], x[:, 1:]
x *= 2 * np.pi
return tf.concat(
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
)
return concat([cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1)


net.apply_feature_transform(periodic)
Expand Down
22 changes: 15 additions & 7 deletions examples/operator/advection_aligned_pideeponet_2d.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
"""Backend supported: tensorflow.compat.v1, tensorflow"""
"""Backend supported: tensorflow.compat.v1, tensorflow, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.backend import tf

if dde.backend.backend_name == "paddle":
dim_x = 5
sin = dde.backend.paddle.sin
cos = dde.backend.paddle.cos
concat = dde.backend.paddle.concat
else:
dim_x = 2
sin = dde.backend.tf.sin
cos = dde.backend.tf.cos
concat = dde.backend.tf.concat


# PDE
def pde(x, y, v):
Expand Down Expand Up @@ -41,18 +52,15 @@ def boundary(x, on_boundary):
# Net
net = dde.nn.DeepONetCartesianProd(
[50, 128, 128, 128],
[2, 128, 128, 128],
[dim_x, 128, 128, 128],
"tanh",
"Glorot normal",
)


def periodic(x):
x, t = x[:, :1], x[:, 1:]
x *= 2 * np.pi
return tf.concat(
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
)
return concat([cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1)


net.apply_feature_transform(periodic)
Expand Down
21 changes: 15 additions & 6 deletions examples/operator/advection_unaligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,19 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.backend import tf

if dde.backend.backend_name == "paddle":
dim_x = 5
sin = dde.backend.paddle.sin
cos = dde.backend.paddle.cos
concat = dde.backend.paddle.concat
else:
dim_x = 2
sin = dde.backend.tf.sin
cos = dde.backend.tf.cos
concat = dde.backend.tf.concat


# PDE
def pde(x, y, v):
Expand Down Expand Up @@ -36,7 +47,7 @@ def func_ic(x, v):
# Net
net = dde.nn.DeepONet(
[50, 128, 128, 128],
[2, 128, 128, 128],
[dim_x, 128, 128, 128],
"tanh",
"Glorot normal",
)
Expand All @@ -45,9 +56,7 @@ def func_ic(x, v):
def periodic(x):
x, t = x[:, :1], x[:, 1:]
x *= 2 * np.pi
return tf.concat(
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
)
return concat([cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1)


net.apply_feature_transform(periodic)
Expand Down
20 changes: 14 additions & 6 deletions examples/operator/advection_unaligned_pideeponet_2d.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.backend import tf

if dde.backend.backend_name == "paddle":
dim_x = 5
sin = dde.backend.paddle.sin
cos = dde.backend.paddle.cos
concat = dde.backend.paddle.concat
else:
dim_x = 2
sin = dde.backend.tf.sin
cos = dde.backend.tf.cos
concat = dde.backend.tf.concat

# PDE
def pde(x, y, v):
Expand Down Expand Up @@ -39,7 +49,7 @@ def boundary(x, on_boundary):
# Net
net = dde.nn.DeepONet(
[50, 128, 128, 128],
[2, 128, 128, 128],
[dim_x, 128, 128, 128],
"tanh",
"Glorot normal",
)
Expand All @@ -48,9 +58,7 @@ def boundary(x, on_boundary):
def periodic(x):
x, t = x[:, :1], x[:, 1:]
x *= 2 * np.pi
return tf.concat(
[tf.math.cos(x), tf.math.sin(x), tf.math.cos(2 * x), tf.math.sin(2 * x), t], 1
)
return concat([cos(x), sin(x), cos(2 * x), sin(2 * x), t], 1)


net.apply_feature_transform(periodic)
Expand Down
11 changes: 8 additions & 3 deletions examples/operator/antiderivative_aligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
"""Backend supported: tensorflow.compat.v1, tensorflow"""
"""Backend supported: tensorflow.compat.v1, tensorflow, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
from deepxde.backend import tf

if dde.backend.backend_name == "paddle":
transpose = dde.backend.paddle.transpose
else:
transpose = dde.backend.tf.transpose



dde.config.disable_xla_jit()
Expand Down Expand Up @@ -37,7 +42,7 @@ def pde(x, u, v):

# Hard constraint zero IC
def zero_ic(inputs, outputs):
return outputs * tf.transpose(inputs[1])
return outputs * transpose(inputs[1],[1,0])


net.apply_output_transform(zero_ic)
Expand Down
2 changes: 1 addition & 1 deletion examples/operator/antiderivative_unaligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion examples/operator/diff_rec_aligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1, tensorflow"""
"""Backend supported: tensorflow.compat.v1, tensorflow, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion examples/operator/diff_rec_unaligned_pideeponet.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1"""
"""Backend supported: tensorflow.compat.v1, paddle"""
import deepxde as dde
import matplotlib.pyplot as plt
import numpy as np
Expand Down
2 changes: 1 addition & 1 deletion examples/pinn_forward/Burgers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch"""
"""Backend supported: tensorflow.compat.v1, tensorflow, pytorch, paddle"""
import deepxde as dde
import numpy as np

Expand Down

0 comments on commit 9f59959

Please sign in to comment.