55from pathos .multiprocessing import ThreadPool as Pool
66from tqdm import tqdm
77
8+ from cellseg_models_pytorch .inference import BaseInferer
9+
810from ..metrics import (
911 accuracy_multiclass ,
1012 aggregated_jaccard_index ,
4345
4446class BenchMarker :
4547 def __init__ (
46- self , pred_dir : str , true_dir : str , classes : Dict [str , int ] = None
48+ self ,
49+ true_dir : str ,
50+ pred_dir : str = None ,
51+ inferer : BaseInferer = None ,
52+ type_classes : Dict [str , int ] = None ,
53+ sem_classes : Dict [str , int ] = None ,
4754 ) -> None :
4855 """Run benchmarking, given prediction and ground truth mask folders.
4956
57+ NOTE: Can also take in an Inferer object.
58+
5059 Parameters
5160 ----------
52- pred_dir : str
53- Path to the prediction .mat files. The pred files have to have matching
54- names to the gt filenames.
5561 true_dir : str
5662 Path to the ground truth .mat files. The gt files have to have matching
5763 names to the pred filenames.
58- classes : Dict[str, int], optional
59- Class dict. E.g. {"bg": 0, "epithelial": 1, "immmune": 2}
64+ pred_dir : str, optional
65+ Path to the prediction .mat files. The pred files have to have matching
66+ names to the gt filenames. If None, the inferer object storing the
67+ predictions will be used instead.
68+ inferer : BaseInferer, optional
69+ Infere object storing predictions of a model. If None, the `pred_dir`
70+ will be used to load the predictions instead.
71+ type_classes : Dict[str, int], optional
72+ Cell type class dict. E.g. {"bg": 0, "epithelial": 1, "immmune": 2}
73+ sem_classes : Dict[str, int], optional
74+ Tissue type class dict. E.g. {"bg": 0, "epithel": 1, "stroma": 2}
6075 """
61- self .pred_dir = Path (pred_dir )
76+ if pred_dir is None and inferer is None :
77+ raise ValueError (
78+ "Both `inferer` and `pred_dir` cannot be set to None at the same time."
79+ )
80+
6281 self .true_dir = Path (true_dir )
63- self .classes = classes
82+ self .type_classes = type_classes
83+ self .sem_classes = sem_classes
84+
85+ if pred_dir is not None :
86+ self .pred_dir = Path (pred_dir )
87+ else :
88+ self .pred_dir = None
89+
90+ self .inferer = inferer
91+
92+ if inferer is not None and pred_dir is None :
93+ try :
94+ self .inferer .out_masks
95+ self .inferer .soft_masks
96+ except AttributeError :
97+ raise AttributeError (
98+ "Did not find `out_masks` or `soft_masks` attributes. "
99+ "To get these, run inference with `inferer.infer()`. "
100+ "Remember to set `save_intermediate` to True for the inferer.`"
101+ )
64102
65103 @staticmethod
66104 def compute_inst_metrics (
@@ -100,16 +138,16 @@ def compute_inst_metrics(
100138 f"An illegal metric was given. Got: { metrics } , allowed: { allowed } "
101139 )
102140
103- # Skip empty GTs
104- if len (np .unique (true )) > 1 :
141+ # Do not run metrics computation if there are no instances in neither of masks
142+ res = {}
143+ if len (np .unique (true )) > 1 or len (np .unique (pred )) > 1 :
105144 true = remap_label (true )
106145 pred = remap_label (pred )
107146
108147 met = {}
109148 for m in metrics :
110149 met [m ] = INST_METRIC_LOOKUP [m ]
111150
112- res = {}
113151 for k , m in met .items ():
114152 score = m (true , pred )
115153
@@ -121,8 +159,19 @@ def compute_inst_metrics(
121159
122160 res ["name" ] = name
123161 res ["type" ] = type
162+ else :
163+ res ["name" ] = name
164+ res ["type" ] = type
124165
125- return res
166+ for m in metrics :
167+ if m == "pq" :
168+ res ["pq" ] = - 1.0
169+ res ["sq" ] = - 1.0
170+ res ["dq" ] = - 1.0
171+ else :
172+ res [m ] = - 1.0
173+
174+ return res
126175
127176 @staticmethod
128177 def compute_sem_metrics (
@@ -158,6 +207,9 @@ def compute_sem_metrics(
158207 A dictionary where metric names are mapped to metric values.
159208 e.g. {"iou": 0.5, "f1score": 0.55, "name": "sample1"}
160209 """
210+ if not isinstance (metrics , tuple ) and not isinstance (metrics , list ):
211+ raise ValueError ("`metrics` must be either a list or tuple of values." )
212+
161213 allowed = list (SEM_METRIC_LOOKUP .keys ())
162214 if not all ([m in allowed for m in metrics ]):
163215 raise ValueError (
@@ -227,20 +279,6 @@ def run_metrics(
227279
228280 return metrics
229281
230- def _read_files (self ) -> List [Tuple [np .ndarray , np .ndarray , str ]]:
231- """Read in the files from the input folders."""
232- preds = sorted (self .pred_dir .glob ("*" ))
233- trues = sorted (self .true_dir .glob ("*" ))
234-
235- masks = []
236- for truef , predf in zip (trues , preds ):
237- true = FileHandler .read_mat (truef , return_all = True )
238- pred = FileHandler .read_mat (predf , return_all = True )
239- name = truef .name
240- masks .append ((true , pred , name ))
241-
242- return masks
243-
244282 def run_inst_benchmark (
245283 self , how : str = "binary" , metrics : Tuple [str , ...] = ("pq" ,)
246284 ) -> None :
@@ -268,17 +306,32 @@ def run_inst_benchmark(
268306 if how not in allowed :
269307 raise ValueError (f"Illegal arg `how`. Got: { how } , Allowed: { allowed } " )
270308
271- masks = self ._read_files ()
309+ trues = sorted (self .true_dir .glob ("*" ))
310+
311+ preds = None
312+ if self .pred_dir is not None :
313+ preds = sorted (self .pred_dir .glob ("*" ))
314+
315+ ik = "inst" if self .pred_dir is None else "inst_map"
316+ tk = "type" if self .pred_dir is None else "type_map"
272317
273318 res = []
274- if how == "multi" and self .classes is not None :
275- for c , i in list (self .classes .items ())[1 :]:
319+ if how == "multi" and self .type_classes is not None :
320+ for c , i in list (self .type_classes .items ())[1 :]:
276321 args = []
277- for true , pred , name in masks :
322+ for j , true_fn in enumerate (trues ):
323+ name = true_fn .name
324+ true = FileHandler .read_mat (true_fn , return_all = True )
325+
326+ if preds is None :
327+ pred = self .inferer .out_masks [name [:- 4 ]]
328+ else :
329+ pred = FileHandler .read_mat (preds [j ], return_all = True )
330+
278331 true_inst = true ["inst_map" ]
279- pred_inst = pred ["inst_map" ]
280332 true_type = true ["type_map" ]
281- pred_type = pred ["type_map" ]
333+ pred_inst = pred [ik ]
334+ pred_type = pred [tk ]
282335
283336 pred_type = get_type_instances (pred_inst , pred_type , i )
284337 true_type = get_type_instances (true_inst , true_type , i )
@@ -287,9 +340,17 @@ def run_inst_benchmark(
287340 res .extend ([metric for metric in met if metric ])
288341 else :
289342 args = []
290- for true , pred , name in masks :
343+ for i , true_fn in enumerate (trues ):
344+ name = true_fn .name
345+ true = FileHandler .read_mat (true_fn , return_all = True )
346+
347+ if preds is None :
348+ pred = self .inferer .out_masks [name [:- 4 ]]
349+ else :
350+ pred = FileHandler .read_mat (preds [i ], return_all = True )
351+
291352 true = true ["inst_map" ]
292- pred = pred ["inst_map" ]
353+ pred = pred [ik ]
293354 args .append ((true , pred , name , metrics ))
294355 met = self .run_metrics (args , "inst" , "binary instance seg" )
295356 res .extend ([metric for metric in met if metric ])
@@ -310,14 +371,40 @@ def run_sem_benchmark(self, metrics: Tuple[str, ...] = ("iou",)) -> Dict[str, An
310371 Dict[str, Any]:
311372 Dictionary mapping the metrics to values + metadata.
312373 """
313- masks = self ._read_files ()
374+ trues = sorted (self .true_dir .glob ("*" ))
375+
376+ preds = None
377+ if self .pred_dir is not None :
378+ preds = sorted (self .pred_dir .glob ("*" ))
379+
380+ sk = "sem" if self .pred_dir is None else "sem_map"
314381
315382 args = []
316- for true , pred , name in masks :
383+ for i , true_fn in enumerate (trues ):
384+ name = true_fn .name
385+ true = FileHandler .read_mat (true_fn , return_all = True )
386+
387+ if preds is None :
388+ pred = self .inferer .out_masks [name [:- 4 ]]
389+ else :
390+ pred = FileHandler .read_mat (preds [i ], return_all = True )
317391 true = true ["sem_map" ]
318- pred = pred ["sem_map" ]
319- args .append ((true , pred , name , len (self .classes ), metrics ))
392+ pred = pred [sk ]
393+ args .append ((true , pred , name , len (self .sem_classes ), metrics ))
394+
320395 met = self .run_metrics (args , "sem" , "semantic seg" )
321- res = [metric for metric in met if metric ]
396+ ires = [metric for metric in met if metric ]
397+
398+ # re-format
399+ res = []
400+ for r in ires :
401+ for k , val in self .sem_classes .items ():
402+ cc = {
403+ "name" : r ["name" ],
404+ "type" : k ,
405+ }
406+ for m in metrics :
407+ cc [m ] = r [m ][val ]
408+ res .append (cc )
322409
323410 return res
0 commit comments