11# pylint: disable=invalid-name, missing-docstring, c-extension-no-member
22
3+ from collections import defaultdict
34from concurrent .futures import ThreadPoolExecutor , as_completed
45import logging , math , time
6+ from copy import copy
57
68import fastremap
79import numpy as np
810from tqdm import tqdm
9- from pychunkedgraph .graph import ChunkedGraph
11+ from pychunkedgraph .graph import ChunkedGraph , types
1012from pychunkedgraph .graph .attributes import Connectivity , Hierarchy
1113from pychunkedgraph .graph .utils import serializers
1214from pychunkedgraph .utils .general import chunked
1315
14- from .utils import exists_as_parent , get_end_timestamps , get_parent_timestamps
16+ from .utils import get_end_timestamps , get_parent_timestamps
1517
1618CHILDREN = {}
1719
1820
1921def update_cross_edges (
20- cg : ChunkedGraph , node , cx_edges_d : dict , node_ts , node_end_ts , timestamps : set
22+ cg : ChunkedGraph ,
23+ node ,
24+ cx_edges_d : dict ,
25+ node_ts ,
26+ node_end_ts ,
27+ timestamps_d : defaultdict [int , set ],
2128) -> list :
2229 """
2330 Helper function to update a single L2 ID.
2431 Returns a list of mutations with given timestamps.
2532 """
2633 rows = []
2734 edges = np .concatenate (list (cx_edges_d .values ()))
28- uparents = np .unique (cg .get_parents (edges [:, 0 ], time_stamp = node_ts ))
29- assert uparents .size <= 1 , f"{ node } , { node_ts } , { uparents } "
30- if uparents .size == 0 or node != uparents [0 ]:
31- # if node is not the parent at this ts, it must be invalid
32- assert not exists_as_parent (cg , node , edges [:, 0 ])
33- return rows
35+ partners = np .unique (edges [:, 1 ])
3436
35- partner_parent_ts_d = get_parent_timestamps ( cg , np . unique ( edges [:, 1 ]) )
36- for v in partner_parent_ts_d . values () :
37- timestamps .update (v )
37+ timestamps = copy ( timestamps_d [ node ] )
38+ for partner in partners :
39+ timestamps .update (timestamps_d [ partner ] )
3840
3941 for ts in sorted (timestamps ):
4042 if ts < node_ts :
4143 continue
4244 if ts > node_end_ts :
4345 break
4446 val_dict = {}
45- svs = edges [:, 1 ]
46- parents = cg .get_parents (svs , time_stamp = ts )
47- edge_parents_d = dict (zip (svs , parents ))
47+
48+ parents = cg .get_parents (partners , time_stamp = ts )
49+ edge_parents_d = dict (zip (partners , parents ))
4850 for layer , layer_edges in cx_edges_d .items ():
4951 layer_edges = fastremap .remap (
5052 layer_edges , edge_parents_d , preserve_missing_labels = True
@@ -62,19 +64,21 @@ def update_nodes(cg: ChunkedGraph, nodes, nodes_ts, children_map=None) -> list:
6264 if children_map is None :
6365 children_map = CHILDREN
6466 end_timestamps = get_end_timestamps (cg , nodes , nodes_ts , children_map )
65- timestamps_d = get_parent_timestamps ( cg , nodes )
67+
6668 cx_edges_d = cg .get_atomic_cross_edges (nodes )
69+ all_cx_edges = [types .empty_2d ]
70+ for _cx_edges_d in cx_edges_d .values ():
71+ if _cx_edges_d :
72+ all_cx_edges .append (np .concatenate (list (_cx_edges_d .values ())))
73+ all_partners = np .unique (np .concatenate (all_cx_edges )[:, 1 ])
74+ timestamps_d = get_parent_timestamps (cg , np .concatenate ([nodes , all_partners ]))
75+
6776 rows = []
6877 for node , node_ts , end_ts in zip (nodes , nodes_ts , end_timestamps ):
69- if cg .get_parent (node ) is None :
70- # invalid id caused by failed ingest task / edits
71- continue
7278 _cx_edges_d = cx_edges_d .get (node , {})
7379 if not _cx_edges_d :
7480 continue
75- _rows = update_cross_edges (
76- cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d [node ]
77- )
81+ _rows = update_cross_edges (cg , node , _cx_edges_d , node_ts , end_ts , timestamps_d )
7882 rows .extend (_rows )
7983 return rows
8084
@@ -84,9 +88,7 @@ def _update_nodes_helper(args):
8488 return update_nodes (cg , nodes , nodes_ts )
8589
8690
87- def update_chunk (
88- cg : ChunkedGraph , chunk_coords : list [int ], layer : int = 2 , debug : bool = False
89- ):
91+ def update_chunk (cg : ChunkedGraph , chunk_coords : list [int ], debug : bool = False ):
9092 """
9193 Iterate over all L2 IDs in a chunk and update their cross chunk edges,
9294 within the periods they were valid/active.
@@ -95,35 +97,40 @@ def update_chunk(
9597
9698 start = time .time ()
9799 x , y , z = chunk_coords
98- chunk_id = cg .get_chunk_id (layer = layer , x = x , y = y , z = z )
100+ chunk_id = cg .get_chunk_id (layer = 2 , x = x , y = y , z = z )
99101 cg .copy_fake_edges (chunk_id )
100102 rr = cg .range_read_chunk (chunk_id )
101103
102104 nodes = []
103105 nodes_ts = []
104106 earliest_ts = cg .get_earliest_timestamp ()
105107 for k , v in rr .items ():
106- nodes .append (k )
107- CHILDREN [k ] = v [Hierarchy .Child ][0 ].value
108- ts = v [Hierarchy .Child ][0 ].timestamp
109- nodes_ts .append (earliest_ts if ts < earliest_ts else ts )
108+ try :
109+ _ = v [Hierarchy .Parent ]
110+ nodes .append (k )
111+ CHILDREN [k ] = v [Hierarchy .Child ][0 ].value
112+ ts = v [Hierarchy .Child ][0 ].timestamp
113+ nodes_ts .append (earliest_ts if ts < earliest_ts else ts )
114+ except KeyError :
115+ continue
110116
111117 if len (nodes ) > 0 :
112- logging .info (f"Processing { len (nodes )} nodes." )
118+ logging .info (f"processing { len (nodes )} nodes." )
113119 assert len (CHILDREN ) > 0 , (nodes , CHILDREN )
114120 else :
115121 return
116122
117123 if debug :
118124 rows = update_nodes (cg , nodes , nodes_ts )
119125 else :
120- task_size = int (math .ceil (len (nodes ) / 64 ))
126+ task_size = int (math .ceil (len (nodes ) / 16 ))
121127 chunked_nodes = chunked (nodes , task_size )
122128 chunked_nodes_ts = chunked (nodes_ts , task_size )
123129 tasks = []
124130 for chunk , ts_chunk in zip (chunked_nodes , chunked_nodes_ts ):
125131 args = (cg , chunk , ts_chunk )
126132 tasks .append (args )
133+ logging .info (f"task size { task_size } , count { len (tasks )} ." )
127134
128135 rows = []
129136 with ThreadPoolExecutor (max_workers = 8 ) as executor :
@@ -132,4 +139,4 @@ def update_chunk(
132139 rows .extend (future .result ())
133140
134141 cg .client .write (rows )
135- print (f"total elaspsed time: { time .time () - start } " )
142+ logging . info (f"total elaspsed time: { time .time () - start } " )
0 commit comments