Skip to content

Commit

Permalink
Relaxes the requirement that the sampled source and target link nodes…
Browse files Browse the repository at this point in the history
… be the first and second sampled nodes in their node sets.

Allows to generate subgraphs with multiple links per/example.

PiperOrigin-RevId: 568850636
  • Loading branch information
aferludin authored and tensorflower-gardener committed Sep 27, 2023
1 parent ec9bd66 commit f6f72a0
Show file tree
Hide file tree
Showing 2 changed files with 116 additions and 23 deletions.
71 changes: 48 additions & 23 deletions tensorflow_gnn/experimental/sampler/subgraph_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.experimental.sampler import core
from tensorflow_gnn.experimental.sampler import ext_ops
from tensorflow_gnn.experimental.sampler import interfaces
from tensorflow_gnn.sampler import sampling_spec_pb2

Expand Down Expand Up @@ -356,7 +357,9 @@ def __init__(
)

def __call__(self, inputs: dict[str, tf.RaggedTensor]) -> tfgnn.GraphTensor:
seed_nodes = tf.concat([inputs['#source'], inputs['#target']], axis=-1)
seed_nodes = tf.concat(
[inputs[tfgnn.SOURCE_NAME], inputs[tfgnn.TARGET_NAME]], axis=-1
)
subgraph = self._sampling_pipeline(seed_nodes)
return AddLinkReadoutStruct(
readout_node_set=self._readout_node_set,
Expand All @@ -367,13 +370,15 @@ def __call__(self, inputs: dict[str, tf.RaggedTensor]) -> tfgnn.GraphTensor:
class AddLinkReadoutStruct(tf.keras.layers.Layer):
"""Helper class for adding readout structure for the link prediction task.
It is assumed that link sampling is done using `SamplingPipeline` starting
from two seed nodes that belong to the same node set (`seed_node_set`). In the
current implementation those two seed become node 0 and node 1 in each sampled
subgraph. The readout structure is created by creating `_readout` node set and
adding single readout node to each graph with its features (passed as 2nd
tuple value in the call method). The seed nodes are connected to readout nodes
by adding auxiliary readout edges.
NOTE: that the link source and target nodes must belong to the same node set.
The input graph tensor is assumed to be sampled by the `SamplingPipeline`
starting from the link source and target nodes.
To facilitate link prediction, the class adds `_readout` node set with one
readout node for each link. It accepts link features to copy over to their
corresponding readout nodes. The readout nodes are connected to link source
and target nodes using a pair of auxiliary readout edge sets.
NOTE: following the TF-GNN implementation, we direct readout edges from
seed nodes (source) to readout nodes (target).
Expand Down Expand Up @@ -402,18 +407,38 @@ def call(
graph_tensor, readout_features = inputs

seed_node_set = graph_tensor.node_sets[self._seed_node_set]
batch_size = tf.shape(seed_node_set.sizes)[0]
row_lengths = tf.ones([batch_size], dtype=graph_tensor.row_splits_dtype)
sizes = tf.ones([batch_size, 1], dtype=graph_tensor.indices_dtype)

def get_readout_index(value):
return tf.RaggedTensor.from_row_lengths(
tf.fill(
[batch_size],
tf.constant(value, dtype=graph_tensor.indices_dtype),
),
row_lengths=row_lengths,

ids = seed_node_set[core.NODE_ID_NAME]
link_source_idx = tf.cast(
ext_ops.ragged_lookup(
readout_features[tfgnn.SOURCE_NAME], ids, global_indices=False
),
graph_tensor.indices_dtype,
)
link_target_idx = tf.cast(
ext_ops.ragged_lookup(
readout_features[tfgnn.TARGET_NAME], ids, global_indices=False
),
graph_tensor.indices_dtype,
)

with tf.control_dependencies(
[
tf.debugging.assert_equal(
link_source_idx.row_lengths(),
link_target_idx.row_lengths(),
message=(
'The number of link source and target nodes must be the'
' same'
),
)
]
):
num_seeds = tf.cast(
link_source_idx.row_lengths(), graph_tensor.indices_dtype
)
readout_index = tf.ragged.range(num_seeds, dtype=num_seeds.dtype)
sizes = tf.expand_dims(num_seeds, axis=-1)

readout_node_sets = {
self._readout_node_set: tfgnn.NodeSet.from_fields(
Expand All @@ -424,15 +449,15 @@ def get_readout_index(value):
f'{self._readout_node_set}/source': tfgnn.EdgeSet.from_fields(
sizes=sizes,
adjacency=tfgnn.Adjacency.from_indices(
(self._seed_node_set, get_readout_index(0)),
(self._readout_node_set, get_readout_index(0)),
(self._seed_node_set, link_source_idx),
(self._readout_node_set, readout_index),
),
),
f'{self._readout_node_set}/target': tfgnn.EdgeSet.from_fields(
sizes=sizes,
adjacency=tfgnn.Adjacency.from_indices(
(self._seed_node_set, get_readout_index(1)),
(self._readout_node_set, get_readout_index(0)),
(self._seed_node_set, link_target_idx),
(self._readout_node_set, readout_index),
),
),
}
Expand Down
68 changes: 68 additions & 0 deletions tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,6 +333,7 @@ def _get_test_link_edges_sampler_schema_spec():
'alexa': ['mike', 'sam'],
'arne': ['alexa', 'mike', 'sam'],
'sam': ['bob'],
'bob': []
}
sampler = StringIdsSampler({'reviews': reviews_edges})

Expand Down Expand Up @@ -414,6 +415,73 @@ def test_link_sampling_pipeline(self):
merged_tensor.node_sets['_readout'].features['label'], [0.0, 1.0, 1.0]
)

def test_seeds_with_no_edges(self):
(_, sampler, graph_schema,
sampling_spec) = _get_test_link_edges_sampler_schema_spec()

source_seeds = tf.ragged.constant([['bob'], ['sam'], ['sam'], ['bob']])
target_seeds = tf.ragged.constant([['sam'], ['bob'], ['sam'], ['bob']])
labels = tf.ragged.constant([[0.0], [1.0], [2.0], [3.0]], dtype=tf.float32)
inputs = {
tfgnn.SOURCE_NAME: source_seeds,
tfgnn.TARGET_NAME: target_seeds,
'label': labels,
}
pipeline = subgraph_pipeline.LinkSamplingPipeline(
graph_schema, sampling_spec, sampler
)
graph_tensor = pipeline(inputs)
self.assertAllEqual(
graph_tensor.node_sets['_readout'].sizes, [[1], [1], [1], [1]]
)
self.assertAllEqual(
graph_tensor.node_sets['_readout'].features['label'],
labels,
)
graph_tensor = graph_tensor.merge_batch_to_components()
self.assertAllEqual(
tfgnn.structured_readout(graph_tensor, 'source', feature_name='#id'),
source_seeds.values,
)
self.assertAllEqual(
tfgnn.structured_readout(graph_tensor, 'target', feature_name='#id'),
target_seeds.values,
)

def test_multiple_seeds(self):
(_, sampler, graph_schema, sampling_spec) = (
_get_test_link_edges_sampler_schema_spec()
)

source_seeds = tf.ragged.constant([['bob', 'sam'], ['arne', 'mike']])
target_seeds = tf.ragged.constant([['sam', 'bob'], ['mike', 'arne']])
labels = tf.ragged.constant([[0.0, 1.0], [2.0, 3.0]], dtype=tf.float32)
inputs = {
tfgnn.SOURCE_NAME: source_seeds,
tfgnn.TARGET_NAME: target_seeds,
'label': labels,
}
pipeline = subgraph_pipeline.LinkSamplingPipeline(
graph_schema, sampling_spec, sampler
)
graph_tensor = pipeline(inputs)
self.assertAllEqual(
graph_tensor.node_sets['_readout'].sizes, [[2], [2]]
)
self.assertAllEqual(
graph_tensor.node_sets['_readout'].features['label'],
labels,
)
graph_tensor = graph_tensor.merge_batch_to_components()
self.assertAllEqual(
tfgnn.structured_readout(graph_tensor, 'source', feature_name='#id'),
source_seeds.values,
)
self.assertAllEqual(
tfgnn.structured_readout(graph_tensor, 'target', feature_name='#id'),
target_seeds.values,
)


if __name__ == '__main__':
tf.test.main()

0 comments on commit f6f72a0

Please sign in to comment.