From ec9bd661cf6cd4e770e377c2d5ade4d001736827 Mon Sep 17 00:00:00 2001 From: Parth Kothari Date: Wed, 27 Sep 2023 06:44:44 -0700 Subject: [PATCH] Expose Corruptor in contrastive models __init__.py PiperOrigin-RevId: 568835418 --- tensorflow_gnn/models/contrastive_losses/__init__.py | 1 + tensorflow_gnn/models/contrastive_losses/layers.py | 6 +++--- tensorflow_gnn/models/contrastive_losses/layers_test.py | 6 +++--- tensorflow_gnn/models/contrastive_losses/tasks.py | 2 +- 4 files changed, 8 insertions(+), 7 deletions(-) diff --git a/tensorflow_gnn/models/contrastive_losses/__init__.py b/tensorflow_gnn/models/contrastive_losses/__init__.py index 320d54be..ab9e5737 100644 --- a/tensorflow_gnn/models/contrastive_losses/__init__.py +++ b/tensorflow_gnn/models/contrastive_losses/__init__.py @@ -25,6 +25,7 @@ from tensorflow_gnn.models.contrastive_losses import tasks CorruptionSpec = layers.CorruptionSpec +Corruptor = layers.Corruptor DeepGraphInfomaxLogits = layers.DeepGraphInfomaxLogits DropoutFeatures = layers.DropoutFeatures ShuffleFeaturesGlobally = layers.ShuffleFeaturesGlobally diff --git a/tensorflow_gnn/models/contrastive_losses/layers.py b/tensorflow_gnn/models/contrastive_losses/layers.py index 6e734b4d..cb47df48 100644 --- a/tensorflow_gnn/models/contrastive_losses/layers.py +++ b/tensorflow_gnn/models/contrastive_losses/layers.py @@ -82,7 +82,7 @@ def with_default(self, default: T): ) -class _Corruptor(tfgnn.keras.layers.MapFeatures, Generic[T]): +class Corruptor(tfgnn.keras.layers.MapFeatures, Generic[T]): """Base class for graph corruptor.""" def __init__( @@ -166,7 +166,7 @@ def wrapper_fn(tensor, rate): @tf.keras.utils.register_keras_serializable(package=_PACKAGE) -class ShuffleFeaturesGlobally(_Corruptor[float]): +class ShuffleFeaturesGlobally(Corruptor[float]): """A corruptor that shuffles features.""" def __init__(self, *args, seed: Optional[float] = None, **kwargs): @@ -179,7 +179,7 @@ def get_config(self): @tf.keras.utils.register_keras_serializable(package=_PACKAGE) -class DropoutFeatures(_Corruptor[float]): +class DropoutFeatures(Corruptor[float]): def __init__(self, *args, seed: Optional[float] = None, **kwargs): self._seed = seed diff --git a/tensorflow_gnn/models/contrastive_losses/layers_test.py b/tensorflow_gnn/models/contrastive_losses/layers_test.py index 10d4d4da..0925ae11 100644 --- a/tensorflow_gnn/models/contrastive_losses/layers_test.py +++ b/tensorflow_gnn/models/contrastive_losses/layers_test.py @@ -277,7 +277,7 @@ def test_corrupt( ]) def test_shuffle_features_globally( self, - corruptor: layers._Corruptor, + corruptor: layers.Corruptor, context: tfgnn.Context, node_set: tfgnn.NodeSet, edge_set: tfgnn.EdgeSet, @@ -346,7 +346,7 @@ def test_shuffle_features_globally( ]) def test_dropout_features( self, - corruptor: layers._Corruptor, + corruptor: layers.Corruptor, context: tfgnn.Context, node_set: tfgnn.NodeSet, edge_set: tfgnn.EdgeSet, @@ -371,7 +371,7 @@ def test_dropout_features( def test_throws_empty_spec_error(self): with self.assertRaisesRegex(ValueError, r"At least one of .*"): - _ = layers._Corruptor(corruption_fn=lambda: None) + _ = layers.Corruptor(corruption_fn=lambda: None) @parameterized.named_parameters([ dict( diff --git a/tensorflow_gnn/models/contrastive_losses/tasks.py b/tensorflow_gnn/models/contrastive_losses/tasks.py index c40b547f..53270f8f 100644 --- a/tensorflow_gnn/models/contrastive_losses/tasks.py +++ b/tensorflow_gnn/models/contrastive_losses/tasks.py @@ -60,7 +60,7 @@ def __init__( *, feature_name: str = tfgnn.HIDDEN_STATE, representations_layer_name: Optional[str] = None, - corruptor: Optional[layers._Corruptor] = None, + corruptor: Optional[layers.Corruptor] = None, projector_units: Optional[Sequence[int]] = None, seed: Optional[int] = None, ):