diff --git a/geo_inference/geo_inference.py b/geo_inference/geo_inference.py index 0671a3b..50946e9 100644 --- a/geo_inference/geo_inference.py +++ b/geo_inference/geo_inference.py @@ -2,6 +2,7 @@ import gc import re import sys +import platform import time import torch import pystac @@ -182,8 +183,13 @@ async def async_run_inference(self, """ - # configuring dask - num_workers = len(os.sched_getaffinity(0)) - 1 if workers == 0 else workers + # configuring dask with proper number of workers, alternatively we could also use os.getenv('SLURM_CPUS_PER_TASK') + if workers != 0: + num_workers = workers + elif 'linux' in platform.uname().system.lower(): + num_workers = len(os.sched_getaffinity(0)) - 1 + else: + num_workers = os.cpu_count() - 1 print(f"running dask with {num_workers} workers") config.set(scheduler='threads', num_workers=num_workers) config.set(pool=ThreadPool(num_workers))