Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Tensorflow 1.x backend: multiple outputs extension of DeepONet #1410

Closed
wants to merge 45 commits into from
Closed
Show file tree
Hide file tree
Changes from 7 commits
Commits
Show all changes
45 commits
Select commit Hold shift + click to select a range
9a7e38b
Tensorflow 1.x backend: multiple outputs extension of DeepONet
vl-dud Jul 31, 2023
ae48af0
Codacy Pylint fix
vl-dud Aug 3, 2023
76b7964
move vanilla deeponet building into a separate method
vl-dud Aug 4, 2023
c0e06a5
Remove unwanted method
vl-dud Aug 23, 2023
8338b81
Change `output_count` to `num_outputs`; format via Black
vl-dud Aug 26, 2023
1515c4b
add DeepONet building strategies
vl-dud Sep 15, 2023
ba8d2a0
Add docs for the strategy argument
vl-dud Sep 18, 2023
5087fc2
Format comments
vl-dud Sep 18, 2023
44dae05
Use maximum 88 characters per line
vl-dud Sep 18, 2023
4f23bf8
rename merge to merge_branch_trunk
vl-dud Sep 20, 2023
6f75d99
rename merge to merge_branch_trunk
vl-dud Oct 3, 2023
367905f
Change default deeponet strategy
vl-dud Oct 4, 2023
1d233c8
Change strategy to multi_output_strategy
vl-dud Oct 9, 2023
9fe3572
Codacy Pylint fix
vl-dud Oct 9, 2023
cb2b3fc
Update deeponet.py for tf2 multiple outputs
mitchelldaneker Oct 9, 2023
e7b7e5d
Update deeponet.py
mitchelldaneker Oct 10, 2023
9f41776
Update deeponet.py
mitchelldaneker Oct 10, 2023
97c3641
Add files via upload
mitchelldaneker Oct 10, 2023
85e5984
Update triple.py
mitchelldaneker Oct 10, 2023
b33f812
Merge remote-tracking branch 'origin/master' into deeponet-multiple-o…
vl-dud Oct 13, 2023
25bf219
Add DeepONet strategy classes to __init__.py
vl-dud Oct 13, 2023
44bfd0a
Update __init__.py
mitchelldaneker Oct 13, 2023
d11ab3a
Update deeponet.py
mitchelldaneker Oct 13, 2023
10ed010
Update __init__.py
mitchelldaneker Oct 13, 2023
3ccd772
Update deeponet.py
mitchelldaneker Oct 13, 2023
64ab358
Update antiderivative_aligned_UQ.py
mitchelldaneker Oct 13, 2023
68b2733
Update deeponet.py
mitchelldaneker Oct 13, 2023
569f94e
Revert "Add DeepONet strategy classes to __init__.py"
vl-dud Oct 16, 2023
7c4f750
Hide deeponet strategy classes
vl-dud Oct 16, 2023
ee3eccc
Update triple.py
mitchelldaneker Oct 16, 2023
91a07e9
Update deeponet.py
mitchelldaneker Oct 16, 2023
1eda936
Merge pull request #3 from mitchelldaneker/multiple-outputs-deeponet-tf2
vl-dud Oct 19, 2023
4c8c40e
Format a code with Black
vl-dud Oct 19, 2023
bed66e0
Codacy Pylint fix
vl-dud Oct 19, 2023
7d938c5
Codacy Pylint fix
vl-dud Oct 19, 2023
509f42c
Update triple.py
mitchelldaneker Oct 19, 2023
5f67bdd
Update deeponet.py
mitchelldaneker Oct 19, 2023
b9bf993
Update triple.py
mitchelldaneker Oct 19, 2023
5d66929
Update deeponet.py
mitchelldaneker Oct 20, 2023
f3bb8d2
Update deeponet.py
mitchelldaneker Oct 20, 2023
8102950
Update deeponet.py
mitchelldaneker Oct 20, 2023
54662db
Merge pull request #6 from mitchelldaneker/tf_multiple_outputs
vl-dud Oct 20, 2023
2459f80
Update deeponet.py
mitchelldaneker Oct 20, 2023
b79e0e0
Update triple.py
mitchelldaneker Oct 20, 2023
226ddac
Merge pull request #7 from mitchelldaneker/tf2_multiple_outputs
vl-dud Oct 20, 2023
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions deepxde/data/pde_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,8 +236,7 @@ def _losses(self, outputs, loss_fn, inputs, model, num_func):

losses = []
for i in range(num_func):
out = outputs[i][:, None]

