diff --git a/deepxde/model.py b/deepxde/model.py index 48ac2ef46..072462b36 100644 --- a/deepxde/model.py +++ b/deepxde/model.py @@ -176,7 +176,9 @@ def losses(losses_fn): losses = [losses] # Regularization loss if self.net.regularizer is not None: - losses.append(tf.losses.get_regularization_loss()) + losses.append( + tf.losses.get_regularization_loss() + self.net.regularization_loss + ) losses = tf.convert_to_tensor(losses) # Weighted losses if self.loss_weights is not None: diff --git a/deepxde/nn/tensorflow_compat_v1/fnn.py b/deepxde/nn/tensorflow_compat_v1/fnn.py index 74ffb2e5a..4c9689dd3 100644 --- a/deepxde/nn/tensorflow_compat_v1/fnn.py +++ b/deepxde/nn/tensorflow_compat_v1/fnn.py @@ -105,16 +105,7 @@ def build(self): self.built = True def _dense(self, inputs, units, activation=None, use_bias=True): - # Cannot directly replace tf.layers.dense() with tf.keras.layers.Dense() due to - # some differences. One difference is that tf.layers.dense() will add - # regularizer loss to the collection REGULARIZATION_LOSSES, but - # tf.keras.layers.Dense() will not. Hence, tf.losses.get_regularization_loss() - # cannot be used for tf.keras.layers.Dense(). - # References: - # - https://github.com/tensorflow/tensorflow/issues/21587 - # - https://www.tensorflow.org/guide/migrate - return tf.layers.dense( - inputs, + dense = tf.keras.layers.Dense( units, activation=activation, use_bias=use_bias, @@ -122,6 +113,10 @@ def _dense(self, inputs, units, activation=None, use_bias=True): kernel_regularizer=self.regularizer, kernel_constraint=self.kernel_constraint, ) + out = dense(inputs) + if self.regularizer: + self.regularization_loss += tf.math.add_n(dense.losses) + return out @staticmethod def _dense_weightnorm(inputs, units, activation=None, use_bias=True): diff --git a/deepxde/nn/tensorflow_compat_v1/nn.py b/deepxde/nn/tensorflow_compat_v1/nn.py index 1b68ae514..9181584d1 100644 --- a/deepxde/nn/tensorflow_compat_v1/nn.py +++ b/deepxde/nn/tensorflow_compat_v1/nn.py @@ -11,6 +11,15 @@ class NN: def __init__(self): self.training = tf.placeholder(tf.bool) self.regularizer = None + # tf.layers.dense() is not available for TensorFlow 2.16+ with Keras 3. The + # corresponding layer is tf.keras.layers.Dense(). However, tf.layers.dense() + # adds regularizer loss to the collection REGULARIZATION_LOSSES, which can be + # accessed by tf.losses.get_regularization_loss(), but tf.keras.layers.Dense() + # adds regularizer loss to Layer.losses. Hence, we use self.regularization_loss + # to collect tf.keras.layers.Dense() regularization loss. + # References: + # - https://github.com/tensorflow/tensorflow/issues/21587 + self.regularization_loss = 0 self._auxiliary_vars = tf.placeholder(config.real(tf), [None, None]) self._input_transform = None diff --git a/docs/user/installation.rst b/docs/user/installation.rst index 2d2822ae4..2dcf7eede 100644 --- a/docs/user/installation.rst +++ b/docs/user/installation.rst @@ -8,7 +8,7 @@ DeepXDE requires one of the following backend-specific dependencies to be instal - TensorFlow 1.x: `TensorFlow `_>=2.7.0 - - For TensorFlow 2.16+ with Keras 3, to keep using Keras 2, you can first install `tf-keras `_, and then set the environment variable ``TF_USE_LEGACY_KERAS=1`` directly or in your python program with ``import os;os.environ["TF_USE_LEGACY_KERAS"]="1"``. [`Reference `_] + - If you use TensorFlow 2.16+ and have an error with Keras 3, to keep using Keras 2, you can first install `tf-keras `_, and then set the environment variable ``TF_USE_LEGACY_KERAS=1`` directly or in your python program with ``import os;os.environ["TF_USE_LEGACY_KERAS"]="1"``. [`Reference `_] - TensorFlow 2.x: `TensorFlow `_>=2.3.0, `TensorFlow Probability `_>=0.11.0