Skip to content

Commit

Permalink
Make a default preprocessing fn for contrastive losses.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567070380
  • Loading branch information
xgfs authored and tensorflower-gardener committed Sep 20, 2023
1 parent 6747895 commit 3295009
Showing 1 changed file with 8 additions and 17 deletions.
25 changes: 8 additions & 17 deletions tensorflow_gnn/models/contrastive_losses/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,6 @@ class ContrastiveLossTask(runner.Task):
Any model-specific preprocessing should be implemented in the `preprocess`.
"""
# TODO(tsitsulin): move `preprocess` here.

def __init__(
self,
Expand Down Expand Up @@ -96,6 +95,14 @@ def __init__(
else:
self._projector = None

def preprocess(
self, inputs: GraphTensor
) -> tuple[Sequence[GraphTensor], runner.Predictions]:
"""Applies a `Corruptor` and returns empty pseudo-labels."""
x = (inputs, self._corruptor(inputs))
y = tf.zeros((inputs.num_components, 0), dtype=tf.int32)
return x, y

def predict(self, *args: tfgnn.GraphTensor) -> runner.Predictions:
"""Apply a readout head for use with various contrastive losses.
Expand Down Expand Up @@ -249,14 +256,6 @@ def __init__(
def make_contrastive_layer(self) -> tf.keras.layers.Layer:
return tf.keras.layers.Layer()

def preprocess(
self, inputs: GraphTensor
) -> tuple[Sequence[GraphTensor], Field]:
"""Creates unused pseudo-labels."""
x = (inputs, self._corruptor(inputs))
y = tf.zeros((inputs.num_components, 0), dtype=tf.int32)
return x, y

def losses(self) -> runner.Losses:
def loss_fn(_, x):
return losses.barlow_twins_loss(
Expand Down Expand Up @@ -290,14 +289,6 @@ def __init__(
def make_contrastive_layer(self) -> tf.keras.layers.Layer:
return tf.keras.layers.Layer()

def preprocess(
self, inputs: GraphTensor
) -> tuple[Sequence[GraphTensor], Field]:
"""Creates unused pseudo-labels."""
x = (inputs, self._corruptor(inputs))
y = tf.zeros((inputs.num_components, 0), dtype=tf.int32)
return x, y

def losses(self) -> runner.Losses:
def loss_fn(_, x):
return losses.vicreg_loss(
Expand Down

0 comments on commit 3295009

Please sign in to comment.