Skip to content

Commit

Permalink
Adds edge features + ragged features support to SamplingPipelines.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 571925262
  • Loading branch information
aferludin authored and tensorflower-gardener committed Oct 9, 2023
1 parent ca8974d commit b07efe7
Show file tree
Hide file tree
Showing 4 changed files with 137 additions and 71 deletions.
49 changes: 35 additions & 14 deletions tensorflow_gnn/experimental/sampler/beam/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,16 +70,46 @@ def get_sampling_model(
a mapping from the layer's name to the corresponding edge set.
"""
layer_name_to_edge_set = {}

def get_tensor_spec(
feature: graph_schema_pb2.Feature,
*,
outer_dims: tuple[Optional[int], ...] = (),
debug_context: str = '',
) -> tf.TensorSpec:
try:
dtype = tf.dtypes.as_dtype(feature.dtype)
shape = _get_shape(feature)
shape = tf.TensorShape(outer_dims).concatenate(shape)
ragged_rank = sum([int(d is None) for d in shape[1:]])
if ragged_rank == 0:
return tf.TensorSpec(shape, dtype)
else:
return tf.RaggedTensorSpec(shape, dtype, ragged_rank=ragged_rank)
except Exception as e:
raise ValueError(f'Invalid graph schema for {debug_context}') from e

def edge_sampler_factory(
op: sampler_lib.SamplingOp,
*,
counter: dict[str, int],
) -> sampler.UniformEdgesSampler:
# pylint: disable=g-complex-comprehension
edge_features = graph_schema.edge_sets[op.edge_set_name].features
accessor = sampler.KeyToTfExampleAccessor(
sampler.InMemStringKeyToBytesAccessor(
keys_to_values={'b': b'b'}),
sampler.InMemStringKeyToBytesAccessor(keys_to_values={'b': b'b'}),
features_spec={
'#target': tf.TensorSpec([None], tf.string),
**{
name: get_tensor_spec(
feature,
outer_dims=(None,),
debug_context=(
f'{op.edge_set_name} edge set, feature {name}'
),
)
for name, feature in edge_features.items()
},
},
)
edge_set_count = counter.setdefault(op.edge_set_name, 0)
Expand Down Expand Up @@ -107,20 +137,11 @@ def node_features_accessor_factory(
return None

node_features = graph_schema.node_sets[node_set_name].features
def get_tensor_spec(
name: str, feature: graph_schema_pb2.Feature
) -> tf.TensorSpec:
try:
shape = _get_shape(feature)
dtype = tf.dtypes.as_dtype(feature.dtype)
return tf.TensorSpec(shape, dtype)
except Exception as e:
raise ValueError(
f'Invalid graph schema for {node_set_name} node set, feature {name}'
) from e

features_spec = {
name: get_tensor_spec(name, feature)
name: get_tensor_spec(
feature, debug_context=f'{node_set_name} node set, feature {name}'
)
for name, feature in node_features.items()
}

Expand Down
60 changes: 28 additions & 32 deletions tensorflow_gnn/experimental/sampler/beam/unigraph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,7 @@

from __future__ import annotations

import functools
from typing import Dict, List, Optional, Tuple
from typing import Dict, List, Tuple

import apache_beam as beam
import numpy as np
Expand Down Expand Up @@ -46,7 +45,7 @@ def read_seeds(root: beam.Pipeline, data_path: str) -> PCollection:
Args:
root: The root Beam Pipeline.
data_path: The file path for the input node set.
Returns:
PCollection of sampler-compatible seeds.
"""
Expand All @@ -67,7 +66,7 @@ def _make_seed_feature(
example: tf.train.Example with the seed features.
feat_name: The feature to extract in this call
Returns:
Returns:
bytes key and a list/np.array representation of a ragged tensor
Raises:
Expand Down Expand Up @@ -143,21 +142,23 @@ def expand(self, rcoll: PCollection) -> Dict[str, PCollection]:


def seeds_from_graph_dict(
graph_dict: Dict[str, PCollection],
sampling_spec: sampler_lib.SamplingSpec) -> PCollection:
graph_dict: Dict[str, PCollection], sampling_spec: sampler_lib.SamplingSpec
) -> PCollection:
"""Emits sampler-compatible seeds from a collection of graph data and a sampling spec.
Args:
graph_dict: A dict of graph data represented as PCollections.
sampling_spec: The sampling spec with the node set used for seeding.
Returns:
PCollection of sampler-compatible seeds.
"""
seed_nodes = graph_dict[f'nodes/{sampling_spec.seed_op.node_set_name}']
return (seed_nodes
| 'SeedKeys' >> beam.Keys()
| 'MakeSeeds' >> beam.Map(_create_seeds))
return (
seed_nodes
| 'SeedKeys' >> beam.Keys()
| 'MakeSeeds' >> beam.Map(_create_seeds)
)


def _create_node_features(
Expand All @@ -167,19 +168,18 @@ def _create_node_features(


def _create_edge(
unused_source_id: bytes,
unused_target_id: bytes,
source_id: bytes,
target_id: bytes,
example: tf.train.Example,
*,
edge_set_name: Optional[tfgnn.EdgeSetName] = None,
) -> tf.train.Example:
"""Creates input for edge set sampling stages."""
for feature in (tfgnn.SOURCE_NAME, tfgnn.TARGET_NAME):
if feature not in example.features.feature:
raise ValueError(
f'Required feature {feature} is not present for the edge set'
f' {edge_set_name}'
)
for name, value in (
(tfgnn.SOURCE_NAME, source_id),
(tfgnn.TARGET_NAME, target_id),
):
example.features.feature[name].CopyFrom(
tf.train.Feature(bytes_list=tf.train.BytesList(value=[value]))
)
return example


Expand All @@ -188,29 +188,25 @@ class ReadAndConvertUnigraph(beam.PTransform):

def __init__(self, graph_schema: tfgnn.GraphSchema, data_path: str):
"""Constructor for ReadAndConvertUnigraph PTransform.
Args:
graph_schema: tfgnn.GraphSchema for the input graph.
data_path: A file path for the graph data in accepted file formats.
"""
self._graph_schema = graph_schema
self._data_path = data_path

def expand(self,
rcoll: PCollection
) -> Dict[str, PCollection]:
def expand(self, rcoll: PCollection) -> Dict[str, PCollection]:
graph_data = unigraph.read_graph(self._graph_schema, self._data_path, rcoll)
result_dict = {}
for node_set_name in graph_data['nodes'].keys():
result_dict[f'nodes/{node_set_name}'] = (
graph_data['nodes'][node_set_name]
| f'ExtractNodeFeatures/{node_set_name}'
>> beam.MapTuple(_create_node_features)
result_dict[f'nodes/{node_set_name}'] = graph_data['nodes'][
node_set_name
] | f'ExtractNodeFeatures/{node_set_name}' >> beam.MapTuple(
_create_node_features
)
for edge_set_name in graph_data['edges'].keys():
result_dict[f'edges/{edge_set_name}'] = graph_data['edges'][
edge_set_name
] | f'ExtractEdges/{edge_set_name}' >> beam.MapTuple(
functools.partial(_create_edge, edge_set_name=edge_set_name)
)
] | f'ExtractEdges/{edge_set_name}' >> beam.MapTuple(_create_edge)
return result_dict
56 changes: 31 additions & 25 deletions tensorflow_gnn/experimental/sampler/subgraph_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@

from __future__ import annotations
import collections
import functools

from typing import Callable, Mapping, Optional, Dict, Tuple
import tensorflow as tf
Expand Down Expand Up @@ -291,14 +292,19 @@ def __call__(self, seed_nodes: tf.RaggedTensor) -> tfgnn.GraphTensor:
tfgnn.SOURCE_NAME: ragged_seed,
tfgnn.TARGET_NAME: ragged_seed,
}

graph_tensor = core.build_graph_tensor(edge_sets=edge_sets)

# Remove fake edge set.
graph_tensor = tf.keras.layers.Lambda(
lambda g: tfgnn.GraphTensor.from_pieces( # pylint: disable=g-long-lambda
context=g.context, node_sets=g.node_sets,
edge_sets={n: e for n, e in g.edge_sets.items()
if n != fake_edge_set_name}))(graph_tensor)
graph_tensor = tfgnn.GraphTensor.from_pieces(
context=graph_tensor.context,
node_sets=graph_tensor.node_sets,
edge_sets={
n: e
for n, e in graph_tensor.edge_sets.items()
if n != fake_edge_set_name
},
)

features = {}
for node_set_name, node_set in graph_tensor.node_sets.items():
Expand Down Expand Up @@ -495,40 +501,40 @@ def sample_edge_sets(
"""
assert_sorted_sampling_spec(sampling_spec)

edge_set_sources = collections.defaultdict(list)
edge_set_targets = collections.defaultdict(list)
features_by_edge_set = collections.defaultdict(
lambda: collections.defaultdict(list)
)
seed_op = (
sampling_spec.seed_op
if sampling_spec.HasField('seed_op')
else sampling_spec.symmetric_link_seed_op
)
tensors_by_op_name = {
seeds_by_op_name = {
seed_op.op_name: seed_node_ids
}

# Concatenates inputs `[batch, item, ...]`` along the item dimension.
concat_fn = functools.partial(tf.concat, axis=1)
for sampling_op in sampling_spec.sampling_ops:
input_tensors = tf.concat(
[tensors_by_op_name[op_name] for op_name in sampling_op.input_op_names],
axis=-1)
input_tensors = concat_fn(
[seeds_by_op_name[op_name] for op_name in sampling_op.input_op_names]
)
edge_sampler = edge_sampler_factory(sampling_op)
sampled_edges = edge_sampler(input_tensors)
features = features_by_edge_set[sampling_op.edge_set_name]
for feature_name, feature_value in sampled_edges.items():
features[feature_name].append(feature_value)

edge_set_sources[sampling_op.edge_set_name].append(
sampled_edges[tfgnn.SOURCE_NAME])
edge_set_targets[sampling_op.edge_set_name].append(
sampled_edges[tfgnn.TARGET_NAME])
tensors_by_op_name[sampling_op.op_name] = sampled_edges[tfgnn.TARGET_NAME]
seeds_by_op_name[sampling_op.op_name] = sampled_edges[tfgnn.TARGET_NAME]

edge_sets = {}
for edge_set_name, source_list in edge_set_sources.items():
target_list = edge_set_targets[edge_set_name]
edge_set_key = ','.join((graph_schema.edge_sets[edge_set_name].source,
edge_set_name,
graph_schema.edge_sets[edge_set_name].target))
edge_sets[edge_set_key] = {
tfgnn.SOURCE_NAME: tf.concat(source_list, axis=-1),
tfgnn.TARGET_NAME: tf.concat(target_list, axis=-1),
}
for edge_set_name, features in features_by_edge_set.items():
edge_set_key = ','.join((
graph_schema.edge_sets[edge_set_name].source,
edge_set_name,
graph_schema.edge_sets[edge_set_name].target,
))
edge_sets[edge_set_key] = {k: concat_fn(v) for k, v in features.items()}

return edge_sets

Expand Down
43 changes: 43 additions & 0 deletions tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import tensorflow as tf

import tensorflow_gnn as tfgnn
from tensorflow_gnn.experimental.sampler import core
from tensorflow_gnn.experimental.sampler import interfaces
from tensorflow_gnn.experimental.sampler import subgraph_pipeline
from tensorflow_gnn.sampler import sampling_spec_pb2
Expand Down Expand Up @@ -327,6 +328,48 @@ def test_sampling_pipeline(self):
self.assertIn(food.decode(), eats_edges[animal.decode()])


class EdgeFeaturesTest(tf.test.TestCase):

def test_homogeneous(self):
graph_schema = tfgnn.GraphSchema()
graph_schema.node_sets['a'].description = 'test node set'
graph_schema.edge_sets['a->a'].source = 'a'
graph_schema.edge_sets['a->a'].target = 'a'
graph_schema.edge_sets['a->a'].features['f'].dtype = 1

sampling_spec = sampling_spec_pb2.SamplingSpec()
sampling_spec.seed_op.op_name = 'seed'
sampling_spec.seed_op.node_set_name = 'a'
sampling_spec.sampling_ops.add(
op_name='hop1', edge_set_name='a->a', sample_size=100
).input_op_names.append('seed')

def edge_sampler_factory(sampling_op):
self.assertEqual(sampling_op.edge_set_name, 'a->a')
return core.InMemUniformEdgesSampler(
num_source_nodes=3,
source=tf.constant([2, 0], tf.int32),
target=tf.constant([0, 1], tf.int32),
edge_features={'f': [2.0, 0.0]},
seed=42,
sample_size=sampling_op.sample_size,
name=sampling_op.edge_set_name,
)

sampling_model = subgraph_pipeline.create_sampling_model_from_spec(
graph_schema,
sampling_spec,
edge_sampler_factory,
seed_node_dtype=tf.int32,
)
result = sampling_model(tf.ragged.constant([[2], [0]]))
self.assertIn('a', result.node_sets)
self.assertIn('a->a', result.edge_sets)
edge_features = result.edge_sets['a->a'].get_features_dict()
self.assertIn('f', edge_features)
self.assertAllEqual(edge_features['f'], tf.ragged.constant([[2.0], [0.0]]))


def _get_test_link_edges_sampler_schema_spec():
reviews_edges = {
'mike': ['alexa'],
Expand Down

0 comments on commit b07efe7

Please sign in to comment.