11# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
22
3+ from collections import defaultdict
34from concurrent .futures import ThreadPoolExecutor , as_completed
5+ from datetime import timedelta
46import logging , math , time
7+ from copy import copy
58
69import fastremap
710import numpy as np
811from tqdm import tqdm
9- from pychunkedgraph .graph import ChunkedGraph
12+ from pychunkedgraph .graph import ChunkedGraph , types
1013from pychunkedgraph .graph .attributes import Connectivity , Hierarchy
1114from pychunkedgraph .graph .utils import serializers
1215from pychunkedgraph .utils .general import chunked
1316
14- from .utils import exists_as_parent , get_end_timestamps , get_parent_timestamps
17+ from .utils import get_end_timestamps , get_parent_timestamps
1518
1619CHILDREN = {}
1720
1821
1922def update_cross_edges (
20- cg : ChunkedGraph , node , cx_edges_d : dict , node_ts , node_end_ts , timestamps : set
23+ cg : ChunkedGraph ,
24+ node ,
25+ cx_edges_d : dict ,
26+ node_ts ,
27+ node_end_ts ,
28+ timestamps_d : defaultdict [int , set ],
2129) -> list :
2230 """
2331 Helper function to update a single L2 ID.
2432 Returns a list of mutations with given timestamps.
2533 """
2634 rows = []
2735 edges = np .concatenate (list (cx_edges_d .values ()))
28- uparents = np .unique (cg .get_parents (edges [:, 0 ], time_stamp = node_ts ))
29- assert uparents .size <= 1 , f"{ node } , { node_ts } , { uparents } "
30- if uparents .size == 0 or node != uparents [0 ]:
31- # if node is not the parent at this ts, it must be invalid
32- assert not exists_as_parent (cg , node , edges [:, 0 ])
33- return rows
36+ partners = np .unique (edges [:, 1 ])
3437
35- partner_parent_ts_d = get_parent_timestamps ( cg , np . unique ( edges [:, 1 ]) )
36- for v in partner_parent_ts_d . values () :
37- timestamps .update (v )
38+ timestamps = copy ( timestamps_d [ node ] )
39+ for partner in partners :
40+ timestamps .update (timestamps_d [ partner ] )
3841
3942 for ts in sorted (timestamps ):
4043 if ts < node_ts :
4144 continue
4245 if ts > node_end_ts :
4346 break
47+
4448 val_dict = {}
45- svs = edges [:, 1 ]
46- parents = cg .get_parents (svs , time_stamp = ts )
47- edge_parents_d = dict (zip (svs , parents ))
49+ parents = cg .get_parents (partners , time_stamp = ts )
50+ edge_parents_d = dict (zip (partners , parents ))
4851 for layer , layer_edges in cx_edges_d .items ():
4952 layer_edges = fastremap .remap (
5053 layer_edges , edge_parents_d , preserve_missing_labels = True
@@ -61,20 +64,26 @@ def update_cross_edges(
6164def update_nodes (cg : ChunkedGraph , nodes , nodes_ts , children_map = None ) -> list :
6265 if children_map is None :
6366 children_map = CHILDREN
64- end_timestamps = get_end_timestamps (cg , nodes , nodes_ts , children_map )
65- timestamps_d = get_parent_timestamps ( cg , nodes )
67+ end_timestamps = get_end_timestamps (cg , nodes , nodes_ts , children_map , layer = 2 )
68+
6669 cx_edges_d = cg .get_atomic_cross_edges (nodes )
70+ all_cx_edges = [types .empty_2d ]
71+ for _cx_edges_d in cx_edges_d .values ():
72+ if _cx_edges_d :
73+ all_cx_edges .append (np .concatenate (list (_cx_edges_d .values ())))
74+ all_partners = np .unique (np .concatenate (all_cx_edges )[:, 1 ])
75+ timestamps_d = get_parent_timestamps (cg , np .concatenate ([nodes , all_partners ]))
76+
6777 rows = []
6878 for node , node_ts , end_ts in zip (nodes , nodes_ts , end_timestamps ):
69- if cg .get_parent (node ) is None :
70- # invalid id caused by failed ingest task / edits
71- continue
79+ end_ts -= timedelta (milliseconds = 1 )
7280 _cx_edges_d = cx_edges_d .get (node , {})
7381 if not _cx_edges_d :
7482 continue
75- _rows = update_cross_edges (
76- cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d [node ]
77- )
83+ _rows = update_cross_edges (cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d )
84+ row_id = serializers .serialize_uint64 (node )
85+ val_dict = {Hierarchy .StaleTimeStamp : 0 }
86+ _rows .append (cg .client .mutate_row (row_id , val_dict , time_stamp = end_ts ))
7887 rows .extend (_rows )
7988 return rows
8089
@@ -84,9 +93,7 @@ def _update_nodes_helper(args):
8493 return update_nodes (cg , nodes , nodes_ts )
8594
8695
87- def update_chunk (
88- cg : ChunkedGraph , chunk_coords : list [int ], layer : int = 2 , debug : bool = False
89- ):
96+ def update_chunk (cg : ChunkedGraph , chunk_coords : list [int ], debug : bool = False ):
9097 """
9198 Iterate over all L2 IDs in a chunk and update their cross chunk edges,
9299 within the periods they were valid/active.
@@ -95,35 +102,41 @@ def update_chunk(
95102
96103 start = time .time ()
97104 x , y , z = chunk_coords
98- chunk_id = cg .get_chunk_id (layer = layer , x = x , y = y , z = z )
105+ chunk_id = cg .get_chunk_id (layer = 2 , x = x , y = y , z = z )
99106 cg .copy_fake_edges (chunk_id )
100107 rr = cg .range_read_chunk (chunk_id )
101108
102109 nodes = []
103110 nodes_ts = []
104111 earliest_ts = cg .get_earliest_timestamp ()
105112 for k , v in rr .items ():
106- nodes .append (k )
107- CHILDREN [k ] = v [Hierarchy .Child ][0 ].value
108- ts = v [Hierarchy .Child ][0 ].timestamp
109- nodes_ts .append (earliest_ts if ts < earliest_ts else ts )
113+ try :
114+ _ = v [Hierarchy .Parent ]
115+ nodes .append (k )
116+ CHILDREN [k ] = v [Hierarchy .Child ][0 ].value
117+ ts = v [Hierarchy .Child ][0 ].timestamp
118+ nodes_ts .append (earliest_ts if ts < earliest_ts else ts )
119+ except KeyError :
120+ # invalid nodes from failed tasks w/o parent column entry
121+ continue
110122
111123 if len (nodes ) > 0 :
112- logging .info (f"Processing { len (nodes )} nodes." )
124+ logging .info (f"processing { len (nodes )} nodes." )
113125 assert len (CHILDREN ) > 0 , (nodes , CHILDREN )
114126 else :
115127 return
116128
117129 if debug :
118130 rows = update_nodes (cg , nodes , nodes_ts )
119131 else :
120- task_size = int (math .ceil (len (nodes ) / 64 ))
132+ task_size = int (math .ceil (len (nodes ) / 16 ))
121133 chunked_nodes = chunked (nodes , task_size )
122134 chunked_nodes_ts = chunked (nodes_ts , task_size )
123135 tasks = []
124136 for chunk , ts_chunk in zip (chunked_nodes , chunked_nodes_ts ):
125137 args = (cg , chunk , ts_chunk )
126138 tasks .append (args )
139+ logging .info (f"task size { task_size } , count { len (tasks )} ." )
127140
128141 rows = []
129142 with ThreadPoolExecutor (max_workers = 8 ) as executor :
@@ -132,4 +145,4 @@ def update_chunk(
132145 rows .extend (future .result ())
133146
134147 cg .client .write (rows )
135- print (f"total elaspsed time: { time .time () - start } " )
148+ logging . info (f"total elaspsed time: { time .time () - start } " )
0 commit comments