diff --git a/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py b/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py index e6c4d575..851af966 100644 --- a/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py +++ b/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py @@ -43,6 +43,7 @@ import tensorflow as tf import tensorflow_gnn as tfgnn from tensorflow_gnn.experimental.sampler import core +from tensorflow_gnn.experimental.sampler import ext_ops from tensorflow_gnn.experimental.sampler import interfaces from tensorflow_gnn.sampler import sampling_spec_pb2 @@ -356,7 +357,9 @@ def __init__( ) def __call__(self, inputs: dict[str, tf.RaggedTensor]) -> tfgnn.GraphTensor: - seed_nodes = tf.concat([inputs['#source'], inputs['#target']], axis=-1) + seed_nodes = tf.concat( + [inputs[tfgnn.SOURCE_NAME], inputs[tfgnn.TARGET_NAME]], axis=-1 + ) subgraph = self._sampling_pipeline(seed_nodes) return AddLinkReadoutStruct( readout_node_set=self._readout_node_set, @@ -367,13 +370,15 @@ def __call__(self, inputs: dict[str, tf.RaggedTensor]) -> tfgnn.GraphTensor: class AddLinkReadoutStruct(tf.keras.layers.Layer): """Helper class for adding readout structure for the link prediction task. - It is assumed that link sampling is done using `SamplingPipeline` starting - from two seed nodes that belong to the same node set (`seed_node_set`). In the - current implementation those two seed become node 0 and node 1 in each sampled - subgraph. The readout structure is created by creating `_readout` node set and - adding single readout node to each graph with its features (passed as 2nd - tuple value in the call method). The seed nodes are connected to readout nodes - by adding auxiliary readout edges. + NOTE: that the link source and target nodes must belong to the same node set. + + The input graph tensor is assumed to be sampled by the `SamplingPipeline` + starting from the link source and target nodes. + + To facilitate link prediction, the class adds `_readout` node set with one + readout node for each link. It accepts link features to copy over to their + corresponding readout nodes. The readout nodes are connected to link source + and target nodes using a pair of auxiliary readout edge sets. NOTE: following the TF-GNN implementation, we direct readout edges from seed nodes (source) to readout nodes (target). @@ -402,18 +407,38 @@ def call( graph_tensor, readout_features = inputs seed_node_set = graph_tensor.node_sets[self._seed_node_set] - batch_size = tf.shape(seed_node_set.sizes)[0] - row_lengths = tf.ones([batch_size], dtype=graph_tensor.row_splits_dtype) - sizes = tf.ones([batch_size, 1], dtype=graph_tensor.indices_dtype) - - def get_readout_index(value): - return tf.RaggedTensor.from_row_lengths( - tf.fill( - [batch_size], - tf.constant(value, dtype=graph_tensor.indices_dtype), - ), - row_lengths=row_lengths, + + ids = seed_node_set[core.NODE_ID_NAME] + link_source_idx = tf.cast( + ext_ops.ragged_lookup( + readout_features[tfgnn.SOURCE_NAME], ids, global_indices=False + ), + graph_tensor.indices_dtype, + ) + link_target_idx = tf.cast( + ext_ops.ragged_lookup( + readout_features[tfgnn.TARGET_NAME], ids, global_indices=False + ), + graph_tensor.indices_dtype, + ) + + with tf.control_dependencies( + [ + tf.debugging.assert_equal( + link_source_idx.row_lengths(), + link_target_idx.row_lengths(), + message=( + 'The number of link source and target nodes must be the' + ' same' + ), + ) + ] + ): + num_seeds = tf.cast( + link_source_idx.row_lengths(), graph_tensor.indices_dtype ) + readout_index = tf.ragged.range(num_seeds, dtype=num_seeds.dtype) + sizes = tf.expand_dims(num_seeds, axis=-1) readout_node_sets = { self._readout_node_set: tfgnn.NodeSet.from_fields( @@ -424,15 +449,15 @@ def get_readout_index(value): f'{self._readout_node_set}/source': tfgnn.EdgeSet.from_fields( sizes=sizes, adjacency=tfgnn.Adjacency.from_indices( - (self._seed_node_set, get_readout_index(0)), - (self._readout_node_set, get_readout_index(0)), + (self._seed_node_set, link_source_idx), + (self._readout_node_set, readout_index), ), ), f'{self._readout_node_set}/target': tfgnn.EdgeSet.from_fields( sizes=sizes, adjacency=tfgnn.Adjacency.from_indices( - (self._seed_node_set, get_readout_index(1)), - (self._readout_node_set, get_readout_index(0)), + (self._seed_node_set, link_target_idx), + (self._readout_node_set, readout_index), ), ), } diff --git a/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py b/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py index 1e762d63..d16dfcfe 100644 --- a/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py +++ b/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py @@ -333,6 +333,7 @@ def _get_test_link_edges_sampler_schema_spec(): 'alexa': ['mike', 'sam'], 'arne': ['alexa', 'mike', 'sam'], 'sam': ['bob'], + 'bob': [] } sampler = StringIdsSampler({'reviews': reviews_edges}) @@ -414,6 +415,73 @@ def test_link_sampling_pipeline(self): merged_tensor.node_sets['_readout'].features['label'], [0.0, 1.0, 1.0] ) + def test_seeds_with_no_edges(self): + (_, sampler, graph_schema, + sampling_spec) = _get_test_link_edges_sampler_schema_spec() + + source_seeds = tf.ragged.constant([['bob'], ['sam'], ['sam'], ['bob']]) + target_seeds = tf.ragged.constant([['sam'], ['bob'], ['sam'], ['bob']]) + labels = tf.ragged.constant([[0.0], [1.0], [2.0], [3.0]], dtype=tf.float32) + inputs = { + tfgnn.SOURCE_NAME: source_seeds, + tfgnn.TARGET_NAME: target_seeds, + 'label': labels, + } + pipeline = subgraph_pipeline.LinkSamplingPipeline( + graph_schema, sampling_spec, sampler + ) + graph_tensor = pipeline(inputs) + self.assertAllEqual( + graph_tensor.node_sets['_readout'].sizes, [[1], [1], [1], [1]] + ) + self.assertAllEqual( + graph_tensor.node_sets['_readout'].features['label'], + labels, + ) + graph_tensor = graph_tensor.merge_batch_to_components() + self.assertAllEqual( + tfgnn.structured_readout(graph_tensor, 'source', feature_name='#id'), + source_seeds.values, + ) + self.assertAllEqual( + tfgnn.structured_readout(graph_tensor, 'target', feature_name='#id'), + target_seeds.values, + ) + + def test_multiple_seeds(self): + (_, sampler, graph_schema, sampling_spec) = ( + _get_test_link_edges_sampler_schema_spec() + ) + + source_seeds = tf.ragged.constant([['bob', 'sam'], ['arne', 'mike']]) + target_seeds = tf.ragged.constant([['sam', 'bob'], ['mike', 'arne']]) + labels = tf.ragged.constant([[0.0, 1.0], [2.0, 3.0]], dtype=tf.float32) + inputs = { + tfgnn.SOURCE_NAME: source_seeds, + tfgnn.TARGET_NAME: target_seeds, + 'label': labels, + } + pipeline = subgraph_pipeline.LinkSamplingPipeline( + graph_schema, sampling_spec, sampler + ) + graph_tensor = pipeline(inputs) + self.assertAllEqual( + graph_tensor.node_sets['_readout'].sizes, [[2], [2]] + ) + self.assertAllEqual( + graph_tensor.node_sets['_readout'].features['label'], + labels, + ) + graph_tensor = graph_tensor.merge_batch_to_components() + self.assertAllEqual( + tfgnn.structured_readout(graph_tensor, 'source', feature_name='#id'), + source_seeds.values, + ) + self.assertAllEqual( + tfgnn.structured_readout(graph_tensor, 'target', feature_name='#id'), + target_seeds.values, + ) + if __name__ == '__main__': tf.test.main()