From 64c05b969aa8174a6fc4d37fe98054670f357fb1 Mon Sep 17 00:00:00 2001 From: AlpinDale <52078762+AlpinDale@users.noreply.github.com> Date: Sun, 15 Dec 2024 18:31:14 -0800 Subject: [PATCH] fix: `ShardedStateLoader` with fp8 quant (#900) --- aphrodite/modeling/model_loader/loader.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/aphrodite/modeling/model_loader/loader.py b/aphrodite/modeling/model_loader/loader.py index f63279610..17947aa32 100644 --- a/aphrodite/modeling/model_loader/loader.py +++ b/aphrodite/modeling/model_loader/loader.py @@ -585,6 +585,10 @@ def load_model(self, *, model_config: ModelConfig, with torch.device(device_config.device): model = _initialize_model(model_config, self.load_config, lora_config, cache_config) + for _, module in model.named_modules(): + quant_method = getattr(module, "quant_method", None) + if quant_method is not None: + quant_method.process_weights_after_loading(module) rank = get_tensor_model_parallel_rank() pattern = os.path.join( local_model_path,