Skip to content

Commit

Permalink
Link prediction support in Beam sampler (part 2).
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 565158910
  • Loading branch information
mihirparadkar authored and tensorflower-gardener committed Sep 21, 2023
1 parent 14c90ad commit d69bd91
Show file tree
Hide file tree
Showing 5 changed files with 295 additions and 21 deletions.
4 changes: 4 additions & 0 deletions tensorflow_gnn/experimental/sampler/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,10 @@
subgraph_pipeline.create_sampling_model_from_spec
)

create_link_sampling_model_from_spec = (
subgraph_pipeline.create_link_sampling_model_from_spec
)

# Export.
create_program = eval_dag.create_program
save_model = eval_dag.save_model
Expand Down
1 change: 1 addition & 0 deletions tensorflow_gnn/experimental/sampler/beam/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,7 @@ pytype_strict_library(
srcs = ["unigraph_utils.py"],
srcs_version = "PY3ONLY",
deps = [
":executor_lib",
"//third_party/py/apache_beam",
"//:expect_numpy_installed",
"//:expect_tensorflow_installed",
Expand Down
71 changes: 50 additions & 21 deletions tensorflow_gnn/experimental/sampler/beam/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,20 +113,43 @@ def node_features_accessor_factory(
features_spec[name] = tf.TensorSpec(shape, dtype)
accessor = sampler.KeyToTfExampleAccessor(
sampler.InMemStringKeyToBytesAccessor(
keys_to_values={'b': b'b'},
name=f'nodes/{node_set_name}'),
keys_to_values={'b': b'b'}, name=f'nodes/{node_set_name}'
),
features_spec=features_spec,
)
return accessor

counter = {}
return sampler.create_sampling_model_from_spec(
graph_schema,
sampling_spec,
edge_sampler_factory=functools.partial(
edge_sampler_factory, counter=counter),
node_features_accessor_factory=node_features_accessor_factory,
), layer_name_to_edge_set
if sampling_spec.HasField('seed_op'):
if sampling_spec.HasField('symmetric_link_seed_op'):
raise ValueError(
'seed_op is already set, so symmetric_link_seed_op should not be set.'
)
return (
sampler.create_sampling_model_from_spec(
graph_schema,
sampling_spec,
edge_sampler_factory=functools.partial(
edge_sampler_factory, counter=counter
),
node_features_accessor_factory=node_features_accessor_factory,
),
layer_name_to_edge_set,
)
elif sampling_spec.HasField('symmetric_link_seed_op'):
return (
sampler.create_link_sampling_model_from_spec(
graph_schema,
sampling_spec,
edge_sampler_factory=functools.partial(
edge_sampler_factory, counter=counter
),
node_features_accessor_factory=node_features_accessor_factory,
),
layer_name_to_edge_set,
)
else:
raise ValueError('One of seed_op or symmetric_link_seed_op must be set.')


def _create_beam_runner(
Expand Down Expand Up @@ -260,22 +283,28 @@ def app_main(argv) -> None:
graph_schema, data_path
)
feeds = feeds_unique
feeds.update({
layer_name: feeds_unique[layers_mapping[layer_name]]
for layer_name in layers_mapping
})
if FLAGS.input_seeds:
feeds.update(
{
layer_name: feeds_unique[layers_mapping[layer_name]]
for layer_name in layers_mapping
}
)
if FLAGS.input_seeds and sampling_spec.HasField('symmetric_link_seed_op'):
inputs = root | 'ReadLinkSeeds' >> unigraph_utils.ReadLinkSeeds(
graph_schema, FLAGS.input_seeds
)
elif FLAGS.input_seeds and sampling_spec.HasField('seed_op'):
seeds = unigraph_utils.read_seeds(root, FLAGS.input_seeds)
inputs = {
'Input': seeds,
}
else:
seeds = unigraph_utils.seeds_from_graph_dict(feeds, sampling_spec)
inputs = {
'Input': seeds,
}
inputs = {
'Input': seeds,
}
examples = executor_lib.execute(
program_pb,
inputs,
feeds=feeds,
artifacts_path=artifacts_path
program_pb, inputs, feeds=feeds, artifacts_path=artifacts_path
)
# results are tuple: example_id to tf.Example with graph tensors.
coder = beam.coders.ProtoCoder(tf.train.Example)
Expand Down
86 changes: 86 additions & 0 deletions tensorflow_gnn/experimental/sampler/beam/unigraph_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,11 @@
import tensorflow_gnn as tfgnn
from tensorflow_gnn import sampler as sampler_lib
from tensorflow_gnn.data import unigraph
from tensorflow_gnn.experimental.sampler.beam import executor_lib


PCollection = beam.pvalue.PCollection
Values = executor_lib.Values


def _create_seeds(node_id: bytes) -> Tuple[bytes, List[List[np.ndarray]]]:
Expand Down Expand Up @@ -53,6 +55,90 @@ def read_seeds(root: beam.Pipeline, data_path: str) -> PCollection:
)


def _make_seed_feature(
example: tf.train.Example, feat_name: str
) -> tuple[bytes, Values]:
"""Formats a particular feature from a tf.train.Example into a seed format.
Args:
example: tf.train.Example with the seed features.
feat_name: The feature to extract in this call
Returns:
bytes key and a list/np.array representation of a ragged tensor
Raises:
ValueError: on a malformed Example without the given feature present
"""

seed_source = example.features.feature[tfgnn.SOURCE_NAME].bytes_list.value[0]
seed_target = example.features.feature[tfgnn.TARGET_NAME].bytes_list.value[0]
key = bytes(
f'S{seed_source.decode("utf-8")}:T{seed_target.decode("utf-8")}', 'utf-8'
)
if example.features.feature[feat_name].HasField('bytes_list'):
bytes_value = example.features.feature[feat_name].bytes_list.value
value = [[
np.array(bytes_value, dtype=np.object_),
np.array([1], dtype=np.int64),
]]
elif example.features.feature[feat_name].HasField('float_list'):
float_value = example.features.feature[feat_name].float_list.value
value = [[
np.array(float_value, dtype=np.float32),
np.array([1], dtype=np.int64),
]]
elif example.features.feature[feat_name].HasField('int64_list'):
int64_value = example.features.feature[feat_name].int64_list.value
value = [[
np.array(int64_value, dtype=np.float32),
np.array([1], dtype=np.int64),
]]
else:
raise ValueError(f'Feature {feat_name} is not present in this example')
return (key, value)


class ReadLinkSeeds(beam.PTransform):
"""Reads seeds for link prediction into PCollections for each seed feature."""

def __init__(self, graph_schema: tfgnn.GraphSchema, data_path: str):
"""Constructor for ReadLinkSeeds PTransform.
Args:
graph_schema: tfgnn.GraphSchema for the input graph.
data_path: A file path for the seed data in accepted file formats.
"""
self._graph_schema = graph_schema
self._data_path = data_path
self._readout_feature_names = [
key
for key in graph_schema.node_sets['_readout'].features.keys()
if key not in [tfgnn.SOURCE_NAME, tfgnn.TARGET_NAME]
]

def expand(self, rcoll: PCollection) -> Dict[str, PCollection]:
seed_table = rcoll | 'ReadSeedTable' >> unigraph.ReadTable(
self._data_path,
converters=unigraph.build_converter_from_schema(
self._graph_schema.node_sets['_readout'].features
),
)
pcolls_out = {}
pcolls_out['SeedSource'] = seed_table | 'MakeSeedSource' >> beam.Map(
_make_seed_feature, tfgnn.SOURCE_NAME
)
pcolls_out['SeedTarget'] = seed_table | 'MakeSeedTarget' >> beam.Map(
_make_seed_feature, tfgnn.TARGET_NAME
)
for feature in self._readout_feature_names:
pcolls_out[f'Seed/{feature}'] = (
seed_table
| f'MakeSeed/{feature}' >> beam.Map(_make_seed_feature, feature)
)
return pcolls_out


def seeds_from_graph_dict(
graph_dict: Dict[str, PCollection],
sampling_spec: sampler_lib.SamplingSpec) -> PCollection:
Expand Down
154 changes: 154 additions & 0 deletions tensorflow_gnn/experimental/sampler/beam/unigraph_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -1075,6 +1075,160 @@ def test_read_and_convert_edge_features(self):
root.run()


class UnigraphUtilsLinkPredictionTest(tf.test.TestCase):

def setUp(self):
super().setUp()
self.resource_dir = test_utils.get_resource_dir('testdata/homogeneous')
self.seed_path = test_utils.get_resource(
'testdata/homogeneous/tastelike.csv'
)

def test_read_link_seeds(self):
data_path = self.seed_path
graph_schema = tfgnn.GraphSchema()
graph_schema.node_sets['_readout'].features[
'weight'
].dtype = tf.float32.as_datatype_enum
graph_schema.node_sets['_readout'].features[
'#source'
].dtype = tf.string.as_datatype_enum
graph_schema.node_sets['_readout'].features[
'#target'
].dtype = tf.string.as_datatype_enum
expected_source_seeds = [
(
bytes('Samanatsu:Tdaidai', 'utf-8'),
[[
np.array([b'amanatsu'], dtype=np.object_),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Samanatsu:Tlumia', 'utf-8'),
[[
np.array([b'amanatsu'], dtype=np.object_),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Skiyomi:Tkomikan', 'utf-8'),
[[
np.array([b'kiyomi'], dtype=np.object_),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Smandora:Tkomikan', 'utf-8'),
[[
np.array([b'mandora'], dtype=np.object_),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Smandora:Ttangelo', 'utf-8'),
[[
np.array([b'mandora'], dtype=np.object_),
np.array([1], dtype=np.int64),
]],
),
]
expected_target_seeds = [
(
bytes('Samanatsu:Tdaidai', 'utf-8'),
[[
np.array([b'daidai'], dtype=np.object_),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Samanatsu:Tlumia', 'utf-8'),
[[
np.array([b'lumia'], dtype=np.object_),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Skiyomi:Tkomikan', 'utf-8'),
[[
np.array([b'komikan'], dtype=np.object_),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Smandora:Tkomikan', 'utf-8'),
[[
np.array([b'komikan'], dtype=np.object_),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Smandora:Ttangelo', 'utf-8'),
[[
np.array([b'tangelo'], dtype=np.object_),
np.array([1], dtype=np.int64),
]],
),
]
expected_weight_seeds = [
(
bytes('Samanatsu:Tdaidai', 'utf-8'),
[[
np.array([0.1], dtype=np.float32),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Samanatsu:Tlumia', 'utf-8'),
[[
np.array([0.2], dtype=np.float32),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Skiyomi:Tkomikan', 'utf-8'),
[[
np.array([0.3], dtype=np.float32),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Smandora:Tkomikan', 'utf-8'),
[[
np.array([0.4], dtype=np.float32),
np.array([1], dtype=np.int64),
]],
),
(
bytes('Smandora:Ttangelo', 'utf-8'),
[[
np.array([0.5], dtype=np.float32),
np.array([1], dtype=np.int64),
]],
),
]
with test_pipeline.TestPipeline() as root:
link_seeds = root | 'ReadLinkSeeds' >> unigraph_utils.ReadLinkSeeds(
graph_schema, data_path
)
util.assert_that(
link_seeds['SeedSource'],
util.equal_to(expected_source_seeds),
label='assert_source_seeds',
)
util.assert_that(
link_seeds['SeedTarget'],
util.equal_to(expected_target_seeds),
label='assert_target_seeds',
)
util.assert_that(
link_seeds['Seed/weight'],
util.equal_to(expected_weight_seeds),
label='assert_weight_seeds',
)
root.run()


# This function is needed because serialization to bytes in Python
# is non-deterministic
def _tf_example_from_bytes(s: bytes):
Expand Down

0 comments on commit d69bd91

Please sign in to comment.