diff --git a/backends/exllamav2/model.py b/backends/exllamav2/model.py index 989c80a9..987da9ae 100644 --- a/backends/exllamav2/model.py +++ b/backends/exllamav2/model.py @@ -128,6 +128,7 @@ def progress(loaded_modules: int, total_modules: int, # Turn off GPU split if the user is using 1 GPU gpu_count = torch.cuda.device_count() gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True) + gpu_device_list = list(range(0, gpu_count)) if gpu_count > 1 and gpu_split_auto: # Auto GPU split parameters @@ -141,6 +142,12 @@ def progress(loaded_modules: int, total_modules: int, # Manual GPU split self.gpu_split = kwargs.get("gpu_split") self.gpu_split_auto = False + + gpu_device_list = [ + device_idx + for device_idx, memory in enumerate(self.gpu_split) + if memory > 0 + ] else: # One GPU setup self.gpu_split_auto = False @@ -185,6 +192,27 @@ def progress(loaded_modules: int, total_modules: int, # Enable fasttensors loading if present self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False) + # Disable paged mode if the user's min GPU isn't supported (ampere and up) + min_compute_capability = min( + torch.cuda.get_device_capability(device=device_idx)[0] + for device_idx in gpu_device_list + ) + + # Compute capability < 8 is not supported by FA2 + # AMD is also unsupported until ROCm updates its FA2 fork + if torch.version.hip or min_compute_capability < 8: + logger.warning( + "An unsupported GPU is found in this configuration. " + "Switching to compatibility mode. \n" + "This disables parallel batching " + "and features that rely on it (ex. CFG). \n" + "To disable compatability mode, all GPUs must be ampere " + "(30 series) or newer. AMD GPUs are not supported." + ) + self.config.no_flash_attn = True + self.paged = False + self.max_batch_size = 1 + # Try to set prompt template self.prompt_template = self.find_prompt_template( kwargs.get("prompt_template"), model_directory @@ -397,31 +425,6 @@ async def load_gen(self, progress_callback=None, **kwargs): async for value in iterate_in_threadpool(model_load_generator): yield value - # Disable paged mode if the user's min GPU isn't supported (ampere and up) - device_list = { - module.device_idx - for module in self.model.modules - if module.device_idx >= 0 - } - min_compute_capability = min( - torch.cuda.get_device_capability(device=device)[0] - for device in device_list - ) - - # Compute capability < 8 is not supported by FA2 - # AMD is also unsupported until ROCm updates its FA2 fork - if torch.version.hip or min_compute_capability < 8: - logger.warning( - "An unsupported GPU is found in this configuration. " - "Switching to compatibility mode. \n" - "This disables parallel batching " - "and features that rely on it (ex. CFG). \n" - "To disable compatability mode, all GPUs must be ampere " - "(30 series) or newer. AMD GPUs are not supported." - ) - self.paged = False - self.max_batch_size = 1 - # Create async generator self.generator = ExLlamaV2DynamicGeneratorAsync( model=self.model,