Skip to content

Commit

Permalink
Expose ContrastiveLossTask interface.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 567004173
  • Loading branch information
xgfs authored and tensorflower-gardener committed Sep 20, 2023
1 parent 30b220e commit c99af3f
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 8 deletions.
1 change: 1 addition & 0 deletions tensorflow_gnn/models/contrastive_losses/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
TripletEmbeddingSquaredDistances = layers.TripletEmbeddingSquaredDistances

BarlowTwinsTask = tasks.BarlowTwinsTask
ContrastiveLossTask = tasks.ContrastiveLossTask
DeepGraphInfomaxTask = tasks.DeepGraphInfomaxTask
TripletLossTask = tasks.TripletLossTask
VicRegTask = tasks.VicRegTask
Expand Down
21 changes: 13 additions & 8 deletions tensorflow_gnn/models/contrastive_losses/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,12 +32,16 @@
GraphTensor = tfgnn.GraphTensor


class _ConstrastiveLossTask(runner.Task, abc.ABC):
class ConstrastiveLossTask(runner.Task):
"""Base class for unsupervised contrastive representation learning tasks.
The default `predict` method implementation shuffles feature across batch
examples to create positive and negative activations. There are multiple ways
proposed in the literature to learn representations based on the activations.
The process is separated into preprocessing and contrastive parts, with the
focus on reusability of individual components. The `preprocess` produces
input GraphTensors to be used with the `predict` as well as labels for the
task. The default `predict` method implementation expects a pair of positive
and negative GraphTensors. There are multiple ways proposed in the literature
to learn representations based on the activations - we achieve that by using
custom losses.
Any subclass must implement `make_contrastive_layer` method, which produces
the final prediction outputs.
Expand All @@ -49,6 +53,7 @@ class _ConstrastiveLossTask(runner.Task, abc.ABC):
Any model-specific preprocessing should be implemented in the `preprocess`.
"""
# TODO(tsitsulin): move `preprocess` here.

def __init__(
self,
Expand Down Expand Up @@ -167,7 +172,7 @@ def call(
return tf.zeros_like(y_pred[..., 0])


class DeepGraphInfomaxTask(_ConstrastiveLossTask):
class DeepGraphInfomaxTask(ConstrastiveLossTask):
"""A Deep Graph Infomax (DGI) Task."""

def __init__(
Expand Down Expand Up @@ -227,7 +232,7 @@ def wrapper_fn(_, y_pred):
return wrapper_fn


class BarlowTwinsTask(_ConstrastiveLossTask):
class BarlowTwinsTask(ConstrastiveLossTask):
"""A Barlow Twins (BT) Task."""

def __init__(
Expand Down Expand Up @@ -266,7 +271,7 @@ def metrics(self) -> runner.Metrics:
return (metrics.AllSvdMetrics(),)


class VicRegTask(_ConstrastiveLossTask):
class VicRegTask(ConstrastiveLossTask):
"""A VICReg Task."""

def __init__(
Expand Down Expand Up @@ -308,7 +313,7 @@ def metrics(self) -> runner.Metrics:
return (metrics.AllSvdMetrics(),)


class TripletLossTask(_ConstrastiveLossTask):
class TripletLossTask(ConstrastiveLossTask):
"""The triplet loss task."""

def __init__(self, *args, margin: float = 1.0, **kwargs):
Expand Down

0 comments on commit c99af3f

Please sign in to comment.