From ddc04ebf83399aef31e2ee381dce0810a808dc29 Mon Sep 17 00:00:00 2001 From: aashvgit <167199295+aashvgit@users.noreply.github.com> Date: Wed, 17 Sep 2025 23:48:12 +0530 Subject: [PATCH] Updating base.py (parallel calibration and model #1809) I've added ThreadPoolExecuter to manage parallel threads and also added New Helper Method (quantize_and_update_module and updated the compress_modules making it simple --- .../modifiers/quantization/gptq/base.py | 93 ++++++++++++++----- 1 file changed, 68 insertions(+), 25 deletions(-) diff --git a/src/llmcompressor/modifiers/quantization/gptq/base.py b/src/llmcompressor/modifiers/quantization/gptq/base.py index bee22fe6e..c85d4d853 100644 --- a/src/llmcompressor/modifiers/quantization/gptq/base.py +++ b/src/llmcompressor/modifiers/quantization/gptq/base.py @@ -1,5 +1,6 @@ import contextlib import warnings +from concurrent.futures import ThreadPoolExecutor from typing import Dict, List, Optional, Tuple, Union import torch @@ -241,34 +242,76 @@ def compress_modules(self): """ Quantize modules which have been calibrated """ - for module in list(self._num_samples.keys()): - name = self._module_names[module] - num_samples = self._num_samples[module] - quant_args = getattr_chain(module, "quantization_scheme.weights") - - logger.info(f"Quantizing {name} using {num_samples} samples") - with torch.no_grad(), align_module_device( - module - ), self._maybe_onload_hessian(module), CompressionLogger( - module - ) as comp_logger: - loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( - module=module, - quant_args=quant_args, - hessians_dict=self._hessians, - blocksize=self.block_size, - percdamp=self.dampening_frac, + modules_to_quantize = list(self._num_samples.keys()) + if not modules_to_quantize: + return + + gpu_ids = list(range(torch.cuda.device_count())) + if not gpu_ids: + logger.info("No CUDA devices found. Running GPTQ sequentially on CPU.") + for module in modules_to_quantize: + self._quantize_and_update_module(module) + return + + logger.info(f"Starting parallel GPTQ on {len(gpu_ids)} GPUs...") + original_device = get_execution_device(modules_to_quantize[0]) + + with ThreadPoolExecutor(max_workers=len(gpu_ids)) as executor: + futures = [] + for i, module in enumerate(modules_to_quantize): + target_gpu_id = gpu_ids[i % len(gpu_ids)] + target_device = f"cuda:{target_gpu_id}" + future = executor.submit( + self._quantize_and_update_module, + module, + target_device, + original_device ) - comp_logger.set_loss(loss) + futures.append(future) + for future in futures: + future.result() + def _quantize_and_update_module( + self, + module: torch.nn.Module, + target_device: Optional[str] = None, + original_device: Optional[str] = None + ): + """ + Helper function to quantize a single module. This function is called by the + ThreadPoolExecutor. + """ + if target_device: + module.to(target_device) + + name = self._module_names[module] + num_samples = self._num_samples[module] + quant_args = getattr_chain(module, "quantization_scheme.weights") + + logger.info(f"Quantizing {name} on device {target_device} using {num_samples} samples") + with torch.no_grad(), align_module_device( + module + ), self._maybe_onload_hessian(module), CompressionLogger( + module + ) as comp_logger: + loss, quantized_weight, scale, zero_point, g_idx = quantize_weight( + module=module, + quant_args=quant_args, + hessians_dict=self._hessians, + blocksize=self.block_size, + percdamp=self.dampening_frac, + ) + comp_logger.set_loss(loss) + + update_offload_parameter(module, "weight", quantized_weight) + update_offload_parameter(module, "weight_scale", scale) + update_offload_parameter(module, "weight_zero_point", zero_point) + if g_idx is not None: + update_offload_parameter(module, "weight_g_idx", g_idx) - update_offload_parameter(module, "weight", quantized_weight) - update_offload_parameter(module, "weight_scale", scale) - update_offload_parameter(module, "weight_zero_point", zero_point) - if g_idx is not None: - update_offload_parameter(module, "weight_g_idx", g_idx) + del self._num_samples[module] - # self._hessians[module] already deleted by quantize_weight - del self._num_samples[module] + if original_device: + module.to(original_device) def on_end(self, state: State, event: Event, **kwargs): """