Skip to content

Commit

Permalink
Add SVDmetrics to DGI.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 566700095
  • Loading branch information
xgfs authored and tensorflower-gardener committed Sep 19, 2023
1 parent 0ea0452 commit 30b220e
Show file tree
Hide file tree
Showing 2 changed files with 35 additions and 5 deletions.
2 changes: 1 addition & 1 deletion tensorflow_gnn/models/contrastive_losses/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ def metrics(self) -> runner.Metrics:
tf.keras.metrics.BinaryCrossentropy(from_logits=True),
tf.keras.metrics.BinaryAccuracy(),
),
"representations": (),
"representations": (metrics.AllSvdMetrics(),),
}


Expand Down
38 changes: 34 additions & 4 deletions tensorflow_gnn/models/contrastive_losses/tasks_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,6 +271,34 @@ def test_output_dictionary(self):
for value in layer_output.values():
self.assertIsInstance(value, tf.Tensor)

def test_metrics(self):
# TODO(b/294224429): Remove when TF 2.13+ is required by all of TF-GNN
if int(tf.__version__.split(".")[1]) < 13:
self.skipTest(
"Dictionary metrics are unsupported in TF older than 2.13 "
f"but got TF {tf.__version__}"
)
y_pred = {
"predictions": [[0.0, 0.0]],
"representations": [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]],
}
_, fake_y = self.task.preprocess(graph_tensor())

self.assertIsInstance(self.task.metrics(), Mapping)
self.assertEqual(fake_y.keys(), y_pred.keys())
self.assertEqual(fake_y.keys(), self.task.metrics().keys())

for metric_key, metric_fns in self.task.metrics().items():
self.assertIsInstance(metric_fns, Sequence)
for metric_fn in metric_fns:
metric_value = metric_fn(fake_y[metric_key], y_pred[metric_key])
if isinstance(metric_value, dict):
# SVDMetrics returns a dictionary.
for metric_val in metric_value.values():
self.assertEqual(metric_val.shape, ())
else:
self.assertEqual(metric_value.shape, ())


class BarlowTwinsTaskTest(tf.test.TestCase):
task = tasks.BarlowTwinsTask("node", seed=8191)
Expand All @@ -281,10 +309,12 @@ def test_pseudolabels(self):
self.assertAllEqual(pseudolabels, ((),))

def test_metrics(self):
# TODO(tsitsulin): Remove when TF 2.13+ is required by all of TFGNN
# TODO(b/294224429): Remove when TF 2.13+ is required by all of TF-GNN
if int(tf.__version__.split(".")[1]) < 13:
self.skipTest("Dictionary metrics are unsupported in TF older than 2.13 "
f"but got TF {tf.__version__}")
self.skipTest(
"Dictionary metrics are unsupported in TF older than 2.13 "
f"but got TF {tf.__version__}"
)
# Clean and corrupted representations (shape (1, 4)) packed in one Tensor.
y_pred = [[[0.0, 0.0, 0.0, 0.0], [0.0, 0.0, 0.0, 0.0]]]
_, fake_y = self.task.preprocess(graph_tensor())
Expand Down Expand Up @@ -322,7 +352,7 @@ def test_pseudolabels(self):
self.assertAllEqual(pseudolabels, ((),))

def test_metrics(self):
# TODO(tsitsulin): Remove when TF 2.13+ is required by all of TFGNN
# TODO(b/294224429): Remove when TF 2.13+ is required by all of TF-GNN
if int(tf.__version__.split(".")[1]) < 13:
self.skipTest("Dictionary metrics are unsupported in TF older than 2.13 "
f"but got TF {tf.__version__}")
Expand Down

0 comments on commit 30b220e

Please sign in to comment.