diff --git a/hsf/factory.py b/hsf/factory.py index 737436f..5a2468f 100644 --- a/hsf/factory.py +++ b/hsf/factory.py @@ -189,9 +189,10 @@ def main(cfg: DictConfig) -> None: bs = cfg.hardware.engine_settings.batch_size multispectral = 2 if cfg.multispectrality.pattern else 1 - assert multispectral * ( + if multispectral * ( tta + 1 - ) % bs == 0, "test_time_num_aug+1 must be a multiple of batch_size for deepsparse" + ) % bs != 0: + raise AssertionError("test_time_num_aug+1 must be a multiple of batch_size for deepsparse") mris = load_from_config(cfg.files.path, cfg.files.pattern) _n = len(mris)