Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 68 additions & 25 deletions src/llmcompressor/modifiers/quantization/gptq/base.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import contextlib
import warnings
from concurrent.futures import ThreadPoolExecutor
from typing import Dict, List, Optional, Tuple, Union

import torch
Expand Down Expand Up @@ -241,34 +242,76 @@ def compress_modules(self):
"""
Quantize modules which have been calibrated
"""
for module in list(self._num_samples.keys()):
name = self._module_names[module]
num_samples = self._num_samples[module]
quant_args = getattr_chain(module, "quantization_scheme.weights")

logger.info(f"Quantizing {name} using {num_samples} samples")
with torch.no_grad(), align_module_device(
module
), self._maybe_onload_hessian(module), CompressionLogger(
module
) as comp_logger:
loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
module=module,
quant_args=quant_args,
hessians_dict=self._hessians,
blocksize=self.block_size,
percdamp=self.dampening_frac,
modules_to_quantize = list(self._num_samples.keys())
if not modules_to_quantize:
return

gpu_ids = list(range(torch.cuda.device_count()))
if not gpu_ids:
logger.info("No CUDA devices found. Running GPTQ sequentially on CPU.")
for module in modules_to_quantize:
self._quantize_and_update_module(module)
return

logger.info(f"Starting parallel GPTQ on {len(gpu_ids)} GPUs...")
original_device = get_execution_device(modules_to_quantize[0])

with ThreadPoolExecutor(max_workers=len(gpu_ids)) as executor:
futures = []
for i, module in enumerate(modules_to_quantize):
target_gpu_id = gpu_ids[i % len(gpu_ids)]
target_device = f"cuda:{target_gpu_id}"
future = executor.submit(
self._quantize_and_update_module,
module,
target_device,
original_device
)
comp_logger.set_loss(loss)
futures.append(future)
for future in futures:
future.result()
Comment on lines +257 to +272
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The current implementation assumes that all modules to be quantized reside on the same device initially. It determines the original_device from the first module in the list and uses it for all modules. This can cause problems in a model-parallel setup where modules might be distributed across different devices. If that's the case, after quantization, all modules would be incorrectly moved to the device of the first module.

To make this more robust, you should determine and store the original device for each module individually and use that to move the module back after processing.

Suggested change
original_device = get_execution_device(modules_to_quantize[0])
with ThreadPoolExecutor(max_workers=len(gpu_ids)) as executor:
futures = []
for i, module in enumerate(modules_to_quantize):
target_gpu_id = gpu_ids[i % len(gpu_ids)]
target_device = f"cuda:{target_gpu_id}"
future = executor.submit(
self._quantize_and_update_module,
module,
target_device,
original_device
)
comp_logger.set_loss(loss)
futures.append(future)
for future in futures:
future.result()
original_devices = {m: get_execution_device(m) for m in modules_to_quantize}
with ThreadPoolExecutor(max_workers=len(gpu_ids)) as executor:
futures = []
for i, module in enumerate(modules_to_quantize):
target_gpu_id = gpu_ids[i % len(gpu_ids)]
target_device = f"cuda:{target_gpu_id}"
future = executor.submit(
self._quantize_and_update_module,
module,
target_device,
original_devices[module],
)
futures.append(future)
for future in futures:
future.result()

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This might make sense, if we are allowing multi-gpu compression, modules could also be spread out across different GPUs.

def _quantize_and_update_module(
self,
module: torch.nn.Module,
target_device: Optional[str] = None,
original_device: Optional[str] = None
):
"""
Helper function to quantize a single module. This function is called by the
ThreadPoolExecutor.
"""
if target_device:
module.to(target_device)

name = self._module_names[module]
num_samples = self._num_samples[module]
quant_args = getattr_chain(module, "quantization_scheme.weights")

logger.info(f"Quantizing {name} on device {target_device} using {num_samples} samples")
with torch.no_grad(), align_module_device(
module
), self._maybe_onload_hessian(module), CompressionLogger(
module
) as comp_logger:
loss, quantized_weight, scale, zero_point, g_idx = quantize_weight(
module=module,
quant_args=quant_args,
hessians_dict=self._hessians,
blocksize=self.block_size,
percdamp=self.dampening_frac,
)
comp_logger.set_loss(loss)

update_offload_parameter(module, "weight", quantized_weight)
update_offload_parameter(module, "weight_scale", scale)
update_offload_parameter(module, "weight_zero_point", zero_point)
if g_idx is not None:
update_offload_parameter(module, "weight_g_idx", g_idx)

update_offload_parameter(module, "weight", quantized_weight)
update_offload_parameter(module, "weight_scale", scale)
update_offload_parameter(module, "weight_zero_point", zero_point)
if g_idx is not None:
update_offload_parameter(module, "weight_g_idx", g_idx)
del self._num_samples[module]

# self._hessians[module] already deleted by quantize_weight
del self._num_samples[module]
if original_device:
module.to(original_device)

def on_end(self, state: State, event: Event, **kwargs):
"""
Expand Down