Skip to content

Commit

Permalink
Expose Corruptor in contrastive models __init__.py
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568835418
  • Loading branch information
theDebugger811 authored and tensorflower-gardener committed Sep 27, 2023
1 parent d6298ab commit ec9bd66
Show file tree
Hide file tree
Showing 4 changed files with 8 additions and 7 deletions.
1 change: 1 addition & 0 deletions tensorflow_gnn/models/contrastive_losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
from tensorflow_gnn.models.contrastive_losses import tasks

CorruptionSpec = layers.CorruptionSpec
Corruptor = layers.Corruptor
DeepGraphInfomaxLogits = layers.DeepGraphInfomaxLogits
DropoutFeatures = layers.DropoutFeatures
ShuffleFeaturesGlobally = layers.ShuffleFeaturesGlobally
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_gnn/models/contrastive_losses/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ def with_default(self, default: T):
)


class _Corruptor(tfgnn.keras.layers.MapFeatures, Generic[T]):
class Corruptor(tfgnn.keras.layers.MapFeatures, Generic[T]):
"""Base class for graph corruptor."""

def __init__(
Expand Down Expand Up @@ -166,7 +166,7 @@ def wrapper_fn(tensor, rate):


@tf.keras.utils.register_keras_serializable(package=_PACKAGE)
class ShuffleFeaturesGlobally(_Corruptor[float]):
class ShuffleFeaturesGlobally(Corruptor[float]):
"""A corruptor that shuffles features."""

def __init__(self, *args, seed: Optional[float] = None, **kwargs):
Expand All @@ -179,7 +179,7 @@ def get_config(self):


@tf.keras.utils.register_keras_serializable(package=_PACKAGE)
class DropoutFeatures(_Corruptor[float]):
class DropoutFeatures(Corruptor[float]):

def __init__(self, *args, seed: Optional[float] = None, **kwargs):
self._seed = seed
Expand Down
6 changes: 3 additions & 3 deletions tensorflow_gnn/models/contrastive_losses/layers_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -277,7 +277,7 @@ def test_corrupt(
])
def test_shuffle_features_globally(
self,
corruptor: layers._Corruptor,
corruptor: layers.Corruptor,
context: tfgnn.Context,
node_set: tfgnn.NodeSet,
edge_set: tfgnn.EdgeSet,
Expand Down Expand Up @@ -346,7 +346,7 @@ def test_shuffle_features_globally(
])
def test_dropout_features(
self,
corruptor: layers._Corruptor,
corruptor: layers.Corruptor,
context: tfgnn.Context,
node_set: tfgnn.NodeSet,
edge_set: tfgnn.EdgeSet,
Expand All @@ -371,7 +371,7 @@ def test_dropout_features(

def test_throws_empty_spec_error(self):
with self.assertRaisesRegex(ValueError, r"At least one of .*"):
_ = layers._Corruptor(corruption_fn=lambda: None)
_ = layers.Corruptor(corruption_fn=lambda: None)

@parameterized.named_parameters([
dict(
Expand Down
2 changes: 1 addition & 1 deletion tensorflow_gnn/models/contrastive_losses/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ def __init__(
*,
feature_name: str = tfgnn.HIDDEN_STATE,
representations_layer_name: Optional[str] = None,
corruptor: Optional[layers._Corruptor] = None,
corruptor: Optional[layers.Corruptor] = None,
projector_units: Optional[Sequence[int]] = None,
seed: Optional[int] = None,
):
Expand Down

0 comments on commit ec9bd66

Please sign in to comment.