Skip to content

Commit 72cb8fd

Browse files
Neural-Link Teamtensorflow-copybara
Neural-Link Team
authored andcommitted
Fixes the NSL graph builder to ignore lsh_rounds when lsh_splits < 1.
This fixes a regression introduced in NSL v1.3.0 in which twice the work was being performed in the default case. As a workaround, just specify lsh_rounds=1 when lsh_splits=0. PiperOrigin-RevId: 327063409
1 parent b67b210 commit 72cb8fd

File tree

2 files changed

+12
-1
lines changed

2 files changed

+12
-1
lines changed

neural_structured_learning/tools/build_graph.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -219,7 +219,10 @@ def _generate_edges(self, embeddings):
219219
A tuple (source, target, weight) denoting a (directed) edge from 'source'
220220
to 'target' with the given 'weight'.
221221
"""
222-
for lsh_round in range(max(1, self.config.lsh_rounds)):
222+
# If lsh_splits < 1, we ignore lsh_rounds and always perform 1 round, since
223+
# performing multiple rounds in the case of no splits does not help.
224+
rounds = self.config.lsh_rounds if self.config.lsh_splits > 0 else 1
225+
for lsh_round in range(rounds):
223226
start_time = time.time()
224227
edge_cnt = 0
225228
bucket_map = self._generate_lsh_buckets(embeddings)

neural_structured_learning/tools/build_graph_test.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,10 @@ def _create_embedding_file(self):
8686
def _create_graph_file(self):
8787
return self.create_tempfile('graph.tsv').full_path
8888

89+
def _num_file_lines(self, graph_path):
90+
with open(graph_path, 'rU') as f:
91+
return sum(1 for _ in f)
92+
8993
def testBuildGraphInvalidLshBitsValue(self):
9094
with self.assertRaises(ValueError):
9195
build_graph_lib.build_graph([], None, lsh_splits=-1)
@@ -103,6 +107,7 @@ def testBuildGraphNoThresholdingNoLSH(self):
103107
build_graph_lib.build_graph([embedding_path],
104108
graph_path,
105109
similarity_threshold=0)
110+
self.assertEqual(self._num_file_lines(graph_path), 6)
106111
g_actual = graph_utils.read_tsv_graph(graph_path)
107112
self.assertDictEqual(
108113
g_actual, {
@@ -129,6 +134,7 @@ def testBuildGraphWithThresholdingNoLSH(self):
129134
build_graph_lib.build_graph([embedding_path],
130135
graph_path,
131136
similarity_threshold=0.51)
137+
self.assertEqual(self._num_file_lines(graph_path), 0)
132138
g_actual = graph_utils.read_tsv_graph(graph_path)
133139
self.assertDictEqual(g_actual, {})
134140

@@ -178,6 +184,7 @@ def testBuildGraphWithThresholdWithLSHInsufficientLSHRounds(self):
178184
lsh_splits=2,
179185
lsh_rounds=1,
180186
random_seed=12345)
187+
self.assertEqual(self._num_file_lines(graph_path), num_points * 2 - 8)
181188
g_actual = graph_utils.read_tsv_graph(graph_path)
182189

183190
# Check that the graph contains fewer than 2 * N edges
@@ -203,6 +210,7 @@ def testBuildGraphWithThresholdWithLSHSufficientLSHRounds(self):
203210
lsh_splits=2,
204211
lsh_rounds=4,
205212
random_seed=12345)
213+
self.assertEqual(self._num_file_lines(graph_path), num_points * 2)
206214
g_actual = graph_utils.read_tsv_graph(graph_path)
207215

208216
# Constuct the expected graph: each point should be a neighbor of the

0 commit comments

Comments
 (0)