diff --git a/tensorflow_gnn/models/contrastive_losses/__init__.py b/tensorflow_gnn/models/contrastive_losses/__init__.py index 1e557b04..320d54be 100644 --- a/tensorflow_gnn/models/contrastive_losses/__init__.py +++ b/tensorflow_gnn/models/contrastive_losses/__init__.py @@ -31,6 +31,7 @@ TripletEmbeddingSquaredDistances = layers.TripletEmbeddingSquaredDistances BarlowTwinsTask = tasks.BarlowTwinsTask +ContrastiveLossTask = tasks.ContrastiveLossTask DeepGraphInfomaxTask = tasks.DeepGraphInfomaxTask TripletLossTask = tasks.TripletLossTask VicRegTask = tasks.VicRegTask diff --git a/tensorflow_gnn/models/contrastive_losses/tasks.py b/tensorflow_gnn/models/contrastive_losses/tasks.py index 1a18556b..29114620 100644 --- a/tensorflow_gnn/models/contrastive_losses/tasks.py +++ b/tensorflow_gnn/models/contrastive_losses/tasks.py @@ -32,12 +32,16 @@ GraphTensor = tfgnn.GraphTensor -class _ConstrastiveLossTask(runner.Task, abc.ABC): +class ConstrastiveLossTask(runner.Task): """Base class for unsupervised contrastive representation learning tasks. - The default `predict` method implementation shuffles feature across batch - examples to create positive and negative activations. There are multiple ways - proposed in the literature to learn representations based on the activations. + The process is separated into preprocessing and contrastive parts, with the + focus on reusability of individual components. The `preprocess` produces + input GraphTensors to be used with the `predict` as well as labels for the + task. The default `predict` method implementation expects a pair of positive + and negative GraphTensors. There are multiple ways proposed in the literature + to learn representations based on the activations - we achieve that by using + custom losses. Any subclass must implement `make_contrastive_layer` method, which produces the final prediction outputs. @@ -49,6 +53,7 @@ class _ConstrastiveLossTask(runner.Task, abc.ABC): Any model-specific preprocessing should be implemented in the `preprocess`. """ + # TODO(tsitsulin): move `preprocess` here. def __init__( self, @@ -167,7 +172,7 @@ def call( return tf.zeros_like(y_pred[..., 0]) -class DeepGraphInfomaxTask(_ConstrastiveLossTask): +class DeepGraphInfomaxTask(ConstrastiveLossTask): """A Deep Graph Infomax (DGI) Task.""" def __init__( @@ -227,7 +232,7 @@ def wrapper_fn(_, y_pred): return wrapper_fn -class BarlowTwinsTask(_ConstrastiveLossTask): +class BarlowTwinsTask(ConstrastiveLossTask): """A Barlow Twins (BT) Task.""" def __init__( @@ -266,7 +271,7 @@ def metrics(self) -> runner.Metrics: return (metrics.AllSvdMetrics(),) -class VicRegTask(_ConstrastiveLossTask): +class VicRegTask(ConstrastiveLossTask): """A VICReg Task.""" def __init__( @@ -308,7 +313,7 @@ def metrics(self) -> runner.Metrics: return (metrics.AllSvdMetrics(),) -class TripletLossTask(_ConstrastiveLossTask): +class TripletLossTask(ConstrastiveLossTask): """The triplet loss task.""" def __init__(self, *args, margin: float = 1.0, **kwargs):