From dd80cf40186d84ec5d0faf4dd0d93d5ae50251e7 Mon Sep 17 00:00:00 2001 From: Giuseppe Franco Date: Sat, 14 Dec 2024 22:12:05 +0000 Subject: [PATCH] Fix compile --- src/brevitas/export/inference/handler.py | 10 +++++----- .../quant_tensor/groupwise_float_quant_tensor.py | 1 - 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/src/brevitas/export/inference/handler.py b/src/brevitas/export/inference/handler.py index 3ab6ba0a0..a582d95a9 100644 --- a/src/brevitas/export/inference/handler.py +++ b/src/brevitas/export/inference/handler.py @@ -105,7 +105,7 @@ def prepare_for_export(self, module): def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: x, *other = self.module_forward(x) - if is_dynamo_compiling: + if is_dynamo_compiling(): start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) @@ -131,7 +131,7 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: else: zero_point = self.zero_point out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) - if is_dynamo_compiling: + if is_dynamo_compiling(): out = self.flattened_view(out) return out, scale, zero_point, self.bit_width @@ -219,7 +219,7 @@ def prepare_for_export(self, module): def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: x, *other = self.module_forward(x) - if is_dynamo_compiling: + if is_dynamo_compiling(): start_dim = self.group_dim if self.group_dim != -1 else -2 x = x.flatten(start_dim, start_dim + 1) output_args = tuple([x] + list(other)) @@ -245,6 +245,6 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]: else: zero_point = self.zero_point out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point) - if is_dynamo_compiling: + if is_dynamo_compiling(): out = self.flattened_view(out) - return out, scale, zero_point, self.bit_width + return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values diff --git a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py index 16f75c49e..b507d3fe3 100644 --- a/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py +++ b/src/brevitas/quant_tensor/groupwise_float_quant_tensor.py @@ -92,7 +92,6 @@ def expand(self): curr_shape = self.value_.shape start_dim = self.group_dim if self.group_dim != -1 else -2 new_value = self.value_.flatten(start_dim, start_dim + 1) - new_value = self.value_.flatten(start_dim, start_dim + 1) if self.scale_.shape != (): new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1) else: