Skip to content

Commit

Permalink
Small bug fixes and link prediction code refactoring.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 568779736
  • Loading branch information
aferludin authored and tensorflower-gardener committed Sep 27, 2023
1 parent 14be914 commit 0e330e1
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 70 deletions.
23 changes: 18 additions & 5 deletions tensorflow_gnn/experimental/sampler/beam/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,25 @@ def node_features_accessor_factory(
) -> sampler.KeyToTfExampleAccessor:
if not graph_schema.node_sets[node_set_name].features:
return None

node_features = graph_schema.node_sets[node_set_name].features
features_spec = {}
for name, feature in node_features.items():
shape = _get_shape(feature)
dtype = tf.dtypes.as_dtype(feature.dtype)
features_spec[name] = tf.TensorSpec(shape, dtype)
def get_tensor_spec(
name: str, feature: graph_schema_pb2.Feature
) -> tf.TensorSpec:
try:
shape = _get_shape(feature)
dtype = tf.dtypes.as_dtype(feature.dtype)
return tf.TensorSpec(shape, dtype)
except Exception as e:
raise ValueError(
f'Invalid graph schema for {node_set_name} node set, feature {name}'
) from e

features_spec = {
name: get_tensor_spec(name, feature)
for name, feature in node_features.items()
}

accessor = sampler.KeyToTfExampleAccessor(
sampler.InMemStringKeyToBytesAccessor(
keys_to_values={'b': b'b'}, name=f'nodes/{node_set_name}'
Expand Down
150 changes: 85 additions & 65 deletions tensorflow_gnn/experimental/sampler/subgraph_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
from __future__ import annotations
import collections

from typing import Callable, Mapping, Optional, Dict
from typing import Callable, Mapping, Optional, Dict, Tuple
import tensorflow as tf
import tensorflow_gnn as tfgnn
from tensorflow_gnn.experimental.sampler import core
Expand Down Expand Up @@ -126,7 +126,6 @@ def _validate_link_sampling_graph_schema(graph_schema: tfgnn.GraphSchema):
' they should be called `_readout/source` and `_readout/target`.'
f' Instead got {aux_edge_sets}'
)
aux_node_set_name = aux_node_sets[0]
if (
graph_schema.edge_sets[aux_edge_sets[0]].source
!= graph_schema.edge_sets[aux_edge_sets[1]].source
Expand All @@ -140,17 +139,6 @@ def _validate_link_sampling_graph_schema(graph_schema: tfgnn.GraphSchema):
f' {graph_schema.edge_sets[aux_edge_sets[1]].source}'
)

if not (
tfgnn.SOURCE_NAME
in graph_schema.node_sets[aux_node_set_name].features.keys()
and tfgnn.TARGET_NAME
in graph_schema.node_sets[aux_node_set_name].features.keys()
):
raise ValueError(
'Link sampling requires a #source and a #target feature in the'
' _readout node set'
)


def _readout_inputs_from_schema(graph_schema: tfgnn.GraphSchema):
"""Populates inputs for sampling model from a graph schema.
Expand All @@ -170,7 +158,6 @@ def _readout_inputs_from_schema(graph_schema: tfgnn.GraphSchema):
feature_names = [
key
for key in node_set_spec.features.keys()
if key not in [tfgnn.SOURCE_NAME, tfgnn.TARGET_NAME]
]
feature_inputs = {}
for feature_name in feature_names:
Expand Down Expand Up @@ -237,6 +224,7 @@ def __init__(
edge_sampler_factory: EdgeSamplerFactory,
node_features_accessor_factory: Optional[
NodeFeaturesLookupFactory] = None,
seed_node_set_name: Optional[tfgnn.NodeSetName] = None,
):
assert_sorted_sampling_spec(sampling_spec)
self._graph_schema = graph_schema
Expand All @@ -247,6 +235,14 @@ def __init__(
tfgnn.NodeSetName, Optional[interfaces.KeyToFeaturesAccessor]
] = {}

if seed_node_set_name is None:
seed_node_set_name = self._sampling_spec.seed_op.node_set_name

if not seed_node_set_name:
raise ValueError('Seed node set name is not specified.')

self._seed_node_set_name = seed_node_set_name

def get_node_features_accessor(
self, node_set_name: tfgnn.NodeSetName
) -> Optional[interfaces.KeyToFeaturesAccessor]:
Expand All @@ -260,7 +256,7 @@ def get_node_features_accessor(

@property
def seed_node_set_name(self) -> tfgnn.NodeSetName:
return self._sampling_spec.seed_op.node_set_name
return self._seed_node_set_name

def __call__(self, seed_nodes: tf.RaggedTensor) -> tfgnn.GraphTensor:
"""Returns subgraph (as `GraphTensor`) sampled around `seed_nodes`.
Expand Down Expand Up @@ -335,41 +331,78 @@ def __init__(
NodeFeaturesLookupFactory
] = None,
):
self._readout_node_set = '_readout'
assert self._readout_node_set in graph_schema.node_sets, graph_schema
assert (
graph_schema.edge_sets[f'{self._readout_node_set}/source'].target
== graph_schema.edge_sets[f'{self._readout_node_set}/target'].target
== self._readout_node_set
), graph_schema
assert (
graph_schema.edge_sets[f'{self._readout_node_set}/source'].source
== graph_schema.edge_sets[f'{self._readout_node_set}/target'].source
), graph_schema

self._seed_node_set = graph_schema.edge_sets[
f'{self._readout_node_set}/source'
].source

