Skip to content

Commit 52447f7

Browse files
committed
feat(upgrade): cache stale timestamp info; remove unnecessary checks to reduce latency
1 parent 7c2edc5 commit 52447f7

File tree

7 files changed

+118
-90
lines changed

7 files changed

+118
-90
lines changed

pychunkedgraph/graph/attributes.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,12 @@ class Hierarchy:
160160
serializer=serializers.NumPyValue(dtype=basetypes.NODE_ID),
161161
)
162162

163+
# track when nodes became stale, required for migration
164+
# will be eventually deleted by GC rule for column family_id 3.
165+
StaleTimeStamp = _Attribute(
166+
key=b"stale_ts", family_id="3", serializer=serializers.Pickle()
167+
)
168+
163169

164170
class GraphMeta:
165171
key = b"meta"

pychunkedgraph/graph/edges/__init__.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -201,22 +201,23 @@ def get_edges(source: str, nodes: np.ndarray) -> Edges:
201201

202202

203203
def get_stale_nodes(
204-
cg, edge_nodes: Iterable[basetypes.NODE_ID], parent_ts: datetime.datetime = None
204+
cg, nodes: Iterable[basetypes.NODE_ID], parent_ts: datetime.datetime = None
205205
):
206206
"""
207-
Checks to see if partner nodes in edges (edges[:,1]) are stale.
208-
This is done by getting a supervoxel of the node and check
207+
Checks to see if given nodes are stale.
208+
This is done by getting a supervoxel of a node and checking
209209
if it has a new parent at the same layer as the node.
210210
"""
211-
edge_supervoxels = cg.get_single_leaf_multiple(edge_nodes)
211+
nodes = np.array(nodes, dtype=basetypes.NODE_ID)
212+
supervoxels = cg.get_single_leaf_multiple(nodes)
212213
# nodes can be at different layers due to skip connections
213-
edge_nodes_layers = cg.get_chunk_layers(edge_nodes)
214+
node_layers = cg.get_chunk_layers(nodes)
214215
stale_nodes = [types.empty_1d]
215-
for layer in np.unique(edge_nodes_layers):
216-
_mask = edge_nodes_layers == layer
217-
layer_nodes = edge_nodes[_mask]
216+
for layer in np.unique(node_layers):
217+
_mask = node_layers == layer
218+
layer_nodes = nodes[_mask]
218219
_nodes = cg.get_roots(
219-
edge_supervoxels[_mask],
220+
supervoxels[_mask],
220221
stop_layer=layer,
221222
ceil=False,
222223
time_stamp=parent_ts,

pychunkedgraph/ingest/cluster.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,7 @@ def upgrade_atomic_chunk(coords: Sequence[int]):
135135
redis = get_redis_connection()
136136
imanager = IngestionManager.from_pickle(redis.get(r_keys.INGESTION_MANAGER))
137137
coords = np.array(list(coords), dtype=int)
138-
update_atomic_chunk(imanager.cg, coords, layer=2)
138+
update_atomic_chunk(imanager.cg, coords)
139139
_post_task_completion(imanager, 2, coords)
140140

141141

Lines changed: 47 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,50 +1,53 @@
11
# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
22

3+
from collections import defaultdict
34
from concurrent.futures import ThreadPoolExecutor, as_completed
5+
from datetime import timedelta
46
import logging, math, time
7+
from copy import copy
58

69
import fastremap
710
import numpy as np
811
from tqdm import tqdm
9-
from pychunkedgraph.graph import ChunkedGraph
12+
from pychunkedgraph.graph import ChunkedGraph, types
1013
from pychunkedgraph.graph.attributes import Connectivity, Hierarchy
1114
from pychunkedgraph.graph.utils import serializers
1215
from pychunkedgraph.utils.general import chunked
1316

14-
from .utils import exists_as_parent, get_end_timestamps, get_parent_timestamps
17+
from .utils import get_end_timestamps, get_parent_timestamps
1518

1619
CHILDREN = {}
1720

1821

1922
def update_cross_edges(
20-
cg: ChunkedGraph, node, cx_edges_d: dict, node_ts, node_end_ts, timestamps: set
23+
cg: ChunkedGraph,
24+
node,
25+
cx_edges_d: dict,
26+
node_ts,
27+
node_end_ts,
28+
timestamps_d: defaultdict[int, set],
2129
) -> list:
2230
"""
2331
Helper function to update a single L2 ID.
2432
Returns a list of mutations with given timestamps.
2533
"""
2634
rows = []
2735
edges = np.concatenate(list(cx_edges_d.values()))
28-
uparents = np.unique(cg.get_parents(edges[:, 0], time_stamp=node_ts))
29-
assert uparents.size <= 1, f"{node}, {node_ts}, {uparents}"
30-
if uparents.size == 0 or node != uparents[0]:
31-
# if node is not the parent at this ts, it must be invalid
32-
assert not exists_as_parent(cg, node, edges[:, 0])
33-
return rows
36+
partners = np.unique(edges[:, 1])
3437

35-
partner_parent_ts_d = get_parent_timestamps(cg, np.unique(edges[:, 1]))
36-
for v in partner_parent_ts_d.values():
37-
timestamps.update(v)
38+
timestamps = copy(timestamps_d[node])
39+
for partner in partners:
40+
timestamps.update(timestamps_d[partner])
3841

3942
for ts in sorted(timestamps):
4043
if ts < node_ts:
4144
continue
4245
if ts > node_end_ts:
4346
break
47+
4448
val_dict = {}
45-
svs = edges[:, 1]
46-
parents = cg.get_parents(svs, time_stamp=ts)
47-
edge_parents_d = dict(zip(svs, parents))
49+
parents = cg.get_parents(partners, time_stamp=ts)
50+
edge_parents_d = dict(zip(partners, parents))
4851
for layer, layer_edges in cx_edges_d.items():
4952
layer_edges = fastremap.remap(
5053
layer_edges, edge_parents_d, preserve_missing_labels=True
@@ -61,20 +64,26 @@ def update_cross_edges(
6164
def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
6265
if children_map is None:
6366
children_map = CHILDREN
64-
end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, children_map)
65-
timestamps_d = get_parent_timestamps(cg, nodes)
67+
end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, children_map, layer=2)
68+
6669
cx_edges_d = cg.get_atomic_cross_edges(nodes)
70+
all_cx_edges = [types.empty_2d]
71+
for _cx_edges_d in cx_edges_d.values():
72+
if _cx_edges_d:
73+
all_cx_edges.append(np.concatenate(list(_cx_edges_d.values())))
74+
all_partners = np.unique(np.concatenate(all_cx_edges)[:, 1])
75+
timestamps_d = get_parent_timestamps(cg, np.concatenate([nodes, all_partners]))
76+
6777
rows = []
6878
for node, node_ts, end_ts in zip(nodes, nodes_ts, end_timestamps):
69-
if cg.get_parent(node) is None:
70-
# invalid id caused by failed ingest task / edits
71-
continue
79+
end_ts -= timedelta(milliseconds=1)
7280
_cx_edges_d = cx_edges_d.get(node, {})
7381
if not _cx_edges_d:
7482
continue
75-
_rows = update_cross_edges(
76-
cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d[node]
77-
)
83+
_rows = update_cross_edges(cg, node, _cx_edges_d, node_ts, end_ts, timestamps_d)
84+
row_id = serializers.serialize_uint64(node)
85+
val_dict = {Hierarchy.StaleTimeStamp: 0}
86+
_rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=end_ts))
7887
rows.extend(_rows)
7988
return rows
8089

