From 3375033b302869e61728c7e7cbaedb68c14ac581 Mon Sep 17 00:00:00 2001 From: Oleksandr Ferludin Date: Wed, 27 Sep 2023 02:34:55 -0700 Subject: [PATCH] Small bug fixes and link prediction code refactoring. PiperOrigin-RevId: 568789927 --- .../experimental/sampler/beam/sampler.py | 23 ++- .../experimental/sampler/subgraph_pipeline.py | 150 ++++++++++-------- 2 files changed, 103 insertions(+), 70 deletions(-) diff --git a/tensorflow_gnn/experimental/sampler/beam/sampler.py b/tensorflow_gnn/experimental/sampler/beam/sampler.py index d24e993f..9bf20936 100644 --- a/tensorflow_gnn/experimental/sampler/beam/sampler.py +++ b/tensorflow_gnn/experimental/sampler/beam/sampler.py @@ -105,12 +105,25 @@ def node_features_accessor_factory( ) -> sampler.KeyToTfExampleAccessor: if not graph_schema.node_sets[node_set_name].features: return None + node_features = graph_schema.node_sets[node_set_name].features - features_spec = {} - for name, feature in node_features.items(): - shape = _get_shape(feature) - dtype = tf.dtypes.as_dtype(feature.dtype) - features_spec[name] = tf.TensorSpec(shape, dtype) + def get_tensor_spec( + name: str, feature: graph_schema_pb2.Feature + ) -> tf.TensorSpec: + try: + shape = _get_shape(feature) + dtype = tf.dtypes.as_dtype(feature.dtype) + return tf.TensorSpec(shape, dtype) + except Exception as e: + raise ValueError( + f'Invalid graph schema for {node_set_name} node set, feature {name}' + ) from e + + features_spec = { + name: get_tensor_spec(name, feature) + for name, feature in node_features.items() + } + accessor = sampler.KeyToTfExampleAccessor( sampler.InMemStringKeyToBytesAccessor( keys_to_values={'b': b'b'}, name=f'nodes/{node_set_name}' diff --git a/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py b/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py index b5f6c569..e6c4d575 100644 --- a/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py +++ b/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py @@ -39,7 +39,7 @@ from __future__ import annotations import collections -from typing import Callable, Mapping, Optional, Dict +from typing import Callable, Mapping, Optional, Dict, Tuple import tensorflow as tf import tensorflow_gnn as tfgnn from tensorflow_gnn.experimental.sampler import core @@ -126,7 +126,6 @@ def _validate_link_sampling_graph_schema(graph_schema: tfgnn.GraphSchema): ' they should be called `_readout/source` and `_readout/target`.' f' Instead got {aux_edge_sets}' ) - aux_node_set_name = aux_node_sets[0] if ( graph_schema.edge_sets[aux_edge_sets[0]].source != graph_schema.edge_sets[aux_edge_sets[1]].source @@ -140,17 +139,6 @@ def _validate_link_sampling_graph_schema(graph_schema: tfgnn.GraphSchema): f' {graph_schema.edge_sets[aux_edge_sets[1]].source}' ) - if not ( - tfgnn.SOURCE_NAME - in graph_schema.node_sets[aux_node_set_name].features.keys() - and tfgnn.TARGET_NAME - in graph_schema.node_sets[aux_node_set_name].features.keys() - ): - raise ValueError( - 'Link sampling requires a #source and a #target feature in the' - ' _readout node set' - ) - def _readout_inputs_from_schema(graph_schema: tfgnn.GraphSchema): """Populates inputs for sampling model from a graph schema. @@ -170,7 +158,6 @@ def _readout_inputs_from_schema(graph_schema: tfgnn.GraphSchema): feature_names = [ key for key in node_set_spec.features.keys() - if key not in [tfgnn.SOURCE_NAME, tfgnn.TARGET_NAME] ] feature_inputs = {} for feature_name in feature_names: @@ -237,6 +224,7 @@ def __init__( edge_sampler_factory: EdgeSamplerFactory, node_features_accessor_factory: Optional[ NodeFeaturesLookupFactory] = None, + seed_node_set_name: Optional[tfgnn.NodeSetName] = None, ): assert_sorted_sampling_spec(sampling_spec) self._graph_schema = graph_schema @@ -247,6 +235,14 @@ def __init__( tfgnn.NodeSetName, Optional[interfaces.KeyToFeaturesAccessor] ] = {} + if seed_node_set_name is None: + seed_node_set_name = self._sampling_spec.seed_op.node_set_name + + if not seed_node_set_name: + raise ValueError('Seed node set name is not specified.') + + self._seed_node_set_name = seed_node_set_name + def get_node_features_accessor( self, node_set_name: tfgnn.NodeSetName ) -> Optional[interfaces.KeyToFeaturesAccessor]: @@ -260,7 +256,7 @@ def get_node_features_accessor( @property def seed_node_set_name(self) -> tfgnn.NodeSetName: - return self._sampling_spec.seed_op.node_set_name + return self._seed_node_set_name def __call__(self, seed_nodes: tf.RaggedTensor) -> tfgnn.GraphTensor: """Returns subgraph (as `GraphTensor`) sampled around `seed_nodes`. @@ -335,41 +331,78 @@ def __init__( NodeFeaturesLookupFactory ] = None, ): + self._readout_node_set = '_readout' + assert self._readout_node_set in graph_schema.node_sets, graph_schema + assert ( + graph_schema.edge_sets[f'{self._readout_node_set}/source'].target + == graph_schema.edge_sets[f'{self._readout_node_set}/target'].target + == self._readout_node_set + ), graph_schema + assert ( + graph_schema.edge_sets[f'{self._readout_node_set}/source'].source + == graph_schema.edge_sets[f'{self._readout_node_set}/target'].source + ), graph_schema + + self._seed_node_set = graph_schema.edge_sets[ + f'{self._readout_node_set}/source' + ].source + self._sampling_pipeline = SamplingPipeline( graph_schema=graph_schema, sampling_spec=sampling_spec, edge_sampler_factory=edge_sampler_factory, node_features_accessor_factory=node_features_accessor_factory, + seed_node_set_name=self._seed_node_set, ) - self._target_node_set_name = graph_schema.edge_sets[ - '_readout/source' - ].source - self._aux_node_set_name = [ - node_set - for node_set in graph_schema.node_sets.keys() - if tfgnn.get_aux_type_prefix(node_set) is not None - ][0] def __call__(self, inputs: dict[str, tf.RaggedTensor]) -> tfgnn.GraphTensor: seed_nodes = tf.concat([inputs['#source'], inputs['#target']], axis=-1) subgraph = self._sampling_pipeline(seed_nodes) - return AddReadoutStruct()( - subgraph, inputs, self._target_node_set_name, self._aux_node_set_name - ) + return AddLinkReadoutStruct( + readout_node_set=self._readout_node_set, + seed_node_set=self._seed_node_set, + )((subgraph, inputs)) + +class AddLinkReadoutStruct(tf.keras.layers.Layer): + """Helper class for adding readout structure for the link prediction task. -class AddReadoutStruct(tf.keras.layers.Layer): - """Helper class for adding readout node and edge sets to a GraphTensor.""" + It is assumed that link sampling is done using `SamplingPipeline` starting + from two seed nodes that belong to the same node set (`seed_node_set`). In the + current implementation those two seed become node 0 and node 1 in each sampled + subgraph. The readout structure is created by creating `_readout` node set and + adding single readout node to each graph with its features (passed as 2nd + tuple value in the call method). The seed nodes are connected to readout nodes + by adding auxiliary readout edges. + + NOTE: following the TF-GNN implementation, we direct readout edges from + seed nodes (source) to readout nodes (target). + """ + + def __init__( + self, readout_node_set: tfgnn.NodeSet, seed_node_set: tfgnn.NodeSetName + ): + super().__init__() + assert readout_node_set + assert seed_node_set + self._readout_node_set = readout_node_set + self._seed_node_set = seed_node_set + + def get_config(self): + return { + 'readout_node_set': self._readout_node_set, + 'seed_node_set': self._seed_node_set, + **super().get_config(), + } def call( self, - graph_tensor: tfgnn.GraphTensor, - readout_features: Mapping[str, tf.RaggedTensor], - target_node_set_name: tfgnn.NodeSetName, - aux_node_set_name: tfgnn.NodeSetName, + inputs: Tuple[tfgnn.GraphTensor, Mapping[str, tf.RaggedTensor]], ) -> tfgnn.GraphTensor: - target_node_set = graph_tensor.node_sets[target_node_set_name] - batch_size = tf.shape(target_node_set.sizes)[0] + graph_tensor, readout_features = inputs + + seed_node_set = graph_tensor.node_sets[self._seed_node_set] + batch_size = tf.shape(seed_node_set.sizes)[0] row_lengths = tf.ones([batch_size], dtype=graph_tensor.row_splits_dtype) sizes = tf.ones([batch_size, 1], dtype=graph_tensor.indices_dtype) @@ -378,48 +411,35 @@ def get_readout_index(value): tf.fill( [batch_size], tf.constant(value, dtype=graph_tensor.indices_dtype), - row_lengths, ), - row_lengths, + row_lengths=row_lengths, ) - readout_source_source = get_readout_index(0) - readout_source_target = get_readout_index(0) - readout_target_source = get_readout_index(1) - readout_target_target = get_readout_index(0) - readout_source_adjacency = tfgnn.Adjacency.from_indices( - (target_node_set_name, readout_source_source), - (aux_node_set_name, readout_source_target), - ) - readout_target_adjacency = tfgnn.Adjacency.from_indices( - (target_node_set_name, readout_target_source), - (aux_node_set_name, readout_target_target), - ) readout_node_sets = { - aux_node_set_name: tfgnn.NodeSet.from_fields( + self._readout_node_set: tfgnn.NodeSet.from_fields( features=readout_features, sizes=sizes ) } readout_edge_sets = { - f'{aux_node_set_name}/source': tfgnn.EdgeSet.from_fields( - sizes=sizes, adjacency=readout_source_adjacency + f'{self._readout_node_set}/source': tfgnn.EdgeSet.from_fields( + sizes=sizes, + adjacency=tfgnn.Adjacency.from_indices( + (self._seed_node_set, get_readout_index(0)), + (self._readout_node_set, get_readout_index(0)), + ), ), - f'{aux_node_set_name}/target': tfgnn.EdgeSet.from_fields( - sizes=sizes, adjacency=readout_target_adjacency + f'{self._readout_node_set}/target': tfgnn.EdgeSet.from_fields( + sizes=sizes, + adjacency=tfgnn.Adjacency.from_indices( + (self._seed_node_set, get_readout_index(1)), + (self._readout_node_set, get_readout_index(0)), + ), ), } - node_sets = { - node_set_name: node_set - for node_set_name, node_set in graph_tensor.node_sets.items() - } - edge_sets = { - edge_set_name: edge_set - for edge_set_name, edge_set in graph_tensor.edge_sets.items() - } - node_sets.update(readout_node_sets) - edge_sets.update(readout_edge_sets) - return graph_tensor.from_pieces( - node_sets=node_sets, edge_sets=edge_sets, context=graph_tensor.context + return tfgnn.GraphTensor.from_pieces( + node_sets={**graph_tensor.node_sets, **readout_node_sets}, + edge_sets={**graph_tensor.edge_sets, **readout_edge_sets}, + context=graph_tensor.context, )