self._sampling_pipeline = SamplingPipeline(
graph_schema=graph_schema,
sampling_spec=sampling_spec,
edge_sampler_factory=edge_sampler_factory,
node_features_accessor_factory=node_features_accessor_factory,
seed_node_set_name=self._seed_node_set,
)
self._target_node_set_name = graph_schema.edge_sets[
'_readout/source'
].source
self._aux_node_set_name = [
node_set
for node_set in graph_schema.node_sets.keys()
if tfgnn.get_aux_type_prefix(node_set) is not None
][0]

def __call__(self, inputs: dict[str, tf.RaggedTensor]) -> tfgnn.GraphTensor:
seed_nodes = tf.concat([inputs['#source'], inputs['#target']], axis=-1)
subgraph = self._sampling_pipeline(seed_nodes)
return AddReadoutStruct()(
subgraph, inputs, self._target_node_set_name, self._aux_node_set_name
)
return AddLinkReadoutStruct(
readout_node_set=self._readout_node_set,
seed_node_set=self._seed_node_set,
)((subgraph, inputs))


class AddLinkReadoutStruct(tf.keras.layers.Layer):
"""Helper class for adding readout structure for the link prediction task.
class AddReadoutStruct(tf.keras.layers.Layer):
"""Helper class for adding readout node and edge sets to a GraphTensor."""
It is assumed that link sampling is done using `SamplingPipeline` starting
from two seed nodes that belong to the same node set (`seed_node_set`). In the
current implementation those two seed become node 0 and node 1 in each sampled
subgraph. The readout structure is created by creating `_readout` node set and
adding single readout node to each graph with its features (passed as 2nd
tuple value in the call method). The seed nodes are connected to readout nodes
by adding auxiliary readout edges.
NOTE: following the TF-GNN implementation, we direct readout edges from
seed nodes (source) to readout nodes (target).
"""

def __init__(
self, readout_node_set: tfgnn.NodeSet, seed_node_set: tfgnn.NodeSetName
):
super().__init__()
assert readout_node_set
assert seed_node_set
self._readout_node_set = readout_node_set
self._seed_node_set = seed_node_set

def get_config(self):
return {
'readout_node_set': self._readout_node_set,
'seed_node_set': self._seed_node_set,
**super().get_config(),
}

def call(
self,
graph_tensor: tfgnn.GraphTensor,
readout_features: Mapping[str, tf.RaggedTensor],
target_node_set_name: tfgnn.NodeSetName,
aux_node_set_name: tfgnn.NodeSetName,
inputs: Tuple[tfgnn.GraphTensor, Mapping[str, tf.RaggedTensor]],
) -> tfgnn.GraphTensor:
target_node_set = graph_tensor.node_sets[target_node_set_name]
batch_size = tf.shape(target_node_set.sizes)[0]
graph_tensor, readout_features = inputs

seed_node_set = graph_tensor.node_sets[self._seed_node_set]
batch_size = tf.shape(seed_node_set.sizes)[0]
row_lengths = tf.ones([batch_size], dtype=graph_tensor.row_splits_dtype)
sizes = tf.ones([batch_size, 1], dtype=graph_tensor.indices_dtype)

Expand All @@ -378,48 +411,35 @@ def get_readout_index(value):
tf.fill(
[batch_size],
tf.constant(value, dtype=graph_tensor.indices_dtype),
row_lengths,
),
row_lengths,
row_lengths=row_lengths,
)

readout_source_source = get_readout_index(0)
readout_source_target = get_readout_index(0)
readout_target_source = get_readout_index(1)
readout_target_target = get_readout_index(0)
readout_source_adjacency = tfgnn.Adjacency.from_indices(
(target_node_set_name, readout_source_source),
(aux_node_set_name, readout_source_target),
)
readout_target_adjacency = tfgnn.Adjacency.from_indices(
(target_node_set_name, readout_target_source),
(aux_node_set_name, readout_target_target),
)
readout_node_sets = {
aux_node_set_name: tfgnn.NodeSet.from_fields(
self._readout_node_set: tfgnn.NodeSet.from_fields(
features=readout_features, sizes=sizes
)
}
readout_edge_sets = {
f'{aux_node_set_name}/source': tfgnn.EdgeSet.from_fields(
sizes=sizes, adjacency=readout_source_adjacency
f'{self._readout_node_set}/source': tfgnn.EdgeSet.from_fields(
sizes=sizes,
adjacency=tfgnn.Adjacency.from_indices(
(self._seed_node_set, get_readout_index(0)),
(self._readout_node_set, get_readout_index(0)),
),
),
f'{aux_node_set_name}/target': tfgnn.EdgeSet.from_fields(
sizes=sizes, adjacency=readout_target_adjacency
f'{self._readout_node_set}/target': tfgnn.EdgeSet.from_fields(
sizes=sizes,
adjacency=tfgnn.Adjacency.from_indices(
(self._seed_node_set, get_readout_index(1)),
(self._readout_node_set, get_readout_index(0)),
),
),
}
node_sets = {
node_set_name: node_set
for node_set_name, node_set in graph_tensor.node_sets.items()
}
edge_sets = {
edge_set_name: edge_set
for edge_set_name, edge_set in graph_tensor.edge_sets.items()
}
node_sets.update(readout_node_sets)
edge_sets.update(readout_edge_sets)
return graph_tensor.from_pieces(
node_sets=node_sets, edge_sets=edge_sets, context=graph_tensor.context
return tfgnn.GraphTensor.from_pieces(
node_sets={**graph_tensor.node_sets, **readout_node_sets},
edge_sets={**graph_tensor.edge_sets, **readout_edge_sets},
context=graph_tensor.context,
)


Expand Down

0 comments on commit 0e330e1

Please sign in to comment.