From b07efe7e4121b46272e2f865e3fd70d5af9daf8e Mon Sep 17 00:00:00 2001 From: Oleksandr Ferludin Date: Mon, 9 Oct 2023 06:51:44 -0700 Subject: [PATCH] Adds edge features + ragged features support to `SamplingPipeline`s. PiperOrigin-RevId: 571925262 --- .../experimental/sampler/beam/sampler.py | 49 ++++++++++----- .../sampler/beam/unigraph_utils.py | 60 +++++++++---------- .../experimental/sampler/subgraph_pipeline.py | 56 +++++++++-------- .../sampler/subgraph_pipeline_test.py | 43 +++++++++++++ 4 files changed, 137 insertions(+), 71 deletions(-) diff --git a/tensorflow_gnn/experimental/sampler/beam/sampler.py b/tensorflow_gnn/experimental/sampler/beam/sampler.py index 804dff97..3715351a 100644 --- a/tensorflow_gnn/experimental/sampler/beam/sampler.py +++ b/tensorflow_gnn/experimental/sampler/beam/sampler.py @@ -70,16 +70,46 @@ def get_sampling_model( a mapping from the layer's name to the corresponding edge set. """ layer_name_to_edge_set = {} + + def get_tensor_spec( + feature: graph_schema_pb2.Feature, + *, + outer_dims: tuple[Optional[int], ...] = (), + debug_context: str = '', + ) -> tf.TensorSpec: + try: + dtype = tf.dtypes.as_dtype(feature.dtype) + shape = _get_shape(feature) + shape = tf.TensorShape(outer_dims).concatenate(shape) + ragged_rank = sum([int(d is None) for d in shape[1:]]) + if ragged_rank == 0: + return tf.TensorSpec(shape, dtype) + else: + return tf.RaggedTensorSpec(shape, dtype, ragged_rank=ragged_rank) + except Exception as e: + raise ValueError(f'Invalid graph schema for {debug_context}') from e + def edge_sampler_factory( op: sampler_lib.SamplingOp, *, counter: dict[str, int], ) -> sampler.UniformEdgesSampler: + # pylint: disable=g-complex-comprehension + edge_features = graph_schema.edge_sets[op.edge_set_name].features accessor = sampler.KeyToTfExampleAccessor( - sampler.InMemStringKeyToBytesAccessor( - keys_to_values={'b': b'b'}), + sampler.InMemStringKeyToBytesAccessor(keys_to_values={'b': b'b'}), features_spec={ '#target': tf.TensorSpec([None], tf.string), + **{ + name: get_tensor_spec( + feature, + outer_dims=(None,), + debug_context=( + f'{op.edge_set_name} edge set, feature {name}' + ), + ) + for name, feature in edge_features.items() + }, }, ) edge_set_count = counter.setdefault(op.edge_set_name, 0) @@ -107,20 +137,11 @@ def node_features_accessor_factory( return None node_features = graph_schema.node_sets[node_set_name].features - def get_tensor_spec( - name: str, feature: graph_schema_pb2.Feature - ) -> tf.TensorSpec: - try: - shape = _get_shape(feature) - dtype = tf.dtypes.as_dtype(feature.dtype) - return tf.TensorSpec(shape, dtype) - except Exception as e: - raise ValueError( - f'Invalid graph schema for {node_set_name} node set, feature {name}' - ) from e features_spec = { - name: get_tensor_spec(name, feature) + name: get_tensor_spec( + feature, debug_context=f'{node_set_name} node set, feature {name}' + ) for name, feature in node_features.items() } diff --git a/tensorflow_gnn/experimental/sampler/beam/unigraph_utils.py b/tensorflow_gnn/experimental/sampler/beam/unigraph_utils.py index 4b9000b3..4d9851b5 100644 --- a/tensorflow_gnn/experimental/sampler/beam/unigraph_utils.py +++ b/tensorflow_gnn/experimental/sampler/beam/unigraph_utils.py @@ -16,8 +16,7 @@ from __future__ import annotations -import functools -from typing import Dict, List, Optional, Tuple +from typing import Dict, List, Tuple import apache_beam as beam import numpy as np @@ -46,7 +45,7 @@ def read_seeds(root: beam.Pipeline, data_path: str) -> PCollection: Args: root: The root Beam Pipeline. data_path: The file path for the input node set. - + Returns: PCollection of sampler-compatible seeds. """ @@ -67,7 +66,7 @@ def _make_seed_feature( example: tf.train.Example with the seed features. feat_name: The feature to extract in this call - Returns: + Returns: bytes key and a list/np.array representation of a ragged tensor Raises: @@ -143,21 +142,23 @@ def expand(self, rcoll: PCollection) -> Dict[str, PCollection]: def seeds_from_graph_dict( - graph_dict: Dict[str, PCollection], - sampling_spec: sampler_lib.SamplingSpec) -> PCollection: + graph_dict: Dict[str, PCollection], sampling_spec: sampler_lib.SamplingSpec +) -> PCollection: """Emits sampler-compatible seeds from a collection of graph data and a sampling spec. - + Args: graph_dict: A dict of graph data represented as PCollections. sampling_spec: The sampling spec with the node set used for seeding. - + Returns: PCollection of sampler-compatible seeds. """ seed_nodes = graph_dict[f'nodes/{sampling_spec.seed_op.node_set_name}'] - return (seed_nodes - | 'SeedKeys' >> beam.Keys() - | 'MakeSeeds' >> beam.Map(_create_seeds)) + return ( + seed_nodes + | 'SeedKeys' >> beam.Keys() + | 'MakeSeeds' >> beam.Map(_create_seeds) + ) def _create_node_features( @@ -167,19 +168,18 @@ def _create_node_features( def _create_edge( - unused_source_id: bytes, - unused_target_id: bytes, + source_id: bytes, + target_id: bytes, example: tf.train.Example, - *, - edge_set_name: Optional[tfgnn.EdgeSetName] = None, ) -> tf.train.Example: """Creates input for edge set sampling stages.""" - for feature in (tfgnn.SOURCE_NAME, tfgnn.TARGET_NAME): - if feature not in example.features.feature: - raise ValueError( - f'Required feature {feature} is not present for the edge set' - f' {edge_set_name}' - ) + for name, value in ( + (tfgnn.SOURCE_NAME, source_id), + (tfgnn.TARGET_NAME, target_id), + ): + example.features.feature[name].CopyFrom( + tf.train.Feature(bytes_list=tf.train.BytesList(value=[value])) + ) return example @@ -188,7 +188,7 @@ class ReadAndConvertUnigraph(beam.PTransform): def __init__(self, graph_schema: tfgnn.GraphSchema, data_path: str): """Constructor for ReadAndConvertUnigraph PTransform. - + Args: graph_schema: tfgnn.GraphSchema for the input graph. data_path: A file path for the graph data in accepted file formats. @@ -196,21 +196,17 @@ def __init__(self, graph_schema: tfgnn.GraphSchema, data_path: str): self._graph_schema = graph_schema self._data_path = data_path - def expand(self, - rcoll: PCollection - ) -> Dict[str, PCollection]: + def expand(self, rcoll: PCollection) -> Dict[str, PCollection]: graph_data = unigraph.read_graph(self._graph_schema, self._data_path, rcoll) result_dict = {} for node_set_name in graph_data['nodes'].keys(): - result_dict[f'nodes/{node_set_name}'] = ( - graph_data['nodes'][node_set_name] - | f'ExtractNodeFeatures/{node_set_name}' - >> beam.MapTuple(_create_node_features) + result_dict[f'nodes/{node_set_name}'] = graph_data['nodes'][ + node_set_name + ] | f'ExtractNodeFeatures/{node_set_name}' >> beam.MapTuple( + _create_node_features ) for edge_set_name in graph_data['edges'].keys(): result_dict[f'edges/{edge_set_name}'] = graph_data['edges'][ edge_set_name - ] | f'ExtractEdges/{edge_set_name}' >> beam.MapTuple( - functools.partial(_create_edge, edge_set_name=edge_set_name) - ) + ] | f'ExtractEdges/{edge_set_name}' >> beam.MapTuple(_create_edge) return result_dict diff --git a/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py b/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py index 851af966..3edfb9d2 100644 --- a/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py +++ b/tensorflow_gnn/experimental/sampler/subgraph_pipeline.py @@ -38,6 +38,7 @@ from __future__ import annotations import collections +import functools from typing import Callable, Mapping, Optional, Dict, Tuple import tensorflow as tf @@ -291,14 +292,19 @@ def __call__(self, seed_nodes: tf.RaggedTensor) -> tfgnn.GraphTensor: tfgnn.SOURCE_NAME: ragged_seed, tfgnn.TARGET_NAME: ragged_seed, } + graph_tensor = core.build_graph_tensor(edge_sets=edge_sets) # Remove fake edge set. - graph_tensor = tf.keras.layers.Lambda( - lambda g: tfgnn.GraphTensor.from_pieces( # pylint: disable=g-long-lambda - context=g.context, node_sets=g.node_sets, - edge_sets={n: e for n, e in g.edge_sets.items() - if n != fake_edge_set_name}))(graph_tensor) + graph_tensor = tfgnn.GraphTensor.from_pieces( + context=graph_tensor.context, + node_sets=graph_tensor.node_sets, + edge_sets={ + n: e + for n, e in graph_tensor.edge_sets.items() + if n != fake_edge_set_name + }, + ) features = {} for node_set_name, node_set in graph_tensor.node_sets.items(): @@ -495,40 +501,40 @@ def sample_edge_sets( """ assert_sorted_sampling_spec(sampling_spec) - edge_set_sources = collections.defaultdict(list) - edge_set_targets = collections.defaultdict(list) + features_by_edge_set = collections.defaultdict( + lambda: collections.defaultdict(list) + ) seed_op = ( sampling_spec.seed_op if sampling_spec.HasField('seed_op') else sampling_spec.symmetric_link_seed_op ) - tensors_by_op_name = { + seeds_by_op_name = { seed_op.op_name: seed_node_ids } + # Concatenates inputs `[batch, item, ...]`` along the item dimension. + concat_fn = functools.partial(tf.concat, axis=1) for sampling_op in sampling_spec.sampling_ops: - input_tensors = tf.concat( - [tensors_by_op_name[op_name] for op_name in sampling_op.input_op_names], - axis=-1) + input_tensors = concat_fn( + [seeds_by_op_name[op_name] for op_name in sampling_op.input_op_names] + ) edge_sampler = edge_sampler_factory(sampling_op) sampled_edges = edge_sampler(input_tensors) + features = features_by_edge_set[sampling_op.edge_set_name] + for feature_name, feature_value in sampled_edges.items(): + features[feature_name].append(feature_value) - edge_set_sources[sampling_op.edge_set_name].append( - sampled_edges[tfgnn.SOURCE_NAME]) - edge_set_targets[sampling_op.edge_set_name].append( - sampled_edges[tfgnn.TARGET_NAME]) - tensors_by_op_name[sampling_op.op_name] = sampled_edges[tfgnn.TARGET_NAME] + seeds_by_op_name[sampling_op.op_name] = sampled_edges[tfgnn.TARGET_NAME] edge_sets = {} - for edge_set_name, source_list in edge_set_sources.items(): - target_list = edge_set_targets[edge_set_name] - edge_set_key = ','.join((graph_schema.edge_sets[edge_set_name].source, - edge_set_name, - graph_schema.edge_sets[edge_set_name].target)) - edge_sets[edge_set_key] = { - tfgnn.SOURCE_NAME: tf.concat(source_list, axis=-1), - tfgnn.TARGET_NAME: tf.concat(target_list, axis=-1), - } + for edge_set_name, features in features_by_edge_set.items(): + edge_set_key = ','.join(( + graph_schema.edge_sets[edge_set_name].source, + edge_set_name, + graph_schema.edge_sets[edge_set_name].target, + )) + edge_sets[edge_set_key] = {k: concat_fn(v) for k, v in features.items()} return edge_sets diff --git a/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py b/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py index d16dfcfe..4dd36e56 100644 --- a/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py +++ b/tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py @@ -22,6 +22,7 @@ import tensorflow as tf import tensorflow_gnn as tfgnn +from tensorflow_gnn.experimental.sampler import core from tensorflow_gnn.experimental.sampler import interfaces from tensorflow_gnn.experimental.sampler import subgraph_pipeline from tensorflow_gnn.sampler import sampling_spec_pb2 @@ -327,6 +328,48 @@ def test_sampling_pipeline(self): self.assertIn(food.decode(), eats_edges[animal.decode()]) +class EdgeFeaturesTest(tf.test.TestCase): + + def test_homogeneous(self): + graph_schema = tfgnn.GraphSchema() + graph_schema.node_sets['a'].description = 'test node set' + graph_schema.edge_sets['a->a'].source = 'a' + graph_schema.edge_sets['a->a'].target = 'a' + graph_schema.edge_sets['a->a'].features['f'].dtype = 1 + + sampling_spec = sampling_spec_pb2.SamplingSpec() + sampling_spec.seed_op.op_name = 'seed' + sampling_spec.seed_op.node_set_name = 'a' + sampling_spec.sampling_ops.add( + op_name='hop1', edge_set_name='a->a', sample_size=100 + ).input_op_names.append('seed') + + def edge_sampler_factory(sampling_op): + self.assertEqual(sampling_op.edge_set_name, 'a->a') + return core.InMemUniformEdgesSampler( + num_source_nodes=3, + source=tf.constant([2, 0], tf.int32), + target=tf.constant([0, 1], tf.int32), + edge_features={'f': [2.0, 0.0]}, + seed=42, + sample_size=sampling_op.sample_size, + name=sampling_op.edge_set_name, + ) + + sampling_model = subgraph_pipeline.create_sampling_model_from_spec( + graph_schema, + sampling_spec, + edge_sampler_factory, + seed_node_dtype=tf.int32, + ) + result = sampling_model(tf.ragged.constant([[2], [0]])) + self.assertIn('a', result.node_sets) + self.assertIn('a->a', result.edge_sets) + edge_features = result.edge_sets['a->a'].get_features_dict() + self.assertIn('f', edge_features) + self.assertAllEqual(edge_features['f'], tf.ragged.constant([[2.0], [0.0]])) + + def _get_test_link_edges_sampler_schema_spec(): reviews_edges = { 'mike': ['alexa'],