From 7073d1dd103291a6bb1bf00fb73285a3c57169a8 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 31 Dec 2024 11:49:56 +0000 Subject: [PATCH] More delay --- src/brevitas/proxy/runtime_quant.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/brevitas/proxy/runtime_quant.py b/src/brevitas/proxy/runtime_quant.py index 64a8faefe..76da00d83 100644 --- a/src/brevitas/proxy/runtime_quant.py +++ b/src/brevitas/proxy/runtime_quant.py @@ -255,11 +255,14 @@ class ClampQuantProxyFromInjector(QuantProxyFromInjector, AccQuantProxyProtocol) def __init__(self): super().__init__() self.skip_create_quant_tensor = False + quant_delay_steps = self.quant_injector.quant_delay_steps if 'quant_delay_steps' in self.quant_injector else None + self.delay_wrapper = DelayWrapper(quant_delay_steps) def forward(self, x: IntQuantTensor) -> Union[Tensor, IntQuantTensor]: if self.is_quant_enabled: out_tuple = self.tensor_quant(x.value, x.scale, x.bit_width) out_value, out_scale, out_zp, out_bit_width = out_tuple + out_value = self.delay_wrapper(x, out_value) if self.skip_create_quant_tensor: return out_value return IntQuantTensor(