Skip to content

Commit 52708ab

Browse files
committed
fix(upgrade): remove unnecessary checks to reduce latency
1 parent 7c2edc5 commit 52708ab

File tree

5 files changed

+43
-62
lines changed

5 files changed

+43
-62
lines changed

pychunkedgraph/ingest/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def upgrade_atomic_chunk(coords: Sequence[int]):
135135
redis = get_redis_connection()
136136
imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER))
137137
coords = np.array(list(coords), dtype=int)
138-
update_atomic_chunk(imanager.cg, coords, layer=2)
138+
update_atomic_chunk(imanager.cg, coords)
139139
_post_task_completion(imanager, 2, coords)
140140

141141

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,52 @@
11
# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
22

3+
from collections import defaultdict
34
from concurrent.futures import ThreadPoolExecutor, as_completed
45
import logging, math, time
6+
from copy import copy
57

68
import fastremap
79
import numpy as np
810
from tqdm import tqdm
9-
from pychunkedgraph.graph import ChunkedGraph
11+
from pychunkedgraph.graph import ChunkedGraph, types
1012
from pychunkedgraph.graph.attributes import Connectivity, Hierarchy
1113
from pychunkedgraph.graph.utils import serializers
1214
from pychunkedgraph.utils.general import chunked
1315

14-
from .utils import exists_as_parent, get_end_timestamps, get_parent_timestamps
16+
from .utils import get_end_timestamps, get_parent_timestamps
1517

1618
CHILDREN = {}
1719

1820

1921
def update_cross_edges(
20-
cg: ChunkedGraph, node, cx_edges_d: dict, node_ts, node_end_ts, timestamps: set
22+
cg: ChunkedGraph,
23+
node,
24+
cx_edges_d: dict,
25+
node_ts,
26+
node_end_ts,
27+
timestamps_d: defaultdict[int, set],
2128
) -> list:
2229
"""
2330
Helper function to update a single L2 ID.
2431
Returns a list of mutations with given timestamps.
2532
"""
2633
rows = []
2734
edges = np.concatenate(list(cx_edges_d.values()))
28-
uparents = np.unique(cg.get_parents(edges[:, 0], time_stamp=node_ts))
29-
assert uparents.size <= 1, f"{node}, {node_ts}, {uparents}"
30-
if uparents.size == 0 or node != uparents[0]:
31-
# if node is not the parent at this ts, it must be invalid
32-
assert not exists_as_parent(cg, node, edges[:, 0])
33-
return rows
35+
partners = np.unique(edges[:, 1])
3436

35-
partner_parent_ts_d = get_parent_timestamps(cg, np.unique(edges[:, 1]))
36-
for v in partner_parent_ts_d.values():
37-
timestamps.update(v)
37+
timestamps = copy(timestamps_d[node])
38+
for partner in partners:
39+
timestamps.update(timestamps_d[partner])
3840

3941
for ts in sorted(timestamps):
4042
if ts < node_ts:
4143
continue
4244
if ts > node_end_ts:
4345
break
4446
val_dict = {}
45-
svs = edges[:, 1]
46-
parents = cg.get_parents(svs, time_stamp=ts)
47-
edge_parents_d = dict(zip(svs, parents))
47+
48+
parents = cg.get_parents(partners, time_stamp=ts)
49+
edge_parents_d = dict(zip(partners, parents))
4850
for layer, layer_edges in cx_edges_d.items():
4951
layer_edges = fastremap.remap(
5052
layer_edges, edge_parents_d, preserve_missing_labels=True
@@ -62,19 +64,21 @@ def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
6264
if children_map is None:
6365
children_map = CHILDREN
6466
end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, children_map)
65-
timestamps_d = get_parent_timestamps(cg, nodes)
67+
6668
cx_edges_d = cg.get_atomic_cross_edges(nodes)
69+
all_cx_edges = [types.empty_2d]
70+
for _cx_edges_d in cx_edges_d.values():
71+
if _cx_edges_d:
72+
all_cx_edges.append(np.concatenate(list(_cx_edges_d.values())))
73+
all_partners = np.unique(np.concatenate(all_cx_edges)[:, 1])
74+
timestamps_d = get_parent_timestamps(cg, np.concatenate([nodes, all_partners]))
75+
6776
rows = []
6877
for node, node_ts, end_ts in zip(nodes, nodes_ts, end_timestamps):
69-
if cg.get_parent(node) is None:
70-
# invalid id caused by failed ingest task / edits
71-
continue
7278
_cx_edges_d = cx_edges_d.get(node, {})
7379
if not _cx_edges_d:
7480
continue
75-
_rows = update_cross_edges(
76-
cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d[node]
77-
)
81+
_rows = update_cross_edges(cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d)
7882
rows.extend(_rows)
7983
return rows
8084

@@ -84,9 +88,7 @@ def _update_nodes_helper(args):
8488
return update_nodes(cg, nodes, nodes_ts)
8589

8690

