Skip to content

Commit

Permalink
Creates link sampling pipeline for edge prediction.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565069852
  • Loading branch information
mihirparadkar authored and tensorflower-gardener committed Sep 21, 2023
1 parent 3295009 commit 66d1582
Show file tree
Hide file tree
Showing 2 changed files with 331 additions and 6 deletions.
248 changes: 242 additions & 6 deletions tensorflow_gnn/experimental/sampler/subgraph_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,137 @@
Optional[interfaces.KeyToFeaturesAccessor]]


def create_link_sampling_model_from_spec(
graph_schema: tfgnn.GraphSchema,
sampling_spec: sampling_spec_pb2.SamplingSpec,
edge_sampler_factory: EdgeSamplerFactory,
node_features_accessor_factory: Optional[NodeFeaturesLookupFactory] = None,
seed_node_dtype=tf.string,
) -> tf.keras.Model:
"""Creates Keras model that accepts link sampling features and seeds to output GraphTensor."""
_validate_link_sampling_graph_schema(graph_schema)
inputs = {
tfgnn.SOURCE_NAME: tf.keras.Input(
type_spec=tf.RaggedTensorSpec(
shape=[None, None], ragged_rank=1, dtype=seed_node_dtype
),
name='SeedSource',
),
tfgnn.TARGET_NAME: tf.keras.Input(
type_spec=tf.RaggedTensorSpec(
shape=[None, None], ragged_rank=1, dtype=seed_node_dtype
),
name='SeedTarget',
),
}
inputs.update(_readout_inputs_from_schema(graph_schema))
pipeline = LinkSamplingPipeline(
graph_schema=graph_schema, sampling_spec=sampling_spec,
edge_sampler_factory=edge_sampler_factory,
node_features_accessor_factory=node_features_accessor_factory,
)
subgraph = pipeline(inputs)
return tf.keras.Model(inputs=inputs, outputs=subgraph)


def _validate_link_sampling_graph_schema(graph_schema: tfgnn.GraphSchema):
"""Validates that the given graph_schema has only one auxiliary node set and its related edge sets.
Args:
graph_schema: tfgnn.GraphSchema object to validate for link prediction.
"""
aux_node_sets = [
node_set
for node_set in graph_schema.node_sets.keys()
if tfgnn.get_aux_type_prefix(node_set) is not None
]
if len(aux_node_sets) != 1 or '_readout' not in aux_node_sets:
raise ValueError(
'There should be exactly one auxiliary node set for link sampling and'
f' it should be called `_readout`. Instead got {aux_node_sets}'
)

aux_edge_sets = [
edge_set
for edge_set in graph_schema.edge_sets.keys()
if tfgnn.get_aux_type_prefix(edge_set) is not None
]
if (
len(aux_edge_sets) != 2
or '_readout/source' not in aux_edge_sets
or '_readout/target' not in aux_edge_sets
):
raise ValueError(
'There should be exactly two auxiliary edge sets for link sampling and'
' they should be called `_readout/source` and `_readout/target`.'
f' Instead got {aux_edge_sets}'
)
aux_node_set_name = aux_node_sets[0]
if (
graph_schema.edge_sets[aux_edge_sets[0]].source
!= graph_schema.edge_sets[aux_edge_sets[1]].source
):
raise ValueError(
'`_readout/source` and `_readout/target` edge sets should point from a'
' common source node set to `_readout` as a target. Instead'
' `_readout/source` points from'
f' {graph_schema.edge_sets[aux_edge_sets[0]].source} while'
' `_readout/target` points from'
f' {graph_schema.edge_sets[aux_edge_sets[1]].source}'
)

if not (
tfgnn.SOURCE_NAME
in graph_schema.node_sets[aux_node_set_name].features.keys()
and tfgnn.TARGET_NAME
in graph_schema.node_sets[aux_node_set_name].features.keys()
):
raise ValueError(
'Link sampling requires a #source and a #target feature in the'
' _readout node set'
)


def _readout_inputs_from_schema(graph_schema: tfgnn.GraphSchema):
"""Populates inputs for sampling model from a graph schema.
Args:
graph_schema: tfgnn.GraphSchema for generating the input specs.
Returns:
dict from feature names to input specs
"""
aux_node_set_name = [
node_set
for node_set in graph_schema.node_sets.keys()
if tfgnn.get_aux_type_prefix(node_set) is not None
][0]
node_set_spec = graph_schema.node_sets[aux_node_set_name]
feature_names = [
key
for key in node_set_spec.features.keys()
if key not in [tfgnn.SOURCE_NAME, tfgnn.TARGET_NAME]
]
feature_inputs = {}
for feature_name in feature_names:
feature_inputs[feature_name] = tf.keras.Input(
type_spec=tf.RaggedTensorSpec(
shape=[None, None],
ragged_rank=1,
dtype=node_set_spec.features[feature_name].dtype,
),
name=f'Seed/{feature_name}',
)
return feature_inputs


def create_sampling_model_from_spec(
graph_schema: tfgnn.GraphSchema,
sampling_spec: sampling_spec_pb2.SamplingSpec,
edge_sampler_factory: EdgeSamplerFactory,
node_features_accessor_factory: Optional[NodeFeaturesLookupFactory] = None,
seed_node_dtype=tf.string) -> tf.keras.Model:
seed_node_dtype=tf.string,
) -> tf.keras.Model:
"""Creates Keras model that accepts seed node IDs to output GraphTensor.
Args:
Expand Down Expand Up @@ -197,6 +322,106 @@ def node_features_accessor_factory(
return self._node_features_accessor_factory


class LinkSamplingPipeline:
"""Callable that makes GraphTensor out of seed node IDs and labels."""

def __init__(
self,
graph_schema: tfgnn.GraphSchema,
sampling_spec: sampling_spec_pb2.SamplingSpec,
edge_sampler_factory: EdgeSamplerFactory,
node_features_accessor_factory: Optional[
NodeFeaturesLookupFactory
] = None,
):
self._sampling_pipeline = SamplingPipeline(
graph_schema=graph_schema,
sampling_spec=sampling_spec,
edge_sampler_factory=edge_sampler_factory,
node_features_accessor_factory=node_features_accessor_factory,
)
self._target_node_set_name = graph_schema.edge_sets[
'_readout/source'
].source
self._aux_node_set_name = [
node_set
for node_set in graph_schema.node_sets.keys()
if tfgnn.get_aux_type_prefix(node_set) is not None
][0]

def __call__(self, inputs: dict[str, tf.RaggedTensor]) -> tfgnn.GraphTensor:
seed_nodes = tf.concat([inputs['#source'], inputs['#target']], axis=-1)
subgraph = self._sampling_pipeline(seed_nodes)
return AddReadoutStruct()(
subgraph, inputs, self._target_node_set_name, self._aux_node_set_name
)


class AddReadoutStruct(tf.keras.layers.Layer):
"""Helper class for adding readout node and edge sets to a GraphTensor."""

def call(
self,
graph_tensor: tfgnn.GraphTensor,
readout_features: Mapping[str, tf.RaggedTensor],
target_node_set_name: tfgnn.NodeSetName,
aux_node_set_name: tfgnn.NodeSetName,
) -> tfgnn.GraphTensor:
target_node_set = graph_tensor.node_sets[target_node_set_name]
batch_size = tf.shape(target_node_set.sizes)[0]
row_lengths = tf.ones([batch_size], dtype=graph_tensor.row_splits_dtype)
sizes = tf.ones([batch_size, 1], dtype=graph_tensor.indices_dtype)

def get_readout_index(value):
return tf.RaggedTensor.from_row_lengths(
tf.fill(
[batch_size],
tf.constant(value, dtype=graph_tensor.indices_dtype),
row_lengths,
),
row_lengths,
)

readout_source_source = get_readout_index(0)
readout_source_target = get_readout_index(0)
readout_target_source = get_readout_index(1)
readout_target_target = get_readout_index(0)
readout_source_adjacency = tfgnn.Adjacency.from_indices(
(target_node_set_name, readout_source_source),
(aux_node_set_name, readout_source_target),
)
readout_target_adjacency = tfgnn.Adjacency.from_indices(
(target_node_set_name, readout_target_source),
(aux_node_set_name, readout_target_target),
)
readout_node_sets = {
aux_node_set_name: tfgnn.NodeSet.from_fields(
features=readout_features, sizes=sizes
)
}
readout_edge_sets = {
f'{aux_node_set_name}/source': tfgnn.EdgeSet.from_fields(
sizes=sizes, adjacency=readout_source_adjacency
),
f'{aux_node_set_name}/target': tfgnn.EdgeSet.from_fields(
sizes=sizes, adjacency=readout_target_adjacency
),
}
node_sets = {
node_set_name: node_set
for node_set_name, node_set in graph_tensor.node_sets.items()
}
edge_sets = {
edge_set_name: edge_set
for edge_set_name, edge_set in graph_tensor.edge_sets.items()
}
node_sets.update(readout_node_sets)
edge_sets.update(readout_edge_sets)
return graph_tensor.from_pieces(
node_sets=node_sets, edge_sets=edge_sets, context=graph_tensor.context
)


def sample_edge_sets(
seed_node_ids: tf.RaggedTensor,
graph_schema: tfgnn.GraphSchema,
Expand Down Expand Up @@ -226,9 +451,13 @@ def sample_edge_sets(

edge_set_sources = collections.defaultdict(list)
edge_set_targets = collections.defaultdict(list)

seed_op = (
sampling_spec.seed_op
if sampling_spec.HasField('seed_op')
else sampling_spec.symmetric_link_seed_op
)
tensors_by_op_name = {
sampling_spec.seed_op.op_name: seed_node_ids
seed_op.op_name: seed_node_ids
}

for sampling_op in sampling_spec.sampling_ops:
Expand Down Expand Up @@ -264,10 +493,17 @@ def _unique_y(x: tf.Tensor) -> tf.Tensor:

def assert_sorted_sampling_spec(sampling_spec: sampling_spec_pb2.SamplingSpec):
"""Raises ValueError if `sampling_spec` is not topologically-sorted."""
seen_ops = {sampling_spec.seed_op.op_name}
seed_op = (
sampling_spec.seed_op
if sampling_spec.HasField('seed_op')
else sampling_spec.symmetric_link_seed_op
)
seen_ops = {seed_op.op_name}
for sampling_op in sampling_spec.sampling_ops:
for input_op in sampling_op.input_op_names:
if input_op not in seen_ops:
raise ValueError(f'Input op {input_op} is used before defined. '
'sampling_spec is not topologically-sorted')
raise ValueError(
f'Input op {input_op} is used before defined. '
'sampling_spec is not topologically-sorted'
)
seen_ops.add(sampling_op.op_name)
89 changes: 89 additions & 0 deletions tensorflow_gnn/experimental/sampler/subgraph_pipeline_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from typing import Dict, List, Tuple, Callable, Optional

from absl.testing import parameterized
import numpy as np
import tensorflow as tf

import tensorflow_gnn as tfgnn
Expand Down Expand Up @@ -326,5 +327,93 @@ def test_sampling_pipeline(self):
self.assertIn(food.decode(), eats_edges[animal.decode()])


def _get_test_link_edges_sampler_schema_spec():
reviews_edges = {
'mike': ['alexa'],
'alexa': ['mike', 'sam'],
'arne': ['alexa', 'mike', 'sam'],
'sam': ['bob'],
}
sampler = StringIdsSampler({'reviews': reviews_edges})

graph_schema = tfgnn.GraphSchema()
graph_schema.node_sets['_readout'].features['label'].dtype = 1
graph_schema.edge_sets['reviews'].source = 'authors'
graph_schema.edge_sets['reviews'].target = 'authors'
graph_schema.edge_sets['_readout/source'].target = '_readout'
graph_schema.edge_sets['_readout/source'].source = 'authors'
graph_schema.edge_sets['_readout/target'].target = '_readout'
graph_schema.edge_sets['_readout/target'].source = 'authors'

sampling_spec = sampling_spec_pb2.SamplingSpec()
sampling_spec.symmetric_link_seed_op.op_name = 'seed'
sampling_spec.sampling_ops.add(
op_name='A', edge_set_name='reviews',
sample_size=2).input_op_names.append('seed')

return reviews_edges, sampler, graph_schema, sampling_spec


class LinkSamplingPipelineTest(tf.test.TestCase, parameterized.TestCase):

# TODO: b/301427603 - Test this parametrically on int64 and int32 seeds.
def test_link_sampling_pipeline(self):
(_, sampler, graph_schema,
sampling_spec) = _get_test_link_edges_sampler_schema_spec()

source_seeds = tf.ragged.constant([['alexa'], ['sam'], ['sam']])
target_seeds = tf.ragged.constant([['mike'], ['mike'], ['arne']])
seed_labels = tf.ragged.constant([[0.0], [1.0], [1.0]], dtype=tf.float32)

inputs = {
tfgnn.SOURCE_NAME: source_seeds,
tfgnn.TARGET_NAME: target_seeds,
'label': seed_labels,
}
pipeline = subgraph_pipeline.LinkSamplingPipeline(
graph_schema, sampling_spec, sampler
)
graph_tensor = pipeline(inputs)
dummy_node_features = tf.ragged.constant(
[np.eye(3, 3), np.eye(4, 3), np.eye(4, 3)]
)
graph_tensor = graph_tensor.replace_features(
node_sets={'authors': {tfgnn.HIDDEN_STATE: dummy_node_features}}
)

self.assertAllEqual(
graph_tensor.node_sets['_readout'].sizes, [[1], [1], [1]]
)
self.assertAllEqual(
graph_tensor.edge_sets['reviews'].sizes, [[3], [2], [3]]
)
self.assertAllEqual(
graph_tensor.node_sets['authors'].sizes, [[3], [4], [4]]
)
self.assertAllEqual(
graph_tensor.node_sets['_readout'].features['label'],
[[0.0], [1.0], [1.0]],
)
merged_tensor = graph_tensor.merge_batch_to_components()
self.assertAllEqual(
merged_tensor.edge_sets['_readout/source'].adjacency.source, [0, 3, 7]
)
self.assertAllEqual(
tfgnn.structured_readout(
merged_tensor, 'source', feature_name=tfgnn.HIDDEN_STATE
),
[[1.0, 0.0, 0.0]] * 3,
)
self.assertAllEqual(
tfgnn.structured_readout(
merged_tensor, 'target', feature_name=tfgnn.HIDDEN_STATE
),
[[0.0, 1.0, 0.0]] * 3,
)
self.assertAllEqual(
merged_tensor.node_sets['_readout'].features['label'], [0.0, 1.0, 1.0]
)


if __name__ == '__main__':
tf.test.main()

0 comments on commit 66d1582

Please sign in to comment.