From fa34af9cbf24c69719d11464b62d0260e05b63ab Mon Sep 17 00:00:00 2001 From: Graph Learning Team Date: Thu, 25 Apr 2024 04:32:23 -0700 Subject: [PATCH] Adds interfaces for temporal graph sampling PiperOrigin-RevId: 628032219 --- .../experimental/sampler/interfaces.py | 90 +++++++++++++++++-- 1 file changed, 84 insertions(+), 6 deletions(-) diff --git a/tensorflow_gnn/experimental/sampler/interfaces.py b/tensorflow_gnn/experimental/sampler/interfaces.py index 598bb5a3..02f33764 100644 --- a/tensorflow_gnn/experimental/sampler/interfaces.py +++ b/tensorflow_gnn/experimental/sampler/interfaces.py @@ -52,6 +52,7 @@ class SamplingPrimitive(abc.ABC): """Base class for all sampling primitives.""" + pass @@ -129,8 +130,7 @@ def call(self, keys: tf.RaggedTensor) -> tf.RaggedTensor: class KeyToFeaturesAccessor(AccessorBase): - """Generic key to features dict accessor. - """ + """Generic key to features dict accessor.""" @abc.abstractmethod def call(self, keys: tf.RaggedTensor) -> Features: @@ -165,7 +165,6 @@ def call(self, source_node_ids: tf.RaggedTensor) -> Features: shape `[batch_size, (num_source_nodes)]` and tf.int32, tf.int64 or tf.string type. - Returns: `Features` containing the subset of all edges whose source nodes are in `source_node_ids`. All returned features must have shape @@ -182,6 +181,47 @@ def edge_set_name(self) -> str: raise NotImplementedError +class TemporalOutgoingEdgesSampler(SamplingPrimitive): + """Samples outgoing edges for given source nodes at a specific point in time. + + Used to create rooted subgraphs from temporal graphs as of a specific point in + time. + """ + + def __call__( + self, source_node_ids: tf.RaggedTensor, *, timestamps: tf.Tensor + ) -> Features: + return self.call(source_node_ids=source_node_ids, timestamps=timestamps) + + @abc.abstractmethod + def call( + self, source_node_ids: tf.RaggedTensor, *, timestamps: tf.Tensor + ) -> Features: + """Samples outgoing edges for the given source node ids and timestamps. + + Args: + source_node_ids: node ids for sampling outgoing edges. Ragged tensor with + shape `[batch_size, (num_source_nodes)]` and tf.int32, tf.int64 or + tf.string dtype. + timestamps: points in time to filter edge states. Tensor with shape + `[batch_size]` and tf.int32, tf.int64 dtype. + + Returns: + `Features` containing a subset of all edges whose source nodes are in + `source_node_ids` as of their state as at `timestamps`. All returned + features have shape `[batch_size, (num_edges), ...]`. The result includes + two special features "#source" and "#target" of rank 2 containing, + respectively, source node ids and target node ids of the sampled edges. + """ + raise NotImplementedError + + @property + @abc.abstractmethod + def edge_set_name(self) -> str: + """The edge set name.""" + raise NotImplementedError + + class UniformEdgesSampler(OutgoingEdgesSampler): """Samples up to the `sample_size` outgoing edges uniformly at random.""" @@ -198,6 +238,22 @@ def edge_target_feature_name(self) -> str: raise NotImplementedError +class TemporalUniformEdgesSampler(TemporalOutgoingEdgesSampler): + """Samples up to the `sample_size` outgoing edges uniformly at random.""" + + @property + @abc.abstractmethod + def sample_size(self) -> int: + """The maximum number of edges to sample.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def edge_target_feature_name(self) -> str: + """The input feature name containing edge target node ids.""" + raise NotImplementedError + + class TopKEdgesSampler(OutgoingEdgesSampler): """Samples up to the `sample_size` top weighted outgoing edges.""" @@ -220,6 +276,28 @@ def weight_feature_name(self) -> str: raise NotImplementedError +class TemporalTopKEdgesSampler(TemporalOutgoingEdgesSampler): + """Samples up to the `sample_size` top weighted outgoing edges.""" + + @property + @abc.abstractmethod + def sample_size(self) -> int: + """The maximum number of edges to sample.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def edge_target_feature_name(self) -> str: + """The input feature name containing edge target node ids.""" + raise NotImplementedError + + @property + @abc.abstractmethod + def weight_feature_name(self) -> str: + """The input feature name containing edge weights.""" + raise NotImplementedError + + class ConnectingEdgesSampler(SamplingPrimitive): """Samples incident edges between given subsets of source and target nodes. @@ -229,9 +307,9 @@ class ConnectingEdgesSampler(SamplingPrimitive): """ @abc.abstractmethod - def call(self, - source_node_ids: tf.RaggedTensor, - target_node_ids: tf.RaggedTensor) -> Features: + def call( + self, source_node_ids: tf.RaggedTensor, target_node_ids: tf.RaggedTensor + ) -> Features: """Samples incident edges *from* source *on* target node ids. Each sampled edges has its source in the `source_node_ids` and its target in