Skip to content

Commit

Permalink
Simplify the code
Browse files Browse the repository at this point in the history
  • Loading branch information
vl-dud committed Oct 31, 2024
1 parent 65d1208 commit 3de1044
Showing 1 changed file with 14 additions and 31 deletions.
45 changes: 14 additions & 31 deletions deepxde/nn/regularizers.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,5 @@
from ..backend import tf

REGULARIZER_DICT = {
"l1": tf.keras.regularizers.L1,
"l2": tf.keras.regularizers.L2,
"l1l2": tf.keras.regularizers.L1L2,
"l1+l2": tf.keras.regularizers.L1L2,
}
if hasattr(tf.keras.regularizers, "OrthogonalRegularizer"):
REGULARIZER_DICT["orthogonal"] = tf.keras.regularizers.OrthogonalRegularizer


def get(identifier):
"""Retrieves a TensorFlow regularizer instance based on the given identifier.
Expand All @@ -24,35 +15,27 @@ def get(identifier):
# TODO: other backends
if identifier is None or not identifier:
return None
if not isinstance(identifier, (list, tuple)):
raise ValueError("Identifier must be a list or a tuple.")

if isinstance(identifier, (list, tuple)):
name = identifier[0].lower()
factor = identifier[1:]
else:
raise ValueError("Identifier must be a non-empty list or tuple.")

name = identifier[0].lower()
factor = identifier[1:]
if not factor:
raise ValueError("Regularization factor must be provided.")

regularizer_class = REGULARIZER_DICT.get(name)
if not regularizer_class:
if name == "orthogonal":
if name == "l1":
return tf.keras.regularizers.L1(l1=factor[0])
if name == "l2":
return tf.keras.regularizers.L2(l2=factor[0])
if name == "orthogonal":
if not hasattr(tf.keras.regularizers, "OrthogonalRegularizer"):
raise ValueError(
"The 'orthogonal' regularizer is not available "
"in your version of TensorFlow"
)
raise ValueError(f"Unknown regularizer name: {name}")

regularizer_kwargs = {}
if name == "l1":
regularizer_kwargs["l1"] = factor[0]
elif name == "l2":
regularizer_kwargs["l2"] = factor[0]
elif name == "orthogonal":
regularizer_kwargs["factor"] = factor[0]
elif name in ("l1l2", "l1+l2"):
return tf.keras.regularizers.OrthogonalRegularizer(factor=factor[0])
if name in ("l1l2", "l1+l2"):
if len(factor) < 2:
raise ValueError("L1L2 regularizer requires both L1/L2 penalties.")
regularizer_kwargs["l1"] = factor[0]
regularizer_kwargs["l2"] = factor[1]
return regularizer_class(**regularizer_kwargs)
return tf.keras.regularizers.L1L2(l1=factor[0], l2=factor[1])
raise ValueError(f"Unknown regularizer name: {name}")

0 comments on commit 3de1044

Please sign in to comment.