out = outputs[i] if model.net.num_outputs > 1 else outputs[i][:, None]
f = []
if self.pde.pde is not None:
f = self.pde.pde(inputs[1], out, model.net.auxiliary_vars[i][:, None])
Expand Down
249 changes: 218 additions & 31 deletions deepxde/nn/tensorflow_compat_v1/deeponet.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,131 @@
from ... import config
from ...backend import tf
from ...utils import timing
from abc import ABC, abstractmethod
lululxvi marked this conversation as resolved.
Show resolved Hide resolved


class DeepONetStrategy(ABC):
"""DeepONet building strategy.

See the section 3.1.6. in
L. Lu, X. Meng, S. Cai, Z. Mao, S. Goswami, Z. Zhang, & G. Karniadakis.
A comprehensive and fair comparison of two neural operators
(with practical extensions) based on FAIR data.
Computer Methods in Applied Mechanics and Engineering, 393, 114778, 2022.
"""

def __init__(self, net):
self.net = net

def _build_branch_and_trunk(self):
# Branch net to encode the input function
branch = self.net.build_branch_net()
# Trunk net to encode the domain of the output function
trunk = self.net.build_trunk_net()
return branch, trunk

@abstractmethod
def build(self):
pass


class VanillaStrategy(DeepONetStrategy):
def build(self):
branch, trunk = self._build_branch_and_trunk()
if branch.shape[-1] != trunk.shape[-1]:
raise AssertionError(
"Output sizes of branch net and trunk net do not match."
)
y = self.net.merge(branch, trunk)
return y


class IndependentStrategy(DeepONetStrategy):
"""Directly use n independent DeepONets,
and each DeepONet outputs only one function.
"""

def build(self):
vanilla_strategy = VanillaStrategy(self.net)
ys = []
for _ in range(self.net.num_outputs):
ys.append(vanilla_strategy.build())
return self.net.concatenate_outputs(ys)


class SplitBothStrategy(DeepONetStrategy):
"""Split the outputs of both the branch net and the trunk net into n groups,
and then the kth group outputs the kth solution.

For example, if n = 2 and both the branch and trunk nets have 100 output neurons,
then the dot product between the first 50 neurons of
the branch and trunk nets generates the first function,
and the remaining 50 neurons generate the second function.
"""

def build(self):
branch, trunk = self._build_branch_and_trunk()
if branch.shape[-1] != trunk.shape[-1]:
raise AssertionError(
"Output sizes of branch net and trunk net do not match."
)
if branch.shape[-1] % self.net.num_outputs != 0:
raise AssertionError(
f"Output size of the branch net is not evenly divisible by {self.net.num_outputs}."
)
branch_groups = tf.split(
branch, num_or_size_splits=self.net.num_outputs, axis=1
)
trunk_groups = tf.split(trunk, num_or_size_splits=self.net.num_outputs, axis=1)
ys = []
for i in range(self.net.num_outputs):
y = self.net.merge(branch_groups[i], trunk_groups[i])
ys.append(y)
return self.net.concatenate_outputs(ys)


class SplitBranchStrategy(DeepONetStrategy):
"""Split the branch net and share the trunk net."""

def build(self):
branch, trunk = self._build_branch_and_trunk()
if branch.shape[-1] % self.net.num_outputs != 0:
raise AssertionError(
f"Output size of the branch net is not evenly divisible by {self.net.num_outputs}."
)
if branch.shape[-1] / self.net.num_outputs != trunk.shape[-1]:
raise AssertionError(
f"Output size of the trunk net does not equal to {branch.shape[-1] // self.net.num_outputs}."
)
branch_groups = tf.split(
branch, num_or_size_splits=self.net.num_outputs, axis=1
)
ys = []
for i in range(self.net.num_outputs):
y = self.net.merge(branch_groups[i], trunk)
ys.append(y)
return self.net.concatenate_outputs(ys)


class SplitTrunkStrategy(DeepONetStrategy):
"""Split the trunk net and share the branch net."""

def build(self):
branch, trunk = self._build_branch_and_trunk()
if trunk.shape[-1] % self.net.num_outputs != 0:
raise AssertionError(
f"Output size of the trunk net is not evenly divisible by {self.net.num_outputs}."
)
if trunk.shape[-1] / self.net.num_outputs != branch.shape[-1]:
raise AssertionError(
f"Output size of the branch net does not equal to {trunk.shape[-1] // self.net.num_outputs}."
)
trunk_groups = tf.split(trunk, num_or_size_splits=self.net.num_outputs, axis=1)
ys = []
for i in range(self.net.num_outputs):
y = self.net.merge(branch, trunk_groups[i])
ys.append(y)
return self.net.concatenate_outputs(ys)


