Skip to content

Commit

Permalink
Fix compile
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 14, 2024
1 parent 68879ca commit dd80cf4
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
10 changes: 5 additions & 5 deletions src/brevitas/export/inference/handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ def prepare_for_export(self, module):

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling:
if is_dynamo_compiling():
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
Expand All @@ -131,7 +131,7 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
else:
zero_point = self.zero_point
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling:
if is_dynamo_compiling():
out = self.flattened_view(out)
return out, scale, zero_point, self.bit_width

Expand Down Expand Up @@ -219,7 +219,7 @@ def prepare_for_export(self, module):

def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
x, *other = self.module_forward(x)
if is_dynamo_compiling:
if is_dynamo_compiling():
start_dim = self.group_dim if self.group_dim != -1 else -2
x = x.flatten(start_dim, start_dim + 1)
output_args = tuple([x] + list(other))
Expand All @@ -245,6 +245,6 @@ def forward(self, x, unused_scale=None) -> Tuple[torch.Tensor]:
else:
zero_point = self.zero_point
out = self.dequantize(self.quantize(x, scale, zero_point), scale, zero_point)
if is_dynamo_compiling:
if is_dynamo_compiling():
out = self.flattened_view(out)
return out, scale, zero_point, self.bit_width
return out, scale, zero_point, self.exponent_bit_width, self.mantissa_bit_width, self.exponent_bias, self.saturating, self.inf_values, self.nan_values
1 change: 0 additions & 1 deletion src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,6 @@ def expand(self):
curr_shape = self.value_.shape
start_dim = self.group_dim if self.group_dim != -1 else -2
new_value = self.value_.flatten(start_dim, start_dim + 1)
new_value = self.value_.flatten(start_dim, start_dim + 1)
if self.scale_.shape != ():
new_scale = self.scale_.expand(curr_shape).flatten(start_dim, start_dim + 1)
else:
Expand Down

0 comments on commit dd80cf4

Please sign in to comment.