diff --git a/tensorflow_gnn/experimental/sampler/__init__.py b/tensorflow_gnn/experimental/sampler/__init__.py index 74975a57..c59697cb 100644 --- a/tensorflow_gnn/experimental/sampler/__init__.py +++ b/tensorflow_gnn/experimental/sampler/__init__.py @@ -31,6 +31,10 @@ subgraph_pipeline.create_sampling_model_from_spec ) +create_link_sampling_model_from_spec = ( + subgraph_pipeline.create_link_sampling_model_from_spec +) + # Export. create_program = eval_dag.create_program save_model = eval_dag.save_model diff --git a/tensorflow_gnn/experimental/sampler/beam/BUILD b/tensorflow_gnn/experimental/sampler/beam/BUILD index 3b3d3f18..693889d6 100644 --- a/tensorflow_gnn/experimental/sampler/beam/BUILD +++ b/tensorflow_gnn/experimental/sampler/beam/BUILD @@ -135,6 +135,7 @@ pytype_strict_library( srcs = ["unigraph_utils.py"], srcs_version = "PY3ONLY", deps = [ + ":executor_lib", "//third_party/py/apache_beam", "//:expect_numpy_installed", "//:expect_tensorflow_installed", @@ -190,7 +191,10 @@ py_binary( pytype_strict_contrib_test( name = "unigraph_utils_test", srcs = ["unigraph_utils_test.py"], - data = ["@tensorflow_gnn//testdata/heterogeneous"], + data = [ + "@tensorflow_gnn//testdata/heterogeneous", + "@tensorflow_gnn//testdata/homogeneous", + ], python_version = "PY3", srcs_version = "PY3ONLY", deps = [ diff --git a/tensorflow_gnn/experimental/sampler/beam/sampler.py b/tensorflow_gnn/experimental/sampler/beam/sampler.py index ba6da2aa..d24e993f 100644 --- a/tensorflow_gnn/experimental/sampler/beam/sampler.py +++ b/tensorflow_gnn/experimental/sampler/beam/sampler.py @@ -113,20 +113,43 @@ def node_features_accessor_factory( features_spec[name] = tf.TensorSpec(shape, dtype) accessor = sampler.KeyToTfExampleAccessor( sampler.InMemStringKeyToBytesAccessor( - keys_to_values={'b': b'b'}, - name=f'nodes/{node_set_name}'), + keys_to_values={'b': b'b'}, name=f'nodes/{node_set_name}' + ), features_spec=features_spec, ) return accessor counter = {} - return sampler.create_sampling_model_from_spec( - graph_schema, - sampling_spec, - edge_sampler_factory=functools.partial( - edge_sampler_factory, counter=counter), - node_features_accessor_factory=node_features_accessor_factory, - ), layer_name_to_edge_set + if sampling_spec.HasField('seed_op'): + if sampling_spec.HasField('symmetric_link_seed_op'): + raise ValueError( + 'seed_op is already set, so symmetric_link_seed_op should not be set.' + ) + return ( + sampler.create_sampling_model_from_spec( + graph_schema, + sampling_spec, + edge_sampler_factory=functools.partial( + edge_sampler_factory, counter=counter + ), + node_features_accessor_factory=node_features_accessor_factory, + ), + layer_name_to_edge_set, + ) + elif sampling_spec.HasField('symmetric_link_seed_op'): + return ( + sampler.create_link_sampling_model_from_spec( + graph_schema, + sampling_spec, + edge_sampler_factory=functools.partial( + edge_sampler_factory, counter=counter + ), + node_features_accessor_factory=node_features_accessor_factory, + ), + layer_name_to_edge_set, + ) + else: + raise ValueError('One of seed_op or symmetric_link_seed_op must be set.') def _create_beam_runner( @@ -260,22 +283,28 @@ def app_main(argv) -> None: graph_schema, data_path ) feeds = feeds_unique - feeds.update({ - layer_name: feeds_unique[layers_mapping[layer_name]] - for layer_name in layers_mapping - }) - if FLAGS.input_seeds: + feeds.update( + { + layer_name: feeds_unique[layers_mapping[layer_name]] + for layer_name in layers_mapping + } + ) + if FLAGS.input_seeds and sampling_spec.HasField('symmetric_link_seed_op'): + inputs = root | 'ReadLinkSeeds' >> unigraph_utils.ReadLinkSeeds( + graph_schema, FLAGS.input_seeds + ) + elif FLAGS.input_seeds and sampling_spec.HasField('seed_op'): seeds = unigraph_utils.read_seeds(root, FLAGS.input_seeds) + inputs = { + 'Input': seeds, + } else: seeds = unigraph_utils.seeds_from_graph_dict(feeds, sampling_spec) - inputs = { - 'Input': seeds, - } + inputs = { + 'Input': seeds, + } examples = executor_lib.execute( - program_pb, - inputs, - feeds=feeds, - artifacts_path=artifacts_path + program_pb, inputs, feeds=feeds, artifacts_path=artifacts_path ) # results are tuple: example_id to tf.Example with graph tensors. coder = beam.coders.ProtoCoder(tf.train.Example) diff --git a/tensorflow_gnn/experimental/sampler/beam/unigraph_utils.py b/tensorflow_gnn/experimental/sampler/beam/unigraph_utils.py index 14147735..941e754c 100644 --- a/tensorflow_gnn/experimental/sampler/beam/unigraph_utils.py +++ b/tensorflow_gnn/experimental/sampler/beam/unigraph_utils.py @@ -14,6 +14,7 @@ # ============================================================================== """Functions to read from unigraph format into that accepted by sampler_v2.""" +from __future__ import annotations from typing import Dict, List, Optional, Tuple import apache_beam as beam @@ -22,9 +23,11 @@ import tensorflow_gnn as tfgnn from tensorflow_gnn import sampler as sampler_lib from tensorflow_gnn.data import unigraph +from tensorflow_gnn.experimental.sampler.beam import executor_lib PCollection = beam.pvalue.PCollection +Values = executor_lib.Values def _create_seeds(node_id: bytes) -> Tuple[bytes, List[List[np.ndarray]]]: @@ -53,6 +56,90 @@ def read_seeds(root: beam.Pipeline, data_path: str) -> PCollection: ) +def _make_seed_feature( + example: tf.train.Example, feat_name: str +) -> tuple[bytes, Values]: + """Formats a particular feature from a tf.train.Example into a seed format. + + Args: + example: tf.train.Example with the seed features. + feat_name: The feature to extract in this call + + Returns: + bytes key and a list/np.array representation of a ragged tensor + + Raises: + ValueError: on a malformed Example without the given feature present + """ + + seed_source = example.features.feature[tfgnn.SOURCE_NAME].bytes_list.value[0] + seed_target = example.features.feature[tfgnn.TARGET_NAME].bytes_list.value[0] + key = bytes( + f'S{seed_source.decode("utf-8")}:T{seed_target.decode("utf-8")}', 'utf-8' + ) + if example.features.feature[feat_name].HasField('bytes_list'): + bytes_value = example.features.feature[feat_name].bytes_list.value + value = [[ + np.array(bytes_value, dtype=np.object_), + np.array([1], dtype=np.int64), + ]] + elif example.features.feature[feat_name].HasField('float_list'): + float_value = example.features.feature[feat_name].float_list.value + value = [[ + np.array(float_value, dtype=np.float32), + np.array([1], dtype=np.int64), + ]] + elif example.features.feature[feat_name].HasField('int64_list'): + int64_value = example.features.feature[feat_name].int64_list.value + value = [[ + np.array(int64_value, dtype=np.float32), + np.array([1], dtype=np.int64), + ]] + else: + raise ValueError(f'Feature {feat_name} is not present in this example') + return (key, value) + + +class ReadLinkSeeds(beam.PTransform): + """Reads seeds for link prediction into PCollections for each seed feature.""" + + def __init__(self, graph_schema: tfgnn.GraphSchema, data_path: str): + """Constructor for ReadLinkSeeds PTransform. + + Args: + graph_schema: tfgnn.GraphSchema for the input graph. + data_path: A file path for the seed data in accepted file formats. + """ + self._graph_schema = graph_schema + self._data_path = data_path + self._readout_feature_names = [ + key + for key in graph_schema.node_sets['_readout'].features.keys() + if key not in [tfgnn.SOURCE_NAME, tfgnn.TARGET_NAME] + ] + + def expand(self, rcoll: PCollection) -> Dict[str, PCollection]: + seed_table = rcoll | 'ReadSeedTable' >> unigraph.ReadTable( + self._data_path, + converters=unigraph.build_converter_from_schema( + self._graph_schema.node_sets['_readout'].features + ), + ) + pcolls_out = {} + pcolls_out['SeedSource'] = seed_table | 'MakeSeedSource' >> beam.Map( + _make_seed_feature, tfgnn.SOURCE_NAME + ) + pcolls_out['SeedTarget'] = seed_table | 'MakeSeedTarget' >> beam.Map( + _make_seed_feature, tfgnn.TARGET_NAME + ) + for feature in self._readout_feature_names: + pcolls_out[f'Seed/{feature}'] = ( + seed_table + | f'MakeSeed/{feature}' >> beam.Map(_make_seed_feature, feature) + ) + return pcolls_out + + def seeds_from_graph_dict( graph_dict: Dict[str, PCollection], sampling_spec: sampler_lib.SamplingSpec) -> PCollection: diff --git a/tensorflow_gnn/experimental/sampler/beam/unigraph_utils_test.py b/tensorflow_gnn/experimental/sampler/beam/unigraph_utils_test.py index 161be3f2..9a16b846 100644 --- a/tensorflow_gnn/experimental/sampler/beam/unigraph_utils_test.py +++ b/tensorflow_gnn/experimental/sampler/beam/unigraph_utils_test.py @@ -1075,6 +1075,160 @@ def test_read_and_convert_edge_features(self): root.run() +class UnigraphUtilsLinkPredictionTest(tf.test.TestCase): + + def setUp(self): + super().setUp() + self.resource_dir = test_utils.get_resource_dir('testdata/homogeneous') + self.seed_path = test_utils.get_resource( + 'testdata/homogeneous/tastelike.csv' + ) + + def test_read_link_seeds(self): + data_path = self.seed_path + graph_schema = tfgnn.GraphSchema() + graph_schema.node_sets['_readout'].features[ + 'weight' + ].dtype = tf.float32.as_datatype_enum + graph_schema.node_sets['_readout'].features[ + '#source' + ].dtype = tf.string.as_datatype_enum + graph_schema.node_sets['_readout'].features[ + '#target' + ].dtype = tf.string.as_datatype_enum + expected_source_seeds = [ + ( + bytes('Samanatsu:Tdaidai', 'utf-8'), + [[ + np.array([b'amanatsu'], dtype=np.object_), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Samanatsu:Tlumia', 'utf-8'), + [[ + np.array([b'amanatsu'], dtype=np.object_), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Skiyomi:Tkomikan', 'utf-8'), + [[ + np.array([b'kiyomi'], dtype=np.object_), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Smandora:Tkomikan', 'utf-8'), + [[ + np.array([b'mandora'], dtype=np.object_), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Smandora:Ttangelo', 'utf-8'), + [[ + np.array([b'mandora'], dtype=np.object_), + np.array([1], dtype=np.int64), + ]], + ), + ] + expected_target_seeds = [ + ( + bytes('Samanatsu:Tdaidai', 'utf-8'), + [[ + np.array([b'daidai'], dtype=np.object_), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Samanatsu:Tlumia', 'utf-8'), + [[ + np.array([b'lumia'], dtype=np.object_), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Skiyomi:Tkomikan', 'utf-8'), + [[ + np.array([b'komikan'], dtype=np.object_), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Smandora:Tkomikan', 'utf-8'), + [[ + np.array([b'komikan'], dtype=np.object_), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Smandora:Ttangelo', 'utf-8'), + [[ + np.array([b'tangelo'], dtype=np.object_), + np.array([1], dtype=np.int64), + ]], + ), + ] + expected_weight_seeds = [ + ( + bytes('Samanatsu:Tdaidai', 'utf-8'), + [[ + np.array([0.1], dtype=np.float32), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Samanatsu:Tlumia', 'utf-8'), + [[ + np.array([0.2], dtype=np.float32), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Skiyomi:Tkomikan', 'utf-8'), + [[ + np.array([0.3], dtype=np.float32), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Smandora:Tkomikan', 'utf-8'), + [[ + np.array([0.4], dtype=np.float32), + np.array([1], dtype=np.int64), + ]], + ), + ( + bytes('Smandora:Ttangelo', 'utf-8'), + [[ + np.array([0.5], dtype=np.float32), + np.array([1], dtype=np.int64), + ]], + ), + ] + with test_pipeline.TestPipeline() as root: + link_seeds = root | 'ReadLinkSeeds' >> unigraph_utils.ReadLinkSeeds( + graph_schema, data_path + ) + util.assert_that( + link_seeds['SeedSource'], + util.equal_to(expected_source_seeds), + label='assert_source_seeds', + ) + util.assert_that( + link_seeds['SeedTarget'], + util.equal_to(expected_target_seeds), + label='assert_target_seeds', + ) + util.assert_that( + link_seeds['Seed/weight'], + util.equal_to(expected_weight_seeds), + label='assert_weight_seeds', + ) + root.run() + + # This function is needed because serialization to bytes in Python # is non-deterministic def _tf_example_from_bytes(s: bytes):