From 30b220ec26e6bff9457a15d26045f6a5dbd999c2 Mon Sep 17 00:00:00 2001 From: Anton Tsitsulin Date: Tue, 19 Sep 2023 11:54:27 -0700 Subject: [PATCH] Add SVDmetrics to DGI. PiperOrigin-RevId: 566700095 --- .../models/contrastive_losses/tasks.py | 2 +- .../models/contrastive_losses/tasks_test.py | 38 +++++++++++++++++-- 2 files changed, 35 insertions(+), 5 deletions(-) diff --git a/tensorflow_gnn/models/contrastive_losses/tasks.py b/tensorflow_gnn/models/contrastive_losses/tasks.py index fa435be8..1a18556b 100644 --- a/tensorflow_gnn/models/contrastive_losses/tasks.py +++ b/tensorflow_gnn/models/contrastive_losses/tasks.py @@ -210,7 +210,7 @@ def metrics(self) -> runner.Metrics: tf.keras.metrics.BinaryCrossentropy(from_logits=True), tf.keras.metrics.BinaryAccuracy(), ), - "representations": (), + "representations": (metrics.AllSvdMetrics(),), } diff --git a/tensorflow_gnn/models/contrastive_losses/tasks_test.py b/tensorflow_gnn/models/contrastive_losses/tasks_test.py index cf665610..7073a830 100644 --- a/tensorflow_gnn/models/contrastive_losses/tasks_test.py +++ b/tensorflow_gnn/models/contrastive_losses/tasks_test.py @@ -271,6 +271,34 @@ def test_output_dictionary(self): for value in layer_output.values(): self.assertIsInstance(value, tf.Tensor) + def test_metrics(self): + # TODO(b/294224429): Remove when TF 2.13+ is required by all of TF-GNN + if int(tf.__version__.split(".")[1]) < 13: + self.skipTest( + "Dictionary metrics are unsupported in TF older than 2.13 " + f"but got TF {tf.__version__}" + ) + y_pred = { + "predictions": [[0.0, 0.0]], + "representations": [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]], + } + _, fake_y = self.task.preprocess(graph_tensor()) + + self.assertIsInstance(self.task.metrics(), Mapping) + self.assertEqual(fake_y.keys(), y_pred.keys()) + self.assertEqual(fake_y.keys(), self.task.metrics().keys()) + + for metric_key, metric_fns in self.task.metrics().items(): + self.assertIsInstance(metric_fns, Sequence) + for metric_fn in metric_fns: + metric_value = metric_fn(fake_y[metric_key], y_pred[metric_key]) + if isinstance(metric_value, dict): + # SVDMetrics returns a dictionary. + for metric_val in metric_value.values(): + self.assertEqual(metric_val.shape, ()) + else: + self.assertEqual(metric_value.shape, ()) + class BarlowTwinsTaskTest(tf.test.TestCase): task = tasks.BarlowTwinsTask("node", seed=8191) @@ -281,10 +309,12 @@ def test_pseudolabels(self): self.assertAllEqual(pseudolabels, ((),)) def test_metrics(self): - # TODO(tsitsulin): Remove when TF 2.13+ is required by all of TFGNN + # TODO(b/294224429): Remove when TF 2.13+ is required by all of TF-GNN if int(tf.__version__.split(".")[1]) < 13: - self.skipTest("Dictionary metrics are unsupported in TF older than 2.13 " - f"but got TF {tf.__version__}") + self.skipTest( + "Dictionary metrics are unsupported in TF older than 2.13 " + f"but got TF {tf.__version__}" + ) # Clean and corrupted representations (shape (1, 4)) packed in one Tensor. y_pred = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]] _, fake_y = self.task.preprocess(graph_tensor()) @@ -322,7 +352,7 @@ def test_pseudolabels(self): self.assertAllEqual(pseudolabels, ((),)) def test_metrics(self): - # TODO(tsitsulin): Remove when TF 2.13+ is required by all of TFGNN + # TODO(b/294224429): Remove when TF 2.13+ is required by all of TF-GNN if int(tf.__version__.split(".")[1]) < 13: self.skipTest("Dictionary metrics are unsupported in TF older than 2.13 " f"but got TF {tf.__version__}")