diff --git a/nnunetv2/inference/predict_from_raw_data.py b/nnunetv2/inference/predict_from_raw_data.py index aa2fceea0..0d6aa327e 100644 --- a/nnunetv2/inference/predict_from_raw_data.py +++ b/nnunetv2/inference/predict_from_raw_data.py @@ -882,6 +882,9 @@ def predict_entry_point(): help="Use this to set the device the inference should run with. Available options are 'cuda' " "(GPU), 'cpu' (CPU) and 'mps' (Apple M1/M2). Do NOT use this to set which GPU ID! " "Use CUDA_VISIBLE_DEVICES=X nnUNetv2_predict [...] instead!") + parser.add_argument('--not_on_device', action='store_true', required=False, default=False, + help="Set this flag to not keep the entire case on device. Recommended for large cases that " + "occupy more VRAM than available") parser.add_argument('--disable_progress_bar', action='store_true', required=False, default=False, help='Set this flag to disable progress bar. Recommended for HPC environments (non interactive ' 'jobs)') @@ -922,7 +925,7 @@ def predict_entry_point(): predictor = nnUNetPredictor(tile_step_size=args.step_size, use_gaussian=True, use_mirroring=not args.disable_tta, - perform_everything_on_device=True, + perform_everything_on_device=not args.not_on_device, device=device, verbose=args.verbose, verbose_preprocessing=args.verbose,