Skip to content

Commit

Permalink
Support .join().build() without having to sample() between join
Browse files Browse the repository at this point in the history
… and `build` -- this is equivalent to calling `build()` on any of the sampling-step nodes.

PiperOrigin-RevId: 626234505
  • Loading branch information
samihaija authored and tensorflower-gardener committed Apr 19, 2024
1 parent f93fc98 commit 6818eb2
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
6 changes: 5 additions & 1 deletion tensorflow_gnn/sampler/sampling_spec_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,11 +446,12 @@ def merge_then_sample(

return new_step

def join(self, other_steps):
def join(self, other_steps) -> 'Join':
return Join([self] + list(other_steps))


class Join:
"""Joins multiple steps. Acts like _SamplingStep due to `sample` & `build`."""

def __init__(self, steps: List[_SamplingStep]):
if not steps:
Expand All @@ -461,6 +462,9 @@ def sample(self, *sample_args, **sample_kwargs) -> '_SamplingStep':
return self._steps[0].merge_then_sample(
self._steps[1:], *sample_args, **sample_kwargs)

def build(self) -> sampling_spec_pb2.SamplingSpec:
return self._steps[0].build()


def _edge_set_names_by_source(
graph: Union[schema_pb2.GraphSchema, GraphTensor, Any]
Expand Down
52 changes: 47 additions & 5 deletions tensorflow_gnn/sampler/sampling_spec_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,12 +51,12 @@ def get_schema(edge_sets=('AA', 'AB', 'AC', 'BC', 'CD'),

class SamplingSpecBuilderTest(parameterized.TestCase):

def test_line_to_sampling_spec(self):
def test_line_build(self):
schema = get_schema()
builder = sampling_spec_builder.SamplingSpecBuilder(
schema, sampling_spec_pb2.SamplingStrategy.RANDOM_UNIFORM)
proto = (builder.seed('A').sample(5, 'AB').sample(5, 'BC').sample(5, 'CD')
.to_sampling_spec())
.build())

expected_proto = text_format.Parse(
"""
Expand Down Expand Up @@ -88,12 +88,12 @@ def test_line_to_sampling_spec(self):
""", sampling_spec_pb2.SamplingSpec())
self.assertEqual(expected_proto, proto)

def test_dag_to_sampling_spec(self):
def test_dag_build(self):
schema = get_schema()
builder = sampling_spec_builder.SamplingSpecBuilder(schema).seed('A')
path1 = builder.sample(5, 'AB').sample(4, 'BC', op_name='A-B-C')
path2 = builder.sample(7, 'AC', op_name='A-C')
proto = (path1.join([path2]).sample(10, 'CD').to_sampling_spec())
proto = path1.join([path2]).sample(10, 'CD').build()

expected_proto = text_format.Parse(
"""
Expand Down Expand Up @@ -133,6 +133,48 @@ def test_dag_to_sampling_spec(self):
""", sampling_spec_pb2.SamplingSpec())
self.assertEqual(expected_proto, proto)

def test_build_right_after_join(self):
schema = get_schema()
builder = sampling_spec_builder.SamplingSpecBuilder(schema).seed('A')
path1 = builder.sample(5, 'AB').sample(4, 'BC', op_name='A-B-C')
path2 = builder.sample(7, 'AC', op_name='A-C')
path1_build_proto = path1.build()
path2_build_proto = path2.build()
self.assertEqual(path1.build(), path2.build())
join_build_proto = path1.join([path2]).build()
self.assertEqual(join_build_proto, path1_build_proto)
self.assertEqual(join_build_proto, path2_build_proto)

expected_proto = text_format.Parse(
"""
seed_op {
op_name: "SEED->A"
node_set_name: "A"
}
sampling_ops {
op_name: "A-C"
input_op_names: "SEED->A"
edge_set_name: "AC"
sample_size: 7
strategy: TOP_K
}
sampling_ops {
op_name: "A->B"
input_op_names: "SEED->A"
edge_set_name: "AB"
sample_size: 5
strategy: TOP_K
}
sampling_ops {
op_name: "A-B-C"
input_op_names: "A->B"
edge_set_name: "BC"
sample_size: 4
strategy: TOP_K
}
""", sampling_spec_pb2.SamplingSpec())
self.assertEqual(expected_proto, join_build_proto)

def test_sample_with_list_of_sizes(self):
schema = get_schema()
proto = (sampling_spec_builder.SamplingSpecBuilder(schema).seed('A')
Expand Down Expand Up @@ -170,7 +212,7 @@ def test_sample_with_list_of_sizes(self):
def test_no_required_edgeset_or_nodeset_names_for_homogeneous_graph(self):
schema = get_schema(edge_sets=['AA']) # Homogeneous graph.
proto = (sampling_spec_builder.SamplingSpecBuilder(schema)
.seed().sample([10, 5]).sample([2, 1]).to_sampling_spec())
.seed().sample([10, 5]).sample([2, 1]).build())
# # ^ could be combined with previous sample.

expected_proto = text_format.Parse(
Expand Down

0 comments on commit 6818eb2

Please sign in to comment.