Skip to content

Commit

Permalink
Backend TensorFlow 1.x: Fix dense error with Keras 3
Browse files Browse the repository at this point in the history
  • Loading branch information
lululxvi committed Jun 26, 2024
1 parent 2000be6 commit f34e81a
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 12 deletions.
4 changes: 3 additions & 1 deletion deepxde/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,9 @@ def losses(losses_fn):
losses = [losses]
# Regularization loss
if self.net.regularizer is not None:
losses.append(tf.losses.get_regularization_loss())
losses.append(
tf.losses.get_regularization_loss() + self.net.regularization_loss
)
losses = tf.convert_to_tensor(losses)
# Weighted losses
if self.loss_weights is not None:
Expand Down
15 changes: 5 additions & 10 deletions deepxde/nn/tensorflow_compat_v1/fnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,23 +105,18 @@ def build(self):
self.built = True

def _dense(self, inputs, units, activation=None, use_bias=True):
# Cannot directly replace tf.layers.dense() with tf.keras.layers.Dense() due to
# some differences. One difference is that tf.layers.dense() will add
# regularizer loss to the collection REGULARIZATION_LOSSES, but
# tf.keras.layers.Dense() will not. Hence, tf.losses.get_regularization_loss()
# cannot be used for tf.keras.layers.Dense().
# References:
# - https://github.com/tensorflow/tensorflow/issues/21587
# - https://www.tensorflow.org/guide/migrate
return tf.layers.dense(
inputs,
dense = tf.keras.layers.Dense(
units,
activation=activation,
use_bias=use_bias,
kernel_initializer=self.kernel_initializer,
kernel_regularizer=self.regularizer,
kernel_constraint=self.kernel_constraint,
)
out = dense(inputs)
if self.regularizer:
self.regularization_loss += tf.math.add_n(dense.losses)
return out

@staticmethod
def _dense_weightnorm(inputs, units, activation=None, use_bias=True):
Expand Down
9 changes: 9 additions & 0 deletions deepxde/nn/tensorflow_compat_v1/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,15 @@ class NN:
def __init__(self):
self.training = tf.placeholder(tf.bool)
self.regularizer = None
# tf.layers.dense() is not available for TensorFlow 2.16+ with Keras 3. The
# corresponding layer is tf.keras.layers.Dense(). However, tf.layers.dense()
# adds regularizer loss to the collection REGULARIZATION_LOSSES, which can be
# accessed by tf.losses.get_regularization_loss(), but tf.keras.layers.Dense()
# adds regularizer loss to Layer.losses. Hence, we use self.regularization_loss
# to collect tf.keras.layers.Dense() regularization loss.
# References:
# - https://github.com/tensorflow/tensorflow/issues/21587
self.regularization_loss = 0

self._auxiliary_vars = tf.placeholder(config.real(tf), [None, None])
self._input_transform = None
Expand Down
2 changes: 1 addition & 1 deletion docs/user/installation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ DeepXDE requires one of the following backend-specific dependencies to be instal

- TensorFlow 1.x: `TensorFlow <https://www.tensorflow.org>`_>=2.7.0

- For TensorFlow 2.16+ with Keras 3, to keep using Keras 2, you can first install `tf-keras <https://pypi.org/project/tf-keras>`_, and then set the environment variable ``TF_USE_LEGACY_KERAS=1`` directly or in your python program with ``import os;os.environ["TF_USE_LEGACY_KERAS"]="1"``. [`Reference <https://keras.io/keras_3>`_]
- If you use TensorFlow 2.16+ and have an error with Keras 3, to keep using Keras 2, you can first install `tf-keras <https://pypi.org/project/tf-keras>`_, and then set the environment variable ``TF_USE_LEGACY_KERAS=1`` directly or in your python program with ``import os;os.environ["TF_USE_LEGACY_KERAS"]="1"``. [`Reference <https://keras.io/keras_3>`_]

- TensorFlow 2.x: `TensorFlow <https://www.tensorflow.org>`_>=2.3.0, `TensorFlow Probability <https://www.tensorflow.org/probability>`_>=0.11.0

Expand Down

0 comments on commit f34e81a

Please sign in to comment.