From 1896e975b0ef80117278be0fcce61f19f1fa9bf5 Mon Sep 17 00:00:00 2001 From: Arno Eigenwillig Date: Tue, 14 May 2024 06:54:38 -0700 Subject: [PATCH] Add a `kernel_regularizer` kwarg to all Classification and Regression tasks, which gets forwarded to the output (logits) Dense layer, aligning it with the options for weight regularization in MtAlbis and other models. PiperOrigin-RevId: 633570099 --- tensorflow_gnn/runner/tasks/classification.py | 70 ++++++-- .../runner/tasks/classification_test.py | 47 +++-- tensorflow_gnn/runner/tasks/regression.py | 162 ++++++++++++++---- .../runner/tasks/regression_test.py | 99 +++++++---- 4 files changed, 286 insertions(+), 92 deletions(-) diff --git a/tensorflow_gnn/runner/tasks/classification.py b/tensorflow_gnn/runner/tasks/classification.py index 7982a9df..002f962a 100644 --- a/tensorflow_gnn/runner/tasks/classification.py +++ b/tensorflow_gnn/runner/tasks/classification.py @@ -106,7 +106,9 @@ def __init__( *, name: str = "classification_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None, + ): """Sets `Task` parameters. Args: @@ -120,6 +122,9 @@ def __init__( label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + classification logits layer. """ if (label_fn is None) == (label_feature_name is None): raise ValueError( @@ -131,6 +136,7 @@ def __init__( self._name = name self._label_fn = label_fn self._label_feature_name = label_feature_name + self._kernel_regularizer = kernel_regularizer @abc.abstractmethod def gather_activations(self, inputs: GraphTensor) -> Field: @@ -148,7 +154,7 @@ def predict(self, inputs: tfgnn.GraphTensor) -> interfaces.Predictions: tfgnn.check_scalar_graph_tensor(inputs, name="Classification") activations = self.gather_activations(inputs) logits = tf.keras.layers.Dense( - self._units, + self._units, kernel_regularizer=self._kernel_regularizer, name=self._name)(activations) return logits @@ -374,7 +380,8 @@ def __init__(self, reduce_type: str = "mean", name: str = "classification_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Graph binary (or multi-label) classification. This task performs binary classification (or multiple independent ones: @@ -394,6 +401,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + classification logits layer. """ super().__init__( node_set_name, @@ -402,7 +412,9 @@ def __init__(self, reduce_type=reduce_type, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class GraphMulticlassClassification(_GraphClassification, @@ -419,7 +431,8 @@ def __init__(self, reduce_type: str = "mean", name: str = "classification_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Graph multiclass classification from pooled node states. Args: @@ -439,6 +452,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + classification logits layer. """ super().__init__( node_set_name, @@ -449,7 +465,9 @@ def __init__(self, reduce_type=reduce_type, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class RootNodeBinaryClassification(_RootNodeClassification, @@ -463,7 +481,8 @@ def __init__(self, state_name: str = tfgnn.HIDDEN_STATE, name: str = "classification_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Root node binary (or multi-label) classification. This task performs binary classification (or multiple independent ones: @@ -486,6 +505,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + classification logits layer. """ super().__init__( node_set_name, @@ -493,7 +515,9 @@ def __init__(self, state_name=state_name, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class RootNodeMulticlassClassification(_RootNodeClassification, @@ -509,7 +533,8 @@ def __init__(self, state_name: str = tfgnn.HIDDEN_STATE, name: str = "classification_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Root node multiclass classification. This task can be used on graph datasets without a readout structure. @@ -532,6 +557,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + classification logits layer. """ super().__init__( node_set_name, @@ -541,7 +569,9 @@ def __init__(self, state_name=state_name, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class NodeBinaryClassification(_NodeClassification, _BinaryClassification): @@ -556,7 +586,8 @@ def __init__(self, validate: bool = True, name: str = "classification_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Node binary (or multi-label) classification. This task performs binary classification (or multiple independent ones: @@ -582,6 +613,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + classification logits layer. """ super().__init__( key, @@ -591,7 +625,9 @@ def __init__(self, validate=validate, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class NodeMulticlassClassification(_NodeClassification, @@ -609,7 +645,8 @@ def __init__(self, per_class_statistics: bool = False, name: str = "classification_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Node multiclass classification via structured readout. Args: @@ -635,6 +672,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + classification logits layer. """ super().__init__( key, @@ -646,4 +686,6 @@ def __init__(self, per_class_statistics=per_class_statistics, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) diff --git a/tensorflow_gnn/runner/tasks/classification_test.py b/tensorflow_gnn/runner/tasks/classification_test.py index dc30898e..325ff91b 100644 --- a/tensorflow_gnn/runner/tasks/classification_test.py +++ b/tensorflow_gnn/runner/tasks/classification_test.py @@ -53,6 +53,10 @@ def fn(inputs): return fn +def l2(rate): + return tf.keras.regularizers.L2(rate) + + def add_readout_from_first_node(gt: GraphTensor) -> GraphTensor: return tfgnn.add_readout_from_first_node( gt, @@ -205,63 +209,76 @@ def test_preprocess( testcase_name="GraphBinaryClassification", task=classification.GraphBinaryClassification( "nodes", - label_fn=label_fn(2)), + label_fn=label_fn(2), + kernel_regularizer=l2(0.125)), gt=TEST_GRAPH_TENSOR, expected_loss=tf.keras.losses.BinaryCrossentropy, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.125), dict( testcase_name="GraphMulticlassClassification", task=classification.GraphMulticlassClassification( "nodes", num_classes=4, - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.25)), gt=context_readout_into_feature(4, TEST_GRAPH_TENSOR), expected_loss=tf.keras.losses.SparseCategoricalCrossentropy, - expected_shape=tf.TensorShape((None, 4))), + expected_shape=tf.TensorShape((None, 4)), + expected_l2_regularization=0.25), dict( testcase_name="RootNodeBinaryClassification", task=classification.RootNodeBinaryClassification( "nodes", - label_fn=label_fn(2)), + label_fn=label_fn(2), + kernel_regularizer=l2(0.5)), gt=TEST_GRAPH_TENSOR, expected_loss=tf.keras.losses.BinaryCrossentropy, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.5), dict( testcase_name="RootNodeMulticlassClassification", task=classification.RootNodeMulticlassClassification( "nodes", num_classes=3, - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.75)), gt=context_readout_into_feature(3, TEST_GRAPH_TENSOR), expected_loss=tf.keras.losses.SparseCategoricalCrossentropy, - expected_shape=tf.TensorShape((None, 3))), + expected_shape=tf.TensorShape((None, 3)), + expected_l2_regularization=0.75), dict( testcase_name="NodeBinaryClassification", task=classification.NodeBinaryClassification( READOUT_KEY, - label_fn=label_fn(2)), + label_fn=label_fn(2), + kernel_regularizer=l2(1.0)), gt=add_readout_from_first_node(TEST_GRAPH_TENSOR), expected_loss=tf.keras.losses.BinaryCrossentropy, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=1.0), dict( testcase_name="NodeMulticlassClassification", task=classification.NodeMulticlassClassification( READOUT_KEY, num_classes=3, - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.375)), gt=add_readout_from_first_node(context_readout_into_feature( 3, TEST_GRAPH_TENSOR)), expected_loss=tf.keras.losses.SparseCategoricalCrossentropy, - expected_shape=tf.TensorShape((None, 3))), + expected_shape=tf.TensorShape((None, 3)), + expected_l2_regularization=0.375), ]) def test_predict( self, task: interfaces.Task, gt: GraphTensor, expected_loss: Type[tf.keras.losses.Loss], - expected_shape: tf.TensorShape): - # Assert head readout, activation and shape. + expected_shape: tf.TensorShape, + expected_l2_regularization: float): + # Assert head readout, activation, shape and regularization. inputs = tf.keras.layers.Input(type_spec=gt.spec) model = tf.keras.Model(inputs, task.predict(inputs)) self.assertLen(model.layers, 3) @@ -279,6 +296,8 @@ def test_predict( _, _, dense = model.layers self.assertEqual(dense.get_config()["activation"], "linear") self.assertTrue(expected_shape.is_compatible_with(dense.output_shape)) + self.assertEqual(dense.kernel_regularizer.get_config()["l2"], + expected_l2_regularization) # Assert losses. loss = task.losses() diff --git a/tensorflow_gnn/runner/tasks/regression.py b/tensorflow_gnn/runner/tasks/regression.py index 977c0791..c5e3b696 100644 --- a/tensorflow_gnn/runner/tasks/regression.py +++ b/tensorflow_gnn/runner/tasks/regression.py @@ -16,7 +16,7 @@ from __future__ import annotations import abc -from typing import Callable, Optional, Tuple +from typing import Any, Callable, Optional, Tuple import tensorflow as tf import tensorflow_gnn as tfgnn @@ -43,7 +43,9 @@ def __init__( *, name: str = "regression_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None, + ): """Sets `Task` parameters. Args: @@ -56,6 +58,9 @@ def __init__( label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ if (label_fn is None) == (label_feature_name is None): raise ValueError( @@ -67,6 +72,7 @@ def __init__( self._name = name self._label_fn = label_fn self._label_feature_name = label_feature_name + self._kernel_regularizer = kernel_regularizer @abc.abstractmethod def gather_activations(self, inputs: GraphTensor) -> Field: @@ -84,7 +90,7 @@ def predict(self, inputs: tfgnn.GraphTensor) -> interfaces.Predictions: tfgnn.check_scalar_graph_tensor(inputs, name="_Regression") activations = self.gather_activations(inputs) logits = tf.keras.layers.Dense( - self._units, + self._units, kernel_regularizer=self._kernel_regularizer, name=self._name)(activations) return logits @@ -310,7 +316,8 @@ def __init__(self, reduce_type: str = "mean", name: str = "regression_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Regression from pooled node states with mean absolute error. Args: @@ -326,6 +333,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( node_set_name, @@ -334,7 +344,9 @@ def __init__(self, reduce_type=reduce_type, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class GraphMeanAbsolutePercentageError(_MeanAbsolutePercentageErrorLossMixIn, @@ -349,7 +361,8 @@ def __init__(self, reduce_type: str = "mean", name: str = "regression_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Regression from pooled node states with mean absolute percentage error. Args: @@ -365,6 +378,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( node_set_name, @@ -373,7 +389,9 @@ def __init__(self, reduce_type=reduce_type, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class GraphMeanSquaredError(_MeanSquaredErrorLossMixIn, _GraphRegression): @@ -387,7 +405,8 @@ def __init__(self, reduce_type: str = "mean", name: str = "regression_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Regression from pooled node states with mean squared error. Args: @@ -403,6 +422,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( node_set_name, @@ -411,7 +433,9 @@ def __init__(self, reduce_type=reduce_type, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class GraphMeanSquaredLogarithmicError(_MeanSquaredLogarithmicErrorLossMixIn, @@ -426,7 +450,8 @@ def __init__(self, reduce_type: str = "mean", name: str = "regression_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Regression from pooled node states with mean squared logarithmic error. Args: @@ -442,6 +467,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( node_set_name, @@ -450,7 +478,9 @@ def __init__(self, reduce_type=reduce_type, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class GraphMeanSquaredLogScaledError(_MeanSquaredLogScaledErrorLossMixIn, @@ -468,7 +498,8 @@ def __init__(self, label_feature_name: Optional[str] = None, alpha_loss_param: float = 5., epsilon_loss_param: float = 1e-8, - reduction: tf.keras.losses.Reduction = AUTO): + reduction: tf.keras.losses.Reduction = AUTO, + kernel_regularizer: Any = None): """Regression from pooled node states with mean squared log scaled error. Args: @@ -487,6 +518,9 @@ def __init__(self, alpha_loss_param: Alpha for the mean squared log scaled error. epsilon_loss_param: Epsilon for the mean squared log scaled error. reduction: Reduction for the mean squared log scaled error. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( node_set_name, @@ -498,7 +532,9 @@ def __init__(self, label_feature_name=label_feature_name, alpha_loss_param=alpha_loss_param, epsilon_loss_param=epsilon_loss_param, - reduction=reduction) + reduction=reduction, + kernel_regularizer=kernel_regularizer, + ) class RootNodeMeanAbsoluteError(_MeanAbsoluteErrorLossMixIn, @@ -512,7 +548,8 @@ def __init__(self, state_name: str = tfgnn.HIDDEN_STATE, name: str = "regression_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Root node regression with mean absolute error. Args: @@ -527,6 +564,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( node_set_name, @@ -534,7 +574,9 @@ def __init__(self, state_name=state_name, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class RootNodeMeanAbsolutePercentageError(_MeanAbsolutePercentageErrorLossMixIn, @@ -548,7 +590,8 @@ def __init__(self, state_name: str = tfgnn.HIDDEN_STATE, name: str = "regression_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Root node regression with mean absolute percentage error. Args: @@ -563,6 +606,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( node_set_name, @@ -570,7 +616,9 @@ def __init__(self, state_name=state_name, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class RootNodeMeanSquaredError(_MeanSquaredErrorLossMixIn, _RootNodeRegression): @@ -583,7 +631,8 @@ def __init__(self, state_name: str = tfgnn.HIDDEN_STATE, name: str = "regression_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Root node regression with mean squared error. Args: @@ -598,6 +647,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( node_set_name, @@ -605,7 +657,9 @@ def __init__(self, state_name=state_name, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class RootNodeMeanSquaredLogarithmicError(_MeanSquaredLogarithmicErrorLossMixIn, @@ -619,7 +673,8 @@ def __init__(self, state_name: str = tfgnn.HIDDEN_STATE, name: str = "regression_logits", label_fn: Optional[LabelFn] = None, - label_feature_name: Optional[str] = None): + label_feature_name: Optional[str] = None, + kernel_regularizer: Any = None): """Root node regression with mean squared logarithmic error. Args: @@ -634,6 +689,9 @@ def __init__(self, label_feature_name: A label feature name for readout from the auxiliary '_readout' node set. Readout does not mutate the input `GraphTensor`. Mutually exclusive with `label_fn`. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( node_set_name, @@ -641,7 +699,9 @@ def __init__(self, state_name=state_name, name=name, label_fn=label_fn, - label_feature_name=label_feature_name) + label_feature_name=label_feature_name, + kernel_regularizer=kernel_regularizer, + ) class RootNodeMeanSquaredLogScaledError(_MeanSquaredLogScaledErrorLossMixIn, @@ -658,7 +718,8 @@ def __init__(self, label_feature_name: Optional[str] = None, alpha_loss_param: float = 5., epsilon_loss_param: float = 1e-8, - reduction: tf.keras.losses.Reduction = AUTO): + reduction: tf.keras.losses.Reduction = AUTO, + kernel_regularizer: Any = None): """Root node regression with mean squared log scaled error. Args: @@ -676,6 +737,9 @@ def __init__(self, alpha_loss_param: Alpha for the mean squared log scaled error. epsilon_loss_param: Epsilon for the mean squared log scaled error. reduction: Reduction for the mean squared log scaled error. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( node_set_name, @@ -686,7 +750,9 @@ def __init__(self, label_feature_name=label_feature_name, alpha_loss_param=alpha_loss_param, epsilon_loss_param=epsilon_loss_param, - reduction=reduction) + reduction=reduction, + kernel_regularizer=kernel_regularizer, + ) class NodeMeanAbsoluteError(_MeanAbsoluteErrorLossMixIn, _NodeRegression): @@ -701,7 +767,8 @@ def __init__(self, label_feature_name: Optional[str] = None, feature_name: str = tfgnn.HIDDEN_STATE, readout_node_set: tfgnn.NodeSetName = "_readout", - validate: bool = True): + validate: bool = True, + kernel_regularizer: Any = None): """Node regression with mean absolute error via structured readout. This task defines regression via structured readout (see @@ -730,6 +797,9 @@ def __init__(self, auxiliary edge sets. This is stronlgy discouraged, unless great care is taken to run `tfgnn.validate_graph_tensor_for_readout()` earlier on structurally unchanged GraphTensors. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( key, @@ -739,7 +809,9 @@ def __init__(self, label_feature_name=label_feature_name, feature_name=feature_name, readout_node_set=readout_node_set, - validate=validate) + validate=validate, + kernel_regularizer=kernel_regularizer, + ) class NodeMeanAbsolutePercentageError(_MeanAbsolutePercentageErrorLossMixIn, @@ -756,7 +828,8 @@ def __init__(self, label_feature_name: Optional[str] = None, feature_name: str = tfgnn.HIDDEN_STATE, readout_node_set: tfgnn.NodeSetName = "_readout", - validate: bool = True): + validate: bool = True, + kernel_regularizer: Any = None): """Node regression with mean absolute percentage error via structured readout. This task defines regression via structured readout (see @@ -785,6 +858,9 @@ def __init__(self, auxiliary edge sets. This is stronlgy discouraged, unless great care is taken to run `tfgnn.validate_graph_tensor_for_readout()` earlier on structurally unchanged GraphTensors. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( key, @@ -794,7 +870,9 @@ def __init__(self, label_feature_name=label_feature_name, feature_name=feature_name, readout_node_set=readout_node_set, - validate=validate) + validate=validate, + kernel_regularizer=kernel_regularizer, + ) class NodeMeanSquaredError(_MeanSquaredErrorLossMixIn, _NodeRegression): @@ -809,7 +887,8 @@ def __init__(self, label_feature_name: Optional[str] = None, feature_name: str = tfgnn.HIDDEN_STATE, readout_node_set: tfgnn.NodeSetName = "_readout", - validate: bool = True): + validate: bool = True, + kernel_regularizer: Any = None): """Node regression with mean squared error via structured readout. This task defines regression via structured readout (see @@ -838,6 +917,9 @@ def __init__(self, auxiliary edge sets. This is stronlgy discouraged, unless great care is taken to run `tfgnn.validate_graph_tensor_for_readout()` earlier on structurally unchanged GraphTensors. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( key, @@ -847,7 +929,9 @@ def __init__(self, label_feature_name=label_feature_name, feature_name=feature_name, readout_node_set=readout_node_set, - validate=validate) + validate=validate, + kernel_regularizer=kernel_regularizer, + ) class NodeMeanSquaredLogarithmicError(_MeanSquaredLogarithmicErrorLossMixIn, @@ -863,7 +947,8 @@ def __init__(self, label_feature_name: Optional[str] = None, feature_name: str = tfgnn.HIDDEN_STATE, readout_node_set: tfgnn.NodeSetName = "_readout", - validate: bool = True): + validate: bool = True, + kernel_regularizer: Any = None): """Node regression with mean squared log error via structured readout. This task defines regression via structured readout (see @@ -892,6 +977,9 @@ def __init__(self, auxiliary edge sets. This is stronlgy discouraged, unless great care is taken to run `tfgnn.validate_graph_tensor_for_readout()` earlier on structurally unchanged GraphTensors. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( key, @@ -901,7 +989,9 @@ def __init__(self, label_feature_name=label_feature_name, feature_name=feature_name, readout_node_set=readout_node_set, - validate=validate) + validate=validate, + kernel_regularizer=kernel_regularizer, + ) class NodeMeanSquaredLogScaledError(_MeanSquaredLogScaledErrorLossMixIn, @@ -921,7 +1011,8 @@ def __init__(self, validate: bool = True, alpha_loss_param: float = 5., epsilon_loss_param: float = 1e-8, - reduction: tf.keras.losses.Reduction = AUTO): + reduction: tf.keras.losses.Reduction = AUTO, + kernel_regularizer: Any = None): """Node regression with mean squared log scaled error via structured readout. This task defines regression via structured readout (see @@ -953,6 +1044,9 @@ def __init__(self, alpha_loss_param: Alpha for the mean squared log scaled error. epsilon_loss_param: Epsilon for the mean squared log scaled error. reduction: Reduction for the mean squared log scaled error. + kernel_regularizer: Can be set to a `kernel_regularizer` as understood + by `tf.keras.layers.Dense` etc. to perform weight regularization of the + output layer. """ super().__init__( key, @@ -965,4 +1059,6 @@ def __init__(self, validate=validate, alpha_loss_param=alpha_loss_param, epsilon_loss_param=epsilon_loss_param, - reduction=reduction) + reduction=reduction, + kernel_regularizer=kernel_regularizer, + ) diff --git a/tensorflow_gnn/runner/tasks/regression_test.py b/tensorflow_gnn/runner/tasks/regression_test.py index 5ef4a911..7c19282d 100644 --- a/tensorflow_gnn/runner/tasks/regression_test.py +++ b/tensorflow_gnn/runner/tasks/regression_test.py @@ -51,6 +51,10 @@ def label_fn(inputs: GraphTensor) -> tuple[GraphTensor, Field]: return x, y +def l2(rate): + return tf.keras.regularizers.L2(rate) + + def add_readout_from_first_node(gt: GraphTensor) -> GraphTensor: return tfgnn.add_readout_from_first_node( gt, @@ -356,138 +360,169 @@ def test_preprocess( testcase_name="GraphMeanAbsoluteError", task=regression.GraphMeanAbsoluteError( "nodes", - label_fn=label_fn), + label_fn=label_fn, + kernel_regularizer=l2(0.125)), gt=TEST_GRAPH_TENSOR, expected_loss=tf.keras.losses.MeanAbsoluteError, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.125), dict( testcase_name="GraphMeanAbsolutePercentageError", task=regression.GraphMeanAbsolutePercentageError( "nodes", - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.125)), gt=context_readout_into_feature(TEST_GRAPH_TENSOR), expected_loss=tf.keras.losses.MeanAbsolutePercentageError, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.125), dict( testcase_name="GraphMeanSquaredError", task=regression.GraphMeanSquaredError( "nodes", - label_fn=label_fn), + label_fn=label_fn, + kernel_regularizer=l2(0.125)), gt=TEST_GRAPH_TENSOR, expected_loss=tf.keras.losses.MeanSquaredError, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.125), dict( testcase_name="GraphMeanSquaredLogarithmicError", task=regression.GraphMeanSquaredLogarithmicError( "nodes", units=3, - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.125)), gt=context_readout_into_feature(TEST_GRAPH_TENSOR), expected_loss=tf.keras.losses.MeanSquaredLogarithmicError, - expected_shape=tf.TensorShape((None, 3))), + expected_shape=tf.TensorShape((None, 3)), + expected_l2_regularization=0.125), dict( testcase_name="GraphMeanSquaredLogScaledError", task=regression.GraphMeanSquaredLogScaledError( "nodes", units=2, - label_fn=label_fn), + label_fn=label_fn, + kernel_regularizer=l2(0.125)), gt=TEST_GRAPH_TENSOR, expected_loss=regression.MeanSquaredLogScaledError, - expected_shape=tf.TensorShape((None, 2))), + expected_shape=tf.TensorShape((None, 2)), + expected_l2_regularization=0.125), dict( testcase_name="RootNodeMeanAbsoluteError", task=regression.RootNodeMeanAbsoluteError( "nodes", - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.125)), gt=context_readout_into_feature(TEST_GRAPH_TENSOR), expected_loss=tf.keras.losses.MeanAbsoluteError, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.125), dict( testcase_name="RootNodeMeanAbsolutePercentageError", task=regression.RootNodeMeanAbsolutePercentageError( "nodes", - label_fn=label_fn), + label_fn=label_fn, + kernel_regularizer=l2(0.125)), gt=TEST_GRAPH_TENSOR, expected_loss=tf.keras.losses.MeanAbsolutePercentageError, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.125), dict( testcase_name="RootNodeMeanSquaredError", task=regression.RootNodeMeanSquaredError( "nodes", - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.125)), gt=context_readout_into_feature(TEST_GRAPH_TENSOR), expected_loss=tf.keras.losses.MeanSquaredError, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.125), dict( testcase_name="RootNodeMeanSquaredLogarithmicError", task=regression.RootNodeMeanSquaredLogarithmicError( "nodes", units=3, - label_fn=label_fn), + label_fn=label_fn, + kernel_regularizer=l2(0.125)), gt=TEST_GRAPH_TENSOR, expected_loss=tf.keras.losses.MeanSquaredLogarithmicError, - expected_shape=tf.TensorShape((None, 3))), + expected_shape=tf.TensorShape((None, 3)), + expected_l2_regularization=0.125), dict( testcase_name="RootNodeMeanSquaredLogScaledError", task=regression.RootNodeMeanSquaredLogScaledError( "nodes", units=2, - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.125)), gt=context_readout_into_feature(TEST_GRAPH_TENSOR), expected_loss=regression.MeanSquaredLogScaledError, - expected_shape=tf.TensorShape((None, 2))), + expected_shape=tf.TensorShape((None, 2)), + expected_l2_regularization=0.125), dict( testcase_name="NodeMeanAbsoluteError", task=regression.NodeMeanAbsoluteError( READOUT_KEY, - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.125)), gt=add_readout_from_first_node( context_readout_into_feature(TEST_GRAPH_TENSOR)), expected_loss=tf.keras.losses.MeanAbsoluteError, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.125), dict( testcase_name="NodeMeanAbsolutePercentageError", task=regression.NodeMeanAbsolutePercentageError( READOUT_KEY, - label_fn=label_fn), + label_fn=label_fn, + kernel_regularizer=l2(0.125)), gt=add_readout_from_first_node(TEST_GRAPH_TENSOR), expected_loss=tf.keras.losses.MeanAbsolutePercentageError, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.125), dict( testcase_name="NodeMeanSquaredError", task=regression.NodeMeanSquaredError( READOUT_KEY, - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.125)), gt=add_readout_from_first_node( context_readout_into_feature(TEST_GRAPH_TENSOR)), expected_loss=tf.keras.losses.MeanSquaredError, - expected_shape=tf.TensorShape((None, 1))), + expected_shape=tf.TensorShape((None, 1)), + expected_l2_regularization=0.125), dict( testcase_name="NodeMeanSquaredLogarithmicError", task=regression.NodeMeanSquaredLogarithmicError( READOUT_KEY, units=3, - label_fn=label_fn), + label_fn=label_fn, + kernel_regularizer=l2(0.125)), gt=add_readout_from_first_node(TEST_GRAPH_TENSOR), expected_loss=tf.keras.losses.MeanSquaredLogarithmicError, - expected_shape=tf.TensorShape((None, 3))), + expected_shape=tf.TensorShape((None, 3)), + expected_l2_regularization=0.125), dict( testcase_name="NodeMeanSquaredLogScaledError", task=regression.NodeMeanSquaredLogScaledError( READOUT_KEY, units=2, - label_feature_name="labels"), + label_feature_name="labels", + kernel_regularizer=l2(0.125)), gt=add_readout_from_first_node( context_readout_into_feature(TEST_GRAPH_TENSOR)), expected_loss=regression.MeanSquaredLogScaledError, - expected_shape=tf.TensorShape((None, 2))), + expected_shape=tf.TensorShape((None, 2)), + expected_l2_regularization=0.125), ]) def test_predict( self, task: interfaces.Task, gt: GraphTensor, expected_loss: Type[tf.keras.losses.Loss], - expected_shape: tf.TensorShape): + expected_shape: tf.TensorShape, + expected_l2_regularization: float): # Assert head readout, activation and shape. inputs = tf.keras.layers.Input(type_spec=gt.spec) model = tf.keras.Model(inputs, task.predict(inputs)) @@ -505,6 +540,8 @@ def test_predict( _, _, dense = model.layers self.assertTrue(expected_shape.is_compatible_with(dense.output_shape)) + self.assertEqual(dense.kernel_regularizer.get_config()["l2"], + expected_l2_regularization) # Assert losses. loss = task.losses()