@@ -84,9 +93,7 @@ def _update_nodes_helper(args):
8493
return update_nodes(cg, nodes, nodes_ts)
8594

8695

87-
def update_chunk(
88-
cg: ChunkedGraph, chunk_coords: list[int], layer: int = 2, debug: bool = False
89-
):
96+
def update_chunk(cg: ChunkedGraph, chunk_coords: list[int], debug: bool = False):
9097
"""
9198
Iterate over all L2 IDs in a chunk and update their cross chunk edges,
9299
within the periods they were valid/active.
@@ -95,35 +102,41 @@ def update_chunk(
95102

96103
start = time.time()
97104
x, y, z = chunk_coords
98-
chunk_id = cg.get_chunk_id(layer=layer, x=x, y=y, z=z)
105+
chunk_id = cg.get_chunk_id(layer=2, x=x, y=y, z=z)
99106
cg.copy_fake_edges(chunk_id)
100107
rr = cg.range_read_chunk(chunk_id)
101108

102109
nodes = []
103110
nodes_ts = []
104111
earliest_ts = cg.get_earliest_timestamp()
105112
for k, v in rr.items():
106-
nodes.append(k)
107-
CHILDREN[k] = v[Hierarchy.Child][0].value
108-
ts = v[Hierarchy.Child][0].timestamp
109-
nodes_ts.append(earliest_ts if ts < earliest_ts else ts)
113+
try:
114+
_ = v[Hierarchy.Parent]
115+
nodes.append(k)
116+
CHILDREN[k] = v[Hierarchy.Child][0].value
117+
ts = v[Hierarchy.Child][0].timestamp
118+
nodes_ts.append(earliest_ts if ts < earliest_ts else ts)
119+
except KeyError:
120+
# invalid nodes from failed tasks w/o parent column entry
121+
continue
110122

111123
if len(nodes) > 0:
112-
logging.info(f"Processing {len(nodes)} nodes.")
124+
logging.info(f"processing {len(nodes)} nodes.")
113125
assert len(CHILDREN) > 0, (nodes, CHILDREN)
114126
else:
115127
return
116128

117129
if debug:
118130
rows = update_nodes(cg, nodes, nodes_ts)
119131
else:
120-
task_size = int(math.ceil(len(nodes) / 64))
132+
task_size = int(math.ceil(len(nodes) / 16))
121133
chunked_nodes = chunked(nodes, task_size)
122134
chunked_nodes_ts = chunked(nodes_ts, task_size)
123135
tasks = []
124136
for chunk, ts_chunk in zip(chunked_nodes, chunked_nodes_ts):
125137
args = (cg, chunk, ts_chunk)
126138
tasks.append(args)
139+
logging.info(f"task size {task_size}, count {len(tasks)}.")
127140

128141
rows = []
129142
with ThreadPoolExecutor(max_workers=8) as executor:
@@ -132,4 +145,4 @@ def update_chunk(
132145
rows.extend(future.result())
133146

134147
cg.client.write(rows)
135-
print(f"total elaspsed time: {time.time() - start}")
148+
logging.info(f"total elaspsed time: {time.time() - start}")

pychunkedgraph/ingest/upgrade/parent_layer.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
from pychunkedgraph.graph.types import empty_2d
1717
from pychunkedgraph.utils.general import chunked
1818

19-
from .utils import exists_as_parent, get_end_timestamps, get_parent_timestamps
19+
from .utils import get_end_timestamps, get_parent_timestamps
2020

2121

2222
CHILDREN = {}
@@ -64,7 +64,9 @@ def _populate_cx_edges_with_timestamps(
6464
all_children = np.concatenate(list(CHILDREN.values()))
6565
response = cg.client.read_nodes(node_ids=all_children, properties=attrs)
6666
timestamps_d = get_parent_timestamps(cg, nodes)
67-
end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, CHILDREN)
67+
end_timestamps = get_end_timestamps(cg, nodes, nodes_ts, CHILDREN, layer=layer)
68+
69+
rows = []
6870
for node, node_ts, node_end_ts in zip(nodes, nodes_ts, end_timestamps):
6971
CX_EDGES[node] = {}
7072
timestamps = timestamps_d[node]
@@ -81,32 +83,18 @@ def _populate_cx_edges_with_timestamps(
8183
break
8284
CX_EDGES[node][ts] = _get_cx_edges_at_timestamp(node, response, ts)
8385

86+
row_id = serializers.serialize_uint64(node)
87+
val_dict = {Hierarchy.StaleTimeStamp: 0}
88+
rows.append(cg.client.mutate_row(row_id, val_dict, time_stamp=node_end_ts))
89+
cg.client.write(rows)
90+
8491

85-
def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts, earliest_ts) -> list:
92+
def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts) -> list:
8693
"""
8794
Helper function to update a single ID.
8895
Returns a list of mutations with timestamps.
8996
"""
9097
rows = []
91-
if node_ts > earliest_ts:
92-
try:
93-
cx_edges_d = CX_EDGES[node][node_ts]
94-
except KeyError:
95-
raise KeyError(f"{node}:{node_ts}")
96-
edges = np.concatenate([empty_2d] + list(cx_edges_d.values()))
97-
if edges.size:
98-
parents = cg.get_roots(
99-
edges[:, 0], time_stamp=node_ts, stop_layer=layer, ceil=False
100-
)
101-
uparents = np.unique(parents)
102-
layers = cg.get_chunk_layers(uparents)
103-
uparents = uparents[layers == layer]
104-
assert uparents.size <= 1, f"{node}, {node_ts}, {uparents}"
105-
if uparents.size == 0 or node != uparents[0]:
106-
# if node is not the parent at this ts, it must be invalid
107-
assert not exists_as_parent(cg, node, edges[:, 0]), f"{node}, {node_ts}"
108-
return rows
109-
11098
row_id = serializers.serialize_uint64(node)
11199
for ts, cx_edges_d in CX_EDGES[node].items():
112100
if ts < node_ts:
@@ -132,12 +120,12 @@ def update_cross_edges(cg: ChunkedGraph, layer, node, node_ts, earliest_ts) -> l
132120

133121

134122
def _update_cross_edges_helper_thread(args):
135-
cg, layer, node, node_ts, earliest_ts = args
136-
return update_cross_edges(cg, layer, node, node_ts, earliest_ts)
123+
cg, layer, node, node_ts = args
124+
return update_cross_edges(cg, layer, node, node_ts)
137125

138126

139127
def _update_cross_edges_helper(args):
140-
cg_info, layer, nodes, nodes_ts, earliest_ts = args
128+
cg_info, layer, nodes, nodes_ts = args
141129
rows = []
142130
cg = ChunkedGraph(**cg_info)
143131
parents = cg.get_parents(nodes, fail_to_zero=True)
@@ -147,7 +135,7 @@ def _update_cross_edges_helper(args):
147135
if parent == 0:
148136
# invalid id caused by failed ingest task / edits
149137
continue
150-
tasks.append((cg, layer, node, node_ts, earliest_ts))
138+
tasks.append((cg, layer, node, node_ts))
151139

152140
with ThreadPoolExecutor(max_workers=4) as executor:
153141
futures = [executor.submit(_update_cross_edges_helper_thread, task) for task in tasks]
@@ -163,10 +151,10 @@ def update_chunk(
163151
"""
164152
Iterate over all layer IDs in a chunk and update their cross chunk edges.
165153
"""
154+
debug = nodes is not None
166155
start = time.time()
167156
x, y, z = chunk_coords
168157
chunk_id = cg.get_chunk_id(layer=layer, x=x, y=y, z=z)
169-
earliest_ts = cg.get_earliest_timestamp()
170158
_populate_nodes_and_children(cg, chunk_id, nodes=nodes)
171159
if not CHILDREN:
172160
return
@@ -175,23 +163,31 @@ def update_chunk(
175163
nodes_ts = cg.get_node_timestamps(nodes, return_numpy=False, normalize=True)
176164
_populate_cx_edges_with_timestamps(cg, layer, nodes, nodes_ts)
177165

166+
if debug:
167+
rows = []
168+
for node, node_ts in zip(nodes, nodes_ts):
169+
rows.extend(update_cross_edges(cg, layer, node, node_ts))
170+
cg.client.write(rows)
171+
logging.info(f"total elaspsed time: {time.time() - start}")
172+
return
173+
178174
task_size = int(math.ceil(len(nodes) / mp.cpu_count() / 2))
179175
chunked_nodes = chunked(nodes, task_size)
180176
chunked_nodes_ts = chunked(nodes_ts, task_size)
181177
cg_info = cg.get_serialized_info()
182178

183179
tasks = []
184180
for chunk, ts_chunk in zip(chunked_nodes, chunked_nodes_ts):
185-
args = (cg_info, layer, chunk, ts_chunk, earliest_ts)
181+
args = (cg_info, layer, chunk, ts_chunk)
186182
tasks.append(args)
187183

188184
processes = min(mp.cpu_count() * 2, len(tasks))
189-
logging.info(f"Processing {len(nodes)} nodes with {processes} workers.")
185+
logging.info(f"processing {len(nodes)} nodes with {processes} workers.")
190186
with mp.Pool(processes) as pool:
191187
_ = list(
192188
tqdm(
193189
pool.imap_unordered(_update_cross_edges_helper, tasks),
194190
total=len(tasks),
195191
)
196192
)
197-
print(f"total elaspsed time: {time.time() - start}")
193+
logging.info(f"total elaspsed time: {time.time() - start}")

0 commit comments

Comments
 (0)