@@ -40,6 +40,12 @@ def add_arguments_predict(parser: ap.ArgumentParser):
4040 help = "If set, only run prediction without evaluation metrics." ,
4141 )
4242
43+ parser .add_argument (
44+ "--split-batch" ,
45+ action = "store_true" ,
46+ help = "If set, compute metrics separately for each (cell type, batch) pair." ,
47+ )
48+
4349 parser .add_argument (
4450 "--shared-only" ,
4551 action = "store_true" ,
@@ -67,7 +73,7 @@ def run_tx_predict(args: ap.ArgumentParser):
6773
6874 # Cell-eval for metrics computation
6975 from cell_eval import MetricsEvaluator
70- from cell_eval .utils import split_anndata_on_celltype
76+ from cell_eval .utils import build_celltype_split_specs
7177 from cell_load .data_modules import PerturbationDataModule
7278 from tqdm import tqdm
7379
@@ -288,17 +294,70 @@ def load_config(cfg_path: str) -> dict:
288294 else :
289295 all_celltypes .append (batch_preds ["celltype_name" ])
290296
291- # Handle gem_group
292- if isinstance (batch_preds ["batch" ], list ):
293- all_gem_groups .extend ([str (x ) for x in batch_preds ["batch" ]])
294- elif isinstance (batch_preds ["batch" ], torch .Tensor ):
295- all_gem_groups .extend ([str (x ) for x in batch_preds ["batch" ].cpu ().numpy ()])
296- else :
297- all_gem_groups .append (str (batch_preds ["batch" ]))
297+ batch_size = batch_preds ["preds" ].shape [0 ]
298+
299+ # Handle gem_group - prefer human-readable batch names when available
300+ def normalize_batch_labels (values ):
301+ if values is None :
302+ return None
303+ if isinstance (values , torch .Tensor ):
304+ values = values .detach ().cpu ().numpy ()
305+ if isinstance (values , np .ndarray ):
306+ if values .ndim == 2 :
307+ if values .shape [0 ] != batch_size :
308+ return None
309+ if values .shape [1 ] == 1 :
310+ flat = values .reshape (batch_size )
311+ return [str (x ) for x in flat .tolist ()]
312+ indices = values .argmax (axis = 1 )
313+ return [str (int (x )) for x in indices .tolist ()]
314+ if values .ndim == 1 :
315+ if values .shape [0 ] != batch_size :
316+ return None
317+ return [str (x ) for x in values .tolist ()]
318+ if values .ndim == 0 :
319+ return [str (values .item ())] * batch_size
320+ return None
321+ if isinstance (values , (list , tuple )):
322+ if len (values ) != batch_size :
323+ return None
324+ normalized = []
325+ for item in values :
326+ if isinstance (item , torch .Tensor ):
327+ item = item .detach ().cpu ().numpy ()
328+ if isinstance (item , np .ndarray ):
329+ if item .ndim == 0 :
330+ normalized .append (str (item .item ()))
331+ continue
332+ if item .ndim == 1 :
333+ if item .size == 1 :
334+ normalized .append (str (item .item ()))
335+ elif np .count_nonzero (item ) == 1 :
336+ normalized .append (str (int (item .argmax ())))
337+ else :
338+ normalized .append (str (item .tolist ()))
339+ continue
340+ normalized .append (str (item ))
341+ return normalized
342+ return [str (values )] * batch_size
343+
344+ batch_name_candidates = (
345+ batch .get ("batch_name" ),
346+ batch_preds .get ("batch_name" ),
347+ batch_preds .get ("batch" ),
348+ )
349+
350+ batch_labels = None
351+ for candidate in batch_name_candidates :
352+ batch_labels = normalize_batch_labels (candidate )
353+ if batch_labels is not None :
354+ break
355+ if batch_labels is None :
356+ batch_labels = ["None" ] * batch_size
357+ all_gem_groups .extend (batch_labels )
298358
299359 batch_pred_np = batch_preds ["preds" ].cpu ().numpy ().astype (np .float32 )
300360 batch_real_np = batch_preds ["pert_cell_emb" ].cpu ().numpy ().astype (np .float32 )
301- batch_size = batch_pred_np .shape [0 ]
302361 final_preds [current_idx : current_idx + batch_size , :] = batch_pred_np
303362 final_reals [current_idx : current_idx + batch_size , :] = batch_real_np
304363 current_idx += batch_size
@@ -408,25 +467,41 @@ def load_config(cfg_path: str) -> dict:
408467
409468 control_pert = data_module .get_control_pert ()
410469
411- ct_split_real = split_anndata_on_celltype (adata = adata_real , celltype_col = data_module .cell_type_key )
412- ct_split_pred = split_anndata_on_celltype (adata = adata_pred , celltype_col = data_module .cell_type_key )
470+ batch_key = data_module .batch_col if args .split_batch else None
471+ if args .split_batch :
472+ if not batch_key :
473+ raise ValueError ("--split-batch requested but no batch column is configured on the data module." )
474+ logger .info (
475+ "Splitting evaluation by cell type and batch column '%s'" , batch_key
476+ )
413477
414- assert len (ct_split_real ) == len (ct_split_pred ), (
415- f"Number of celltypes in real and pred anndata must match: { len (ct_split_real )} != { len (ct_split_pred )} "
478+ split_specs = build_celltype_split_specs (
479+ real = adata_real ,
480+ pred = adata_pred ,
481+ celltype_col = data_module .cell_type_key ,
482+ batch_key = batch_key ,
416483 )
417484
418485 pdex_kwargs = dict (exp_post_agg = True , is_log1p = True )
419- for ct in ct_split_real .keys ():
420- real_ct = ct_split_real [ct ]
421- pred_ct = ct_split_pred [ct ]
486+ for split in split_specs :
487+ batch_suffix = (
488+ f", batch={ split .batch } "
489+ if split .batch is not None and not pd .isna (split .batch )
490+ else ""
491+ )
492+ logger .info (
493+ "Evaluating metrics for celltype=%s%s" ,
494+ split .celltype ,
495+ batch_suffix ,
496+ )
422497
423498 evaluator = MetricsEvaluator (
424- adata_pred = pred_ct ,
425- adata_real = real_ct ,
499+ adata_pred = split . pred ,
500+ adata_real = split . real ,
426501 control_pert = control_pert ,
427502 pert_col = data_module .pert_col ,
428503 outdir = results_dir ,
429- prefix = ct ,
504+ prefix = split . label ,
430505 pdex_kwargs = pdex_kwargs ,
431506 batch_size = 2048 ,
432507 )
0 commit comments