diff --git a/deepxde/nn/regularizers.py b/deepxde/nn/regularizers.py index a730655eb..1f2727d34 100644 --- a/deepxde/nn/regularizers.py +++ b/deepxde/nn/regularizers.py @@ -1,14 +1,5 @@ from ..backend import tf -REGULARIZER_DICT = { - "l1": tf.keras.regularizers.L1, - "l2": tf.keras.regularizers.L2, - "l1l2": tf.keras.regularizers.L1L2, - "l1+l2": tf.keras.regularizers.L1L2, -} -if hasattr(tf.keras.regularizers, "OrthogonalRegularizer"): - REGULARIZER_DICT["orthogonal"] = tf.keras.regularizers.OrthogonalRegularizer - def get(identifier): """Retrieves a TensorFlow regularizer instance based on the given identifier. @@ -24,35 +15,27 @@ def get(identifier): # TODO: other backends if identifier is None or not identifier: return None + if not isinstance(identifier, (list, tuple)): + raise ValueError("Identifier must be a list or a tuple.") - if isinstance(identifier, (list, tuple)): - name = identifier[0].lower() - factor = identifier[1:] - else: - raise ValueError("Identifier must be a non-empty list or tuple.") - + name = identifier[0].lower() + factor = identifier[1:] if not factor: raise ValueError("Regularization factor must be provided.") - regularizer_class = REGULARIZER_DICT.get(name) - if not regularizer_class: - if name == "orthogonal": + if name == "l1": + return tf.keras.regularizers.L1(l1=factor[0]) + if name == "l2": + return tf.keras.regularizers.L2(l2=factor[0]) + if name == "orthogonal": + if not hasattr(tf.keras.regularizers, "OrthogonalRegularizer"): raise ValueError( "The 'orthogonal' regularizer is not available " "in your version of TensorFlow" ) - raise ValueError(f"Unknown regularizer name: {name}") - - regularizer_kwargs = {} - if name == "l1": - regularizer_kwargs["l1"] = factor[0] - elif name == "l2": - regularizer_kwargs["l2"] = factor[0] - elif name == "orthogonal": - regularizer_kwargs["factor"] = factor[0] - elif name in ("l1l2", "l1+l2"): + return tf.keras.regularizers.OrthogonalRegularizer(factor=factor[0]) + if name in ("l1l2", "l1+l2"): if len(factor) < 2: raise ValueError("L1L2 regularizer requires both L1/L2 penalties.") - regularizer_kwargs["l1"] = factor[0] - regularizer_kwargs["l2"] = factor[1] - return regularizer_class(**regularizer_kwargs) + return tf.keras.regularizers.L1L2(l1=factor[0], l2=factor[1]) + raise ValueError(f"Unknown regularizer name: {name}")