diff --git a/sdcflows/workflows/apply/correction.py b/sdcflows/workflows/apply/correction.py index 2ec7de8555..728b026af1 100644 --- a/sdcflows/workflows/apply/correction.py +++ b/sdcflows/workflows/apply/correction.py @@ -26,7 +26,7 @@ from niworkflows.engine.workflows import LiterateWorkflow as Workflow -def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): +def init_unwarp_wf(*, free_mem=None, omp_nthreads=1, debug=False, name="unwarp_wf"): r""" Set up a workflow that unwarps the input :abbr:`EPI (echo-planar imaging)` dataset. @@ -103,9 +103,24 @@ def init_unwarp_wf(omp_nthreads=1, debug=False, name="unwarp_wf"): rotime = pe.Node(GetReadoutTime(), name="rotime") rotime.interface._always_run = debug - resample = pe.Node(ApplyCoeffsField( - num_threads=omp_nthreads if not debug else 1 - ), name="resample") + + # resample is memory-hungry; choose a smaller number of threads + # if we know how much memory we have to work with + mem_per_thread = 5 # True for a 128x128x84 image; should generalize + if debug: + num_threads = 1 + elif free_mem is not None: + mem_gb = min(0.9 * free_mem, mem_per_thread * omp_nthreads) + num_threads = max(int(mem_gb // mem_per_thread), 1) + else: + num_threads = omp_nthreads + + resample = pe.Node( + ApplyCoeffsField(num_threads=num_threads), + mem_gb=mem_per_thread * num_threads, + name="resample", + ) + merge = pe.Node(MergeSeries(), name="merge") average = pe.Node(RobustAverage(mc_method=None), name="average")