Skip to content

Commit e8519a6

Browse files
committed
added split batch option that relies on cell-eval having the split batch option
1 parent e01fd45 commit e8519a6

File tree

2 files changed

+95
-20
lines changed

2 files changed

+95
-20
lines changed

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ dependencies = [
2727
"geomloss>=0.2.6",
2828
"transformers>=4.52.3",
2929
"peft>=0.11.0",
30-
"cell-eval>=0.5.45",
30+
"cell-eval>=0.5.46",
3131
"ipykernel>=6.30.1",
3232
]
3333

src/state/_cli/_tx/_predict.py

Lines changed: 94 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -40,6 +40,12 @@ def add_arguments_predict(parser: ap.ArgumentParser):
4040
help="If set, only run prediction without evaluation metrics.",
4141
)
4242

43+
parser.add_argument(
44+
"--split-batch",
45+
action="store_true",
46+
help="If set, compute metrics separately for each (cell type, batch) pair.",
47+
)
48+
4349
parser.add_argument(
4450
"--shared-only",
4551
action="store_true",
@@ -67,7 +73,7 @@ def run_tx_predict(args: ap.ArgumentParser):
6773

6874
# Cell-eval for metrics computation
6975
from cell_eval import MetricsEvaluator
70-
from cell_eval.utils import split_anndata_on_celltype
76+
from cell_eval.utils import build_celltype_split_specs
7177
from cell_load.data_modules import PerturbationDataModule
7278
from tqdm import tqdm
7379

@@ -288,17 +294,70 @@ def load_config(cfg_path: str) -> dict:
288294
else:
289295
all_celltypes.append(batch_preds["celltype_name"])
290296

291-
# Handle gem_group
292-
if isinstance(batch_preds["batch"], list):
293-
all_gem_groups.extend([str(x) for x in batch_preds["batch"]])
294-
elif isinstance(batch_preds["batch"], torch.Tensor):
295-
all_gem_groups.extend([str(x) for x in batch_preds["batch"].cpu().numpy()])
296-
else:
297-
all_gem_groups.append(str(batch_preds["batch"]))
297+
batch_size = batch_preds["preds"].shape[0]
298+
299+
# Handle gem_group - prefer human-readable batch names when available
300+
def normalize_batch_labels(values):
301+
if values is None:
302+
return None
303+
if isinstance(values, torch.Tensor):
304+
values = values.detach().cpu().numpy()
305+
if isinstance(values, np.ndarray):
306+
if values.ndim == 2:
307+
if values.shape[0] != batch_size:
308+
return None
309+
if values.shape[1] == 1:
310+
flat = values.reshape(batch_size)
311+
return [str(x) for x in flat.tolist()]
312+
indices = values.argmax(axis=1)
313+
return [str(int(x)) for x in indices.tolist()]
314+
if values.ndim == 1:
315+
if values.shape[0] != batch_size:
316+
return None
317+
return [str(x) for x in values.tolist()]
318+
if values.ndim == 0:
319+
return [str(values.item())] * batch_size
320+
return None
321+
if isinstance(values, (list, tuple)):
322+
if len(values) != batch_size:
323+
return None
324+
normalized = []
325+
for item in values:
326+
if isinstance(item, torch.Tensor):
327+
item = item.detach().cpu().numpy()
328+
if isinstance(item, np.ndarray):
329+
if item.ndim == 0:
330+
normalized.append(str(item.item()))
331+
continue
332+
if item.ndim == 1:
333+
if item.size == 1:
334+
normalized.append(str(item.item()))
335+
elif np.count_nonzero(item) == 1:
336+
normalized.append(str(int(item.argmax())))
337+
else:
338+
normalized.append(str(item.tolist()))
339+
continue
340+
normalized.append(str(item))
341+
return normalized
342+
return [str(values)] * batch_size
343+
344+
batch_name_candidates = (
345+
batch.get("batch_name"),
346+
batch_preds.get("batch_name"),
347+
batch_preds.get("batch"),
348+
)
349+
350+
batch_labels = None
351+
for candidate in batch_name_candidates:
352+
batch_labels = normalize_batch_labels(candidate)
353+
if batch_labels is not None:
354+
break
355+
if batch_labels is None:
356+
batch_labels = ["None"] * batch_size
357+
all_gem_groups.extend(batch_labels)
298358

299359
batch_pred_np = batch_preds["preds"].cpu().numpy().astype(np.float32)
300360
batch_real_np = batch_preds["pert_cell_emb"].cpu().numpy().astype(np.float32)
301-
batch_size = batch_pred_np.shape[0]
302361
final_preds[current_idx : current_idx + batch_size, :] = batch_pred_np
303362
final_reals[current_idx : current_idx + batch_size, :] = batch_real_np
304363
current_idx += batch_size
@@ -408,25 +467,41 @@ def load_config(cfg_path: str) -> dict:
408467

409468
control_pert = data_module.get_control_pert()
410469

411-
ct_split_real = split_anndata_on_celltype(adata=adata_real, celltype_col=data_module.cell_type_key)
412-
ct_split_pred = split_anndata_on_celltype(adata=adata_pred, celltype_col=data_module.cell_type_key)
470+
batch_key = data_module.batch_col if args.split_batch else None
471+
if args.split_batch:
472+
if not batch_key:
473+
raise ValueError("--split-batch requested but no batch column is configured on the data module.")
474+
logger.info(
475+
"Splitting evaluation by cell type and batch column '%s'", batch_key
476+
)
413477

414-
assert len(ct_split_real) == len(ct_split_pred), (
415-
f"Number of celltypes in real and pred anndata must match: {len(ct_split_real)} != {len(ct_split_pred)}"
478+
split_specs = build_celltype_split_specs(
479+
real=adata_real,
480+
pred=adata_pred,
481+
celltype_col=data_module.cell_type_key,
482+
batch_key=batch_key,
416483
)
417484

418485
pdex_kwargs = dict(exp_post_agg=True, is_log1p=True)
419-
for ct in ct_split_real.keys():
420-
real_ct = ct_split_real[ct]
421-
pred_ct = ct_split_pred[ct]
486+
for split in split_specs:
487+
batch_suffix = (
488+
f", batch={split.batch}"
489+
if split.batch is not None and not pd.isna(split.batch)
490+
else ""
491+
)
492+
logger.info(
493+
"Evaluating metrics for celltype=%s%s",
494+
split.celltype,
495+
batch_suffix,
496+
)
422497

423498
evaluator = MetricsEvaluator(
424-
adata_pred=pred_ct,
425-
adata_real=real_ct,
499+
adata_pred=split.pred,
500+
adata_real=split.real,
426501
control_pert=control_pert,
427502
pert_col=data_module.pert_col,
428503
outdir=results_dir,
429-
prefix=ct,
504+
prefix=split.label,
430505
pdex_kwargs=pdex_kwargs,
431506
batch_size=2048,
432507
)

0 commit comments

Comments
 (0)