87-
def update_chunk(
88-
cg: ChunkedGraph, chunk_coords: list[int], layer: int = 2, debug: bool = False
89-
):
91+
def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], debug: bool = False):
9092
"""
9193
Iterate over all L2 IDs in a chunk and update their cross chunk edges,
9294
within the periods they were valid/active.
@@ -95,35 +97,40 @@ def update_chunk(
9597

9698
start = time.time()
9799
x, y, z = chunk_coords
98-
chunk_id = cg.get_chunk_id(layer=layer, x=x, y=y, z=z)
100+
chunk_id = cg.get_chunk_id(layer=2, x=x, y=y, z=z)
99101
cg.copy_fake_edges(chunk_id)
100102
rr = cg.range_read_chunk(chunk_id)
101103

102104
nodes = []
103105
nodes_ts = []
104106
earliest_ts = cg.get_earliest_timestamp()
105107
for k, v in rr.items():
106-
nodes.append(k)
107-
CHILDREN[k] = v[Hierarchy.Child][0].value
108-
ts = v[Hierarchy.Child][0].timestamp
109-
nodes_ts.append(earliest_ts if ts < earliest_ts else ts)
108+
try:
109+
_ = v[Hierarchy.Parent]
110+
nodes.append(k)
111+
CHILDREN[k] = v[Hierarchy.Child][0].value
112+
ts = v[Hierarchy.Child][0].timestamp
113+
nodes_ts.append(earliest_ts if ts < earliest_ts else ts)
114+
except KeyError:
115+
continue
110116

111117
if len(nodes) > 0:
112-
logging.info(f"Processing {len(nodes)} nodes.")
118+
logging.info(f"processing {len(nodes)} nodes.")
113119
assert len(CHILDREN) > 0, (nodes, CHILDREN)
114120
else:
115121
return
116122

117123
if debug:
118124
rows = update_nodes(cg, nodes, nodes_ts)
119125
else:
120-
task_size = int(math.ceil(len(nodes) / 64))
126+
task_size = int(math.ceil(len(nodes) / 16))
121127
chunked_nodes = chunked(nodes, task_size)
122128
chunked_nodes_ts = chunked(nodes_ts, task_size)
123129
tasks = []
124130
for chunk, ts_chunk in zip(chunked_nodes, chunked_nodes_ts):
125131
args = (cg, chunk, ts_chunk)
126132
tasks.append(args)
133+
logging.info(f"task size {task_size}, count {len(tasks)}.")
127134

128135
rows = []
129136
with ThreadPoolExecutor(max_workers=8) as executor:
@@ -132,4 +139,4 @@ def update_chunk(
132139
rows.extend(future.result())
133140

134141
cg.client.write(rows)
135-
print(f"total elaspsed time: {time.time() - start}")
142+
logging.info(f"total elaspsed time: {time.time() - start}")

pychunkedgraph/ingest/upgrade/parent_layer.py

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -88,25 +88,6 @@ def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts, earliest_ts) -> l
8888
Returns a list of mutations with timestamps.
8989
"""
9090
rows = []
91-
if node_ts > earliest_ts:
92-
try:
93-
cx_edges_d = CX_EDGES[node][node_ts]
94-
except KeyError:
95-
raise KeyError(f"{node}:{node_ts}")
96-
edges = np.concatenate([empty_2d] + list(cx_edges_d.values()))
97-
if edges.size:
98-
parents = cg.get_roots(
99-
edges[:, 0], time_stamp=node_ts, stop_layer=layer, ceil=False
100-
)
101-
uparents = np.unique(parents)
102-
layers = cg.get_chunk_layers(uparents)
103-
uparents = uparents[layers == layer]
104-
assert uparents.size <= 1, f"{node}, {node_ts}, {uparents}"
105-
if uparents.size == 0 or node != uparents[0]:
106-
# if node is not the parent at this ts, it must be invalid
107-
assert not exists_as_parent(cg, node, edges[:, 0]), f"{node}, {node_ts}"
108-
return rows
109-
11091
row_id = serializers.serialize_uint64(node)
11192
for ts, cx_edges_d in CX_EDGES[node].items():
11293
if ts < node_ts:
@@ -186,7 +167,7 @@ def update_chunk(
186167
tasks.append(args)
187168

188169
processes = min(mp.cpu_count() * 2, len(tasks))
189-
logging.info(f"Processing {len(nodes)} nodes with {processes} workers.")
170+
logging.info(f"processing {len(nodes)} nodes with {processes} workers.")
190171
with mp.Pool(processes) as pool:
191172
_ = list(
192173
tqdm(

pychunkedgraph/ingest/upgrade/utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ def get_end_timestamps(cg: ChunkedGraph, nodes, nodes_ts, children_map):
6161

6262
def get_parent_timestamps(
6363
cg: ChunkedGraph, nodes, start_time=None, end_time=None
64-
) -> dict[int, set]:
64+
) -> defaultdict[int, set]:
6565
"""
6666
Timestamps of when the given nodes were edited.
6767
"""

pychunkedgraph/utils/general.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ def reverse_dictionary(dictionary):
2626

2727

2828
def chunked(l: Sequence, n: int):
29-
"""
30-
Yield successive n-sized chunks from l.
31-
NOTE: Use itertools.batched from python 3.12
32-
"""
3329
"""
3430
Yield successive n-sized chunks from l.
3531
NOTE: Use itertools.batched from python 3.12
@@ -39,9 +35,6 @@ def chunked(l: Sequence, n: int):
3935
it = iter(l)
4036
while batch := tuple(islice(it, n)):
4137
yield batch
42-
it = iter(l)
43-
while batch := tuple(islice(it, n)):
44-
yield batch
4538

4639

4740
def in2d(arr1: np.ndarray, arr2: np.ndarray) -> np.ndarray:

0 commit comments

Comments
 (0)