diff --git a/elk/run.py b/elk/run.py index fb8903cc..64df03dd 100644 --- a/elk/run.py +++ b/elk/run.py @@ -97,6 +97,8 @@ def execute( ) devices = select_usable_devices(self.num_gpus, min_memory=self.min_gpu_mem) + if devices == ['mps']: + devices = ['cpu'] num_devices = len(devices) func: Callable[[int], dict[str, pd.DataFrame]] = partial( self.apply_to_layer, devices=devices, world_size=num_devices