diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 9f8c1ced..1bf1d535 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -108,21 +108,24 @@ def progress(loaded_modules: int, total_modules: int, # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() - if gpu_count > 1: - gpu_split = kwargs.get("gpu_split") + gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) - if gpu_split: - self.gpu_split = gpu_split - else: - # Auto GPU split parameters - self.gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) - autosplit_reserve_megabytes = unwrap( - kwargs.get("autosplit_reserve"), [96] - ) - self.autosplit_reserve = list( - map(lambda value: value * 1024**2, autosplit_reserve_megabytes) - ) + if gpu_count > 1 and gpu_split_auto: + # Auto GPU split parameters + self.gpu_split_auto = gpu_split_auto + + autosplit_reserve_megabytes = unwrap( + kwargs.get("autosplit_reserve"), [96] + ) + self.autosplit_reserve = list( + map(lambda value: value * 1024**2, autosplit_reserve_megabytes) + ) + elif gpu_count > 1: + # Manual GPU split + self.gpu_split = kwargs.get("gpu_split") + self.gpu_split_auto = False else: + # One GPU setup self.gpu_split_auto = False logger.info("Disabling GPU split because one GPU is in use.")