diff --git a/tensorflow_gnn/models/contrastive_losses/tasks.py b/tensorflow_gnn/models/contrastive_losses/tasks.py index 7a57a945..c40b547f 100644 --- a/tensorflow_gnn/models/contrastive_losses/tasks.py +++ b/tensorflow_gnn/models/contrastive_losses/tasks.py @@ -53,7 +53,6 @@ class ContrastiveLossTask(runner.Task): Any model-specific preprocessing should be implemented in the `preprocess`. """ - # TODO(tsitsulin): move `preprocess` here. def __init__( self, @@ -96,6 +95,14 @@ def __init__( else: self._projector = None + def preprocess( + self, inputs: GraphTensor + ) -> tuple[Sequence[GraphTensor], runner.Predictions]: + """Applies a `Corruptor` and returns empty pseudo-labels.""" + x = (inputs, self._corruptor(inputs)) + y = tf.zeros((inputs.num_components, 0), dtype=tf.int32) + return x, y + def predict(self, *args: tfgnn.GraphTensor) -> runner.Predictions: """Apply a readout head for use with various contrastive losses. @@ -249,14 +256,6 @@ def __init__( def make_contrastive_layer(self) -> tf.keras.layers.Layer: return tf.keras.layers.Layer() - def preprocess( - self, inputs: GraphTensor - ) -> tuple[Sequence[GraphTensor], Field]: - """Creates unused pseudo-labels.""" - x = (inputs, self._corruptor(inputs)) - y = tf.zeros((inputs.num_components, 0), dtype=tf.int32) - return x, y - def losses(self) -> runner.Losses: def loss_fn(_, x): return losses.barlow_twins_loss( @@ -290,14 +289,6 @@ def __init__( def make_contrastive_layer(self) -> tf.keras.layers.Layer: return tf.keras.layers.Layer() - def preprocess( - self, inputs: GraphTensor - ) -> tuple[Sequence[GraphTensor], Field]: - """Creates unused pseudo-labels.""" - x = (inputs, self._corruptor(inputs)) - y = tf.zeros((inputs.num_components, 0), dtype=tf.int32) - return x, y - def losses(self) -> runner.Losses: def loss_fn(_, x): return losses.vicreg_loss(