From 181ef8077314db4ea97848bf5f09d060255c614f Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Tue, 18 Jun 2024 10:17:46 +0100 Subject: [PATCH] Fix (core/float): add default for float_scaling_impl --- src/brevitas/core/quant/float.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/src/brevitas/core/quant/float.py b/src/brevitas/core/quant/float.py index 71f518bb5..5b582e195 100644 --- a/src/brevitas/core/quant/float.py +++ b/src/brevitas/core/quant/float.py @@ -52,6 +52,9 @@ def __init__( if scaling_impl is None: scaling_impl = ConstScaling(1., device=device, dtype=dtype) + if float_scaling_impl is None: + float_scaling_impl = ConstScaling(1., device=device, dtype=dtype) + # Zero-point is currently hardcoded to 0 self.zero_point_impl = StatelessBuffer(torch.tensor(0., device=device, dtype=dtype)) self.float_scaling_impl = float_scaling_impl