class DeepONet(NN):
Expand All @@ -20,7 +145,7 @@ class DeepONet(NN):
layer_sizes_branch: A list of integers as the width of a fully connected
network, or `(dim, f)` where `dim` is the input dimension and `f` is a
network function. The width of the last layer in the branch and trunk net
should be equal.
should be equal. The exception is the use of "split_branch" and "split_trunk" strategies.
layer_sizes_trunk (list): A list of integers as the width of a fully connected
network.
activation: If `activation` is a ``string``, then the same activation is used in
Expand All @@ -29,6 +154,15 @@ class DeepONet(NN):
`activation["branch"]`.
trainable_branch: Boolean.
trainable_trunk: Boolean or a list of booleans.
num_outputs (integer): number of outputs.
strategy (str): "vanilla", "independent", "split_both", "split_branch" or "split_trunk".
lululxvi marked this conversation as resolved.
Show resolved Hide resolved
It makes sense to set in case of multiple outputs.

- Сhoose "vanilla" for classical implementation of DeepONet. Can not be used with num_outputs > 1.
- Сhoose "independent" to use num_outputs independent DeepONets, and each DeepONet outputs only one function.
- Сhoose "split_both" to split the outputs of both the branch net and the trunk net into num_outputs groups, and then the kth group outputs the kth solution.
- Сhoose "split_branch" to split the branch net and share the trunk net. The width of the last layer in the branch net should be equal to the one in the trunk net multiplied by the number of outputs.
- Сhoose "split_trunk" to split the trunk net and share the branch net. The width of the last layer in the trunk net should be equal to the one in the branch net multiplied by the number of outputs.
"""

def __init__(
Expand All @@ -42,6 +176,8 @@ def __init__(
stacked=False,
trainable_branch=True,
trainable_trunk=True,
num_outputs=1,
strategy="independent",
):
super().__init__()
if isinstance(trainable_trunk, (list, tuple)):
Expand Down Expand Up @@ -69,6 +205,22 @@ def __init__(
self._inputs = None
self._X_func_default = None

self.num_outputs = num_outputs
if self.num_outputs == 1:
if strategy != "vanilla":
strategy = "vanilla"
print('Strategy is forcibly changed to "vanilla".')
elif strategy == "vanilla":
strategy = "independent"
print('Strategy is forcibly changed to "independent".')
self.strategy = {
"independent": IndependentStrategy,
"split_both": SplitBothStrategy,
"split_branch": SplitBranchStrategy,
"split_trunk": SplitTrunkStrategy,
"vanilla": VanillaStrategy,
}.get(strategy, IndependentStrategy)(self)

@property
def inputs(self):
return self._inputs
Expand Down Expand Up @@ -101,7 +253,14 @@ def build(self):
self.X_loc = tf.placeholder(config.real(tf), [None, self.layer_size_loc[0]])
self._inputs = [self.X_func, self.X_loc]

# Branch net to encode the input function
self.y = self.strategy.build()
if self._output_transform is not None:
self.y = self._output_transform(self._inputs, self.y)

self.target = tf.placeholder(config.real(tf), [None, self.num_outputs])
self.built = True

def build_branch_net(self):
y_func = self.X_func
if callable(self.layer_size_func[1]):
# User-defined network
Expand Down Expand Up @@ -141,8 +300,9 @@ def build(self):
regularizer=self.regularizer,
trainable=self.trainable_branch,
)
return y_func

# Trunk net to encode the domain of the output function
def build_trunk_net(self):
y_loc = self.X_loc
if self._input_transform is not None:
y_loc = self._input_transform(y_loc)
Expand All @@ -156,24 +316,20 @@ def build(self):
if isinstance(self.trainable_trunk, (list, tuple))
else self.trainable_trunk,
)
return y_loc

def merge(self, branch, trunk):
lululxvi marked this conversation as resolved.
Show resolved Hide resolved
# Dot product
if y_func.shape[-1] != y_loc.shape[-1]:
raise AssertionError(
"Output sizes of branch net and trunk net do not match."
)
self.y = tf.einsum("bi,bi->b", y_func, y_loc)
self.y = tf.expand_dims(self.y, axis=1)
# Add bias
y = tf.einsum("bi,bi->b", branch, trunk)
y = tf.expand_dims(y, axis=1)
if self.use_bias:
b = tf.Variable(tf.zeros(1, dtype=config.real(tf)))
self.y += b

if self._output_transform is not None:
self.y = self._output_transform(self._inputs, self.y)
y += b
return y

self.target = tf.placeholder(config.real(tf), [None, 1])
self.built = True
@staticmethod
def concatenate_outputs(ys):
return tf.concat(ys, axis=1)

def _dense(
self,
Expand Down Expand Up @@ -252,13 +408,22 @@ class DeepONetCartesianProd(NN):
layer_size_branch: A list of integers as the width of a fully connected network,
or `(dim, f)` where `dim` is the input dimension and `f` is a network
function. The width of the last layer in the branch and trunk net should be
equal.
equal. The exception is the use of "split_branch" and "split_trunk" strategies.
layer_size_trunk (list): A list of integers as the width of a fully connected
network.
activation: If `activation` is a ``string``, then the same activation is used in
both trunk and branch nets. If `activation` is a ``dict``, then the trunk
net uses the activation `activation["trunk"]`, and the branch net uses
`activation["branch"]`.
num_outputs (integer): number of outputs.
strategy (str): "vanilla", "independent", "split_both", "split_branch" or "split_trunk".
It makes sense to set in case of multiple outputs.

- Сhoose "vanilla" for classical implementation of DeepONet. Can not be used with num_outputs > 1.
lululxvi marked this conversation as resolved.
Show resolved Hide resolved
- Сhoose "independent" to use num_outputs independent DeepONets, and each DeepONet outputs only one function.
- Сhoose "split_both" to split the outputs of both the branch net and the trunk net into num_outputs groups, and then the kth group outputs the kth solution.
- Сhoose "split_branch" to split the branch net and share the trunk net. The width of the last layer in the branch net should be equal to the one in the trunk net multiplied by the number of outputs.
- Сhoose "split_trunk" to split the trunk net and share the branch net. The width of the last layer in the trunk net should be equal to the one in the branch net multiplied by the number of outputs.
"""

def __init__(
Expand All @@ -268,6 +433,8 @@ def __init__(
activation,
kernel_initializer,
regularization=None,
num_outputs=1,
strategy="independent",
):
super().__init__()
self.layer_size_func = layer_size_branch
Expand All @@ -279,9 +446,24 @@ def __init__(
self.activation_branch = self.activation_trunk = activations.get(activation)
self.kernel_initializer = initializers.get(kernel_initializer)
self.regularizer = regularizers.get(regularization)

self._inputs = None

self.num_outputs = num_outputs
if self.num_outputs == 1:
if strategy != "vanilla":
strategy = "vanilla"
print('Strategy is forcibly changed to "vanilla".')
elif strategy == "vanilla":
strategy = "independent"
print('Strategy is forcibly changed to "independent".')
self.strategy = {
"independent": IndependentStrategy,
"split_both": SplitBothStrategy,
"split_branch": SplitBranchStrategy,
"split_trunk": SplitTrunkStrategy,
"vanilla": VanillaStrategy,
}.get(strategy, IndependentStrategy)(self)

@property
def inputs(self):
return self._inputs
Expand All @@ -301,7 +483,14 @@ def build(self):
self.X_loc = tf.placeholder(config.real(tf), [None, self.layer_size_loc[0]])
self._inputs = [self.X_func, self.X_loc]

# Branch net to encode the input function
self.y = self.strategy.build()
if self._output_transform is not None:
self.y = self._output_transform(self._inputs, self.y)

self.target = tf.placeholder(config.real(tf), [None, None])
self.built = True

def build_branch_net(self):
y_func = self.X_func
if callable(self.layer_size_func[1]):
# User-defined network
Expand All @@ -322,7 +511,9 @@ def build(self):
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.regularizer,
)
return y_func

def build_trunk_net(self):
# Trunk net to encode the domain of the output function
y_loc = self.X_loc
if self._input_transform is not None:
Expand All @@ -335,19 +526,15 @@ def build(self):
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.regularizer,
)
return y_loc

# Dot product
if y_func.shape[-1] != y_loc.shape[-1]:
raise AssertionError(
"Output sizes of branch net and trunk net do not match."
)
self.y = tf.einsum("bi,ni->bn", y_func, y_loc)
def merge(self, branch, trunk):
y = tf.einsum("bi,ni->bn", branch, trunk)
# Add bias
b = tf.Variable(tf.zeros(1, dtype=config.real(tf)))
self.y += b
y += b
return y

if self._output_transform is not None:
self.y = self._output_transform(self._inputs, self.y)

self.target = tf.placeholder(config.real(tf), [None, None])
self.built = True
@staticmethod
def concatenate_outputs(ys):
return tf.stack(ys, axis=2)
Loading