diff --git a/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py b/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py index be23e892..338a7655 100644 --- a/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py +++ b/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py @@ -60,12 +60,137 @@ Optional[interfaces.KeyToFeaturesAccessor]] +def create_link_sampling_model_from_spec( + graph_schema: tfgnn.GraphSchema, + sampling_spec: sampling_spec_pb2.SamplingSpec, + edge_sampler_factory: EdgeSamplerFactory, + node_features_accessor_factory: Optional[NodeFeaturesLookupFactory] = None, + seed_node_dtype=tf.string, +) -> tf.keras.Model: + """Creates Keras model that accepts link sampling features and seeds to output GraphTensor.""" + _validate_link_sampling_graph_schema(graph_schema) + inputs = { + tfgnn.SOURCE_NAME: tf.keras.Input( + type_spec=tf.RaggedTensorSpec( + shape=[None, None], ragged_rank=1, dtype=seed_node_dtype + ), + name='SeedSource', + ), + tfgnn.TARGET_NAME: tf.keras.Input( + type_spec=tf.RaggedTensorSpec( + shape=[None, None], ragged_rank=1, dtype=seed_node_dtype + ), + name='SeedTarget', + ), + } + inputs.update(_readout_inputs_from_schema(graph_schema)) + pipeline = LinkSamplingPipeline( + graph_schema=graph_schema, sampling_spec=sampling_spec, + edge_sampler_factory=edge_sampler_factory, + node_features_accessor_factory=node_features_accessor_factory, + ) + subgraph = pipeline(inputs) + return tf.keras.Model(inputs=inputs, outputs=subgraph) + + +def _validate_link_sampling_graph_schema(graph_schema: tfgnn.GraphSchema): + """Validates that the given graph_schema has only one auxiliary node set and its related edge sets. + + Args: + graph_schema: tfgnn.GraphSchema object to validate for link prediction. + """ + aux_node_sets = [ + node_set + for node_set in graph_schema.node_sets.keys() + if tfgnn.get_aux_type_prefix(node_set) is not None + ] + if len(aux_node_sets) != 1 or '_readout' not in aux_node_sets: + raise ValueError( + 'There should be exactly one auxiliary node set for link sampling and' + f' it should be called `_readout`. Instead got {aux_node_sets}' + ) + + aux_edge_sets = [ + edge_set + for edge_set in graph_schema.edge_sets.keys() + if tfgnn.get_aux_type_prefix(edge_set) is not None + ] + if ( + len(aux_edge_sets) != 2 + or '_readout/source' not in aux_edge_sets + or '_readout/target' not in aux_edge_sets + ): + raise ValueError( + 'There should be exactly two auxiliary edge sets for link sampling and' + ' they should be called `_readout/source` and `_readout/target`.' + f' Instead got {aux_edge_sets}' + ) + aux_node_set_name = aux_node_sets[0] + if ( + graph_schema.edge_sets[aux_edge_sets[0]].source + != graph_schema.edge_sets[aux_edge_sets[1]].source + ): + raise ValueError( + '`_readout/source` and `_readout/target` edge sets should point from a' + ' common source node set to `_readout` as a target. Instead' + ' `_readout/source` points from' + f' {graph_schema.edge_sets[aux_edge_sets[0]].source} while' + ' `_readout/target` points from' + f' {graph_schema.edge_sets[aux_edge_sets[1]].source}' + ) + + if not ( + tfgnn.SOURCE_NAME + in graph_schema.node_sets[aux_node_set_name].features.keys() + and tfgnn.TARGET_NAME + in graph_schema.node_sets[aux_node_set_name].features.keys() + ): + raise ValueError( + 'Link sampling requires a #source and a #target feature in the' + ' _readout node set' + ) + + +def _readout_inputs_from_schema(graph_schema: tfgnn.GraphSchema): + """Populates inputs for sampling model from a graph schema. + + Args: + graph_schema: tfgnn.GraphSchema for generating the input specs. + + Returns: + dict from feature names to input specs + """ + aux_node_set_name = [ + node_set + for node_set in graph_schema.node_sets.keys() + if tfgnn.get_aux_type_prefix(node_set) is not None + ][0] + node_set_spec = graph_schema.node_sets[aux_node_set_name] + feature_names = [ + key + for key in node_set_spec.features.keys() + if key not in [tfgnn.SOURCE_NAME, tfgnn.TARGET_NAME] + ] + feature_inputs = {} + for feature_name in feature_names: + feature_inputs[feature_name] = tf.keras.Input( + type_spec=tf.RaggedTensorSpec( + shape=[None, None], + ragged_rank=1, + dtype=node_set_spec.features[feature_name].dtype, + ), + name=f'Seed/{feature_name}', + ) + return feature_inputs + + def create_sampling_model_from_spec( graph_schema: tfgnn.GraphSchema, sampling_spec: sampling_spec_pb2.SamplingSpec, edge_sampler_factory: EdgeSamplerFactory, node_features_accessor_factory: Optional[NodeFeaturesLookupFactory] = None, - seed_node_dtype=tf.string) -> tf.keras.Model: + seed_node_dtype=tf.string, +) -> tf.keras.Model: """Creates Keras model that accepts seed node IDs to output GraphTensor. Args: @@ -197,6 +322,106 @@ def node_features_accessor_factory( return self._node_features_accessor_factory +class LinkSamplingPipeline: + """Callable that makes GraphTensor out of seed node IDs and labels.""" + + def __init__( + self, + graph_schema: tfgnn.GraphSchema, + sampling_spec: sampling_spec_pb2.SamplingSpec, + edge_sampler_factory: EdgeSamplerFactory, + node_features_accessor_factory: Optional[ + NodeFeaturesLookupFactory + ] = None, + ): + self._sampling_pipeline = SamplingPipeline( + graph_schema=graph_schema, + sampling_spec=sampling_spec, + edge_sampler_factory=edge_sampler_factory, + node_features_accessor_factory=node_features_accessor_factory, + ) + self._target_node_set_name = graph_schema.edge_sets[ + '_readout/source' + ].source + self._aux_node_set_name = [ + node_set + for node_set in graph_schema.node_sets.keys() + if tfgnn.get_aux_type_prefix(node_set) is not None + ][0] + + def __call__(self, inputs: dict[str, tf.RaggedTensor]) -> tfgnn.GraphTensor: + seed_nodes = tf.concat([inputs['#source'], inputs['#target']], axis=-1) + subgraph = self._sampling_pipeline(seed_nodes) + return AddReadoutStruct()( + subgraph, inputs, self._target_node_set_name, self._aux_node_set_name + ) + + +class AddReadoutStruct(tf.keras.layers.Layer): + """Helper class for adding readout node and edge sets to a GraphTensor.""" + + def call( + self, + graph_tensor: tfgnn.GraphTensor, + readout_features: Mapping[str, tf.RaggedTensor], + target_node_set_name: tfgnn.NodeSetName, + aux_node_set_name: tfgnn.NodeSetName, + ) -> tfgnn.GraphTensor: + target_node_set = graph_tensor.node_sets[target_node_set_name] + batch_size = tf.shape(target_node_set.sizes)[0] + row_lengths = tf.ones([batch_size], dtype=graph_tensor.row_splits_dtype) + sizes = tf.ones([batch_size, 1], dtype=graph_tensor.indices_dtype) + + def get_readout_index(value): + return tf.RaggedTensor.from_row_lengths( + tf.fill( + [batch_size], + tf.constant(value, dtype=graph_tensor.indices_dtype), + row_lengths, + ), + row_lengths, + ) + + readout_source_source = get_readout_index(0) + readout_source_target = get_readout_index(0) + readout_target_source = get_readout_index(1) + readout_target_target = get_readout_index(0) + readout_source_adjacency = tfgnn.Adjacency.from_indices( + (target_node_set_name, readout_source_source), + (aux_node_set_name, readout_source_target), + ) + readout_target_adjacency = tfgnn.Adjacency.from_indices( + (target_node_set_name, readout_target_source), + (aux_node_set_name, readout_target_target), + ) + readout_node_sets = { + aux_node_set_name: tfgnn.NodeSet.from_fields( + features=readout_features, sizes=sizes + ) + } + readout_edge_sets = { + f'{aux_node_set_name}/source': tfgnn.EdgeSet.from_fields( + sizes=sizes, adjacency=readout_source_adjacency + ), + f'{aux_node_set_name}/target': tfgnn.EdgeSet.from_fields( + sizes=sizes, adjacency=readout_target_adjacency + ), + } + node_sets = { + node_set_name: node_set + for node_set_name, node_set in graph_tensor.node_sets.items() + } + edge_sets = { + edge_set_name: edge_set + for edge_set_name, edge_set in graph_tensor.edge_sets.items() + } + node_sets.update(readout_node_sets) + edge_sets.update(readout_edge_sets) + return graph_tensor.from_pieces( + node_sets=node_sets, edge_sets=edge_sets, context=graph_tensor.context + ) + + def sample_edge_sets( seed_node_ids: tf.RaggedTensor, graph_schema: tfgnn.GraphSchema, @@ -226,9 +451,13 @@ def sample_edge_sets( edge_set_sources = collections.defaultdict(list) edge_set_targets = collections.defaultdict(list) - + seed_op = ( + sampling_spec.seed_op + if sampling_spec.HasField('seed_op') + else sampling_spec.symmetric_link_seed_op + ) tensors_by_op_name = { - sampling_spec.seed_op.op_name: seed_node_ids + seed_op.op_name: seed_node_ids } for sampling_op in sampling_spec.sampling_ops: @@ -264,10 +493,17 @@ def _unique_y(x: tf.Tensor) -> tf.Tensor: def assert_sorted_sampling_spec(sampling_spec: sampling_spec_pb2.SamplingSpec): """Raises ValueError if `sampling_spec` is not topologically-sorted.""" - seen_ops = {sampling_spec.seed_op.op_name} + seed_op = ( + sampling_spec.seed_op + if sampling_spec.HasField('seed_op') + else sampling_spec.symmetric_link_seed_op + ) + seen_ops = {seed_op.op_name} for sampling_op in sampling_spec.sampling_ops: for input_op in sampling_op.input_op_names: if input_op not in seen_ops: - raise ValueError(f'Input op {input_op} is used before defined. ' - 'sampling_spec is not topologically-sorted') + raise ValueError( + f'Input op {input_op} is used before defined. ' + 'sampling_spec is not topologically-sorted' + ) seen_ops.add(sampling_op.op_name) diff --git a/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py b/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py index 446948ef..1e762d63 100644 --- a/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py +++ b/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py @@ -18,6 +18,7 @@ from typing import Dict, List, Tuple, Callable, Optional from absl.testing import parameterized +import numpy as np import tensorflow as tf import tensorflow_gnn as tfgnn @@ -326,5 +327,93 @@ def test_sampling_pipeline(self): self.assertIn(food.decode(), eats_edges[animal.decode()]) +def _get_test_link_edges_sampler_schema_spec(): + reviews_edges = { + 'mike': ['alexa'], + 'alexa': ['mike', 'sam'], + 'arne': ['alexa', 'mike', 'sam'], + 'sam': ['bob'], + } + sampler = StringIdsSampler({'reviews': reviews_edges}) + + graph_schema = tfgnn.GraphSchema() + graph_schema.node_sets['_readout'].features['label'].dtype = 1 + graph_schema.edge_sets['reviews'].source = 'authors' + graph_schema.edge_sets['reviews'].target = 'authors' + graph_schema.edge_sets['_readout/source'].target = '_readout' + graph_schema.edge_sets['_readout/source'].source = 'authors' + graph_schema.edge_sets['_readout/target'].target = '_readout' + graph_schema.edge_sets['_readout/target'].source = 'authors' + + sampling_spec = sampling_spec_pb2.SamplingSpec() + sampling_spec.symmetric_link_seed_op.op_name = 'seed' + sampling_spec.sampling_ops.add( + op_name='A', edge_set_name='reviews', + sample_size=2).input_op_names.append('seed') + + return reviews_edges, sampler, graph_schema, sampling_spec + + +class LinkSamplingPipelineTest(tf.test.TestCase, parameterized.TestCase): + + # TODO: b/301427603 - Test this parametrically on int64 and int32 seeds. + def test_link_sampling_pipeline(self): + (_, sampler, graph_schema, + sampling_spec) = _get_test_link_edges_sampler_schema_spec() + + source_seeds = tf.ragged.constant([['alexa'], ['sam'], ['sam']]) + target_seeds = tf.ragged.constant([['mike'], ['mike'], ['arne']]) + seed_labels = tf.ragged.constant([[0.0], [1.0], [1.0]], dtype=tf.float32) + + inputs = { + tfgnn.SOURCE_NAME: source_seeds, + tfgnn.TARGET_NAME: target_seeds, + 'label': seed_labels, + } + pipeline = subgraph_pipeline.LinkSamplingPipeline( + graph_schema, sampling_spec, sampler + ) + graph_tensor = pipeline(inputs) + dummy_node_features = tf.ragged.constant( + [np.eye(3, 3), np.eye(4, 3), np.eye(4, 3)] + ) + graph_tensor = graph_tensor.replace_features( + node_sets={'authors': {tfgnn.HIDDEN_STATE: dummy_node_features}} + ) + + self.assertAllEqual( + graph_tensor.node_sets['_readout'].sizes, [[1], [1], [1]] + ) + self.assertAllEqual( + graph_tensor.edge_sets['reviews'].sizes, [[3], [2], [3]] + ) + self.assertAllEqual( + graph_tensor.node_sets['authors'].sizes, [[3], [4], [4]] + ) + self.assertAllEqual( + graph_tensor.node_sets['_readout'].features['label'], + [[0.0], [1.0], [1.0]], + ) + merged_tensor = graph_tensor.merge_batch_to_components() + self.assertAllEqual( + merged_tensor.edge_sets['_readout/source'].adjacency.source, [0, 3, 7] + ) + self.assertAllEqual( + tfgnn.structured_readout( + merged_tensor, 'source', feature_name=tfgnn.HIDDEN_STATE + ), + [[1.0, 0.0, 0.0]] * 3, + ) + self.assertAllEqual( + tfgnn.structured_readout( + merged_tensor, 'target', feature_name=tfgnn.HIDDEN_STATE + ), + [[0.0, 1.0, 0.0]] * 3, + ) + self.assertAllEqual( + merged_tensor.node_sets['_readout'].features['label'], [0.0, 1.0, 1.0] + ) + + if __name__ == '__main__': tf.test.main()