Skip to content

Commit

Permalink
Model: Fix paged and FA2 checks
Browse files Browse the repository at this point in the history
If a user is using GPU split, check compute capability on only those
GPUs. Autosplit assumes that all GPUs will be used.

Signed-off-by: kingbri <[email protected]>
  • Loading branch information
bdashore3 committed May 26, 2024
1 parent 9fbbc5a commit 094c7b1
Showing 1 changed file with 28 additions and 25 deletions.
53 changes: 28 additions & 25 deletions backends/exllamav2/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,7 @@ def progress(loaded_modules: int, total_modules: int,
# Turn off GPU split if the user is using 1 GPU
gpu_count = torch.cuda.device_count()
gpu_split_auto = unwrap(kwargs.get("gpu_split_auto"), True)
gpu_device_list = list(range(0, gpu_count))

if gpu_count > 1 and gpu_split_auto:
# Auto GPU split parameters
Expand All @@ -141,6 +142,12 @@ def progress(loaded_modules: int, total_modules: int,
# Manual GPU split
self.gpu_split = kwargs.get("gpu_split")
self.gpu_split_auto = False

gpu_device_list = [
device_idx
for device_idx, memory in enumerate(self.gpu_split)
if memory > 0
]
else:
# One GPU setup
self.gpu_split_auto = False
Expand Down Expand Up @@ -185,6 +192,27 @@ def progress(loaded_modules: int, total_modules: int,
# Enable fasttensors loading if present
self.config.fasttensors = unwrap(kwargs.get("fasttensors"), False)

# Disable paged mode if the user's min GPU isn't supported (ampere and up)
min_compute_capability = min(
torch.cuda.get_device_capability(device=device_idx)[0]
for device_idx in gpu_device_list
)

# Compute capability < 8 is not supported by FA2
# AMD is also unsupported until ROCm updates its FA2 fork
if torch.version.hip or min_compute_capability < 8:
logger.warning(
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
self.config.no_flash_attn = True
self.paged = False
self.max_batch_size = 1

# Try to set prompt template
self.prompt_template = self.find_prompt_template(
kwargs.get("prompt_template"), model_directory
Expand Down Expand Up @@ -397,31 +425,6 @@ async def load_gen(self, progress_callback=None, **kwargs):
async for value in iterate_in_threadpool(model_load_generator):
yield value

# Disable paged mode if the user's min GPU isn't supported (ampere and up)
device_list = {
module.device_idx
for module in self.model.modules
if module.device_idx >= 0
}
min_compute_capability = min(
torch.cuda.get_device_capability(device=device)[0]
for device in device_list
)

# Compute capability < 8 is not supported by FA2
# AMD is also unsupported until ROCm updates its FA2 fork
if torch.version.hip or min_compute_capability < 8:
logger.warning(
"An unsupported GPU is found in this configuration. "
"Switching to compatibility mode. \n"
"This disables parallel batching "
"and features that rely on it (ex. CFG). \n"
"To disable compatability mode, all GPUs must be ampere "
"(30 series) or newer. AMD GPUs are not supported."
)
self.paged = False
self.max_batch_size = 1

# Create async generator
self.generator = ExLlamaV2DynamicGeneratorAsync(
model=self.model,
Expand Down

0 comments on commit 094c7b1

Please sign in to comment.