Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
Giuseppe5 committed Dec 19, 2024
1 parent 33bd3f7 commit ae52c79
Show file tree
Hide file tree
Showing 4 changed files with 11 additions and 17 deletions.
10 changes: 5 additions & 5 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -206,15 +206,15 @@
},
{
"cell_type": "code",
"execution_count": 16,
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Non padding weights shape torch.Size([64, 8, 3, 3])\n",
"Padded weights shape torch.Size([64, 32, 3, 3])\n"
"Non padding weights shape torch.Size([64, 1, 8, 3, 3])\n",
"Padded weights shape torch.Size([64, 1, 32, 3, 3])\n"
]
}
],
Expand Down Expand Up @@ -257,8 +257,8 @@
"o = mx_model(x)\n",
"\n",
"# The quant weight of the padded model is different from the non padding one\n",
"print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value.shape}\")\n",
"print(f\"Padded weights shape {mx_model.conv.quant_weight().value.shape}\")\n",
"print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value_.shape}\")\n",
"print(f\"Padded weights shape {mx_model.conv.quant_weight().value_.shape}\")\n",
"\n",
"# However, results are still the same \n",
"assert torch.allclose(o, o_no_padding)"
Expand Down
6 changes: 0 additions & 6 deletions src/brevitas/graph/gpxq.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,12 +187,6 @@ def __init__(
self.layer = layer
self.name = name
self.act_order = act_order
if self.layer.weight_quant.is_groupwise:
weight = self.layer.weight_quant.apply_input_view(self.layer.weight)
weight = weight.view(self.layer.weight_quant.quant_injector.reshaped_groupwise_shape)
self.layer.weight.data = weight.data
self.layer.in_channels = weight.shape[1] if is_conv_transposed(
self.layer) else weight.shape[0]

weight_shape = torch.tensor(layer.weight.shape)

Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/quant_tensor/groupwise_float_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,12 +113,12 @@ def expand(self):

if residual > 0:
new_value = torch.stack(
torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim)
torch.unbind(new_value, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)
new_scale = torch.stack(
torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim)
torch.unbind(new_scale, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)
if self.zero_point_.shape != ():
new_zp = torch.stack(
torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim)
torch.unbind(new_zp, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)

return new_value, new_scale, new_zp

Expand Down
6 changes: 3 additions & 3 deletions src/brevitas/quant_tensor/groupwise_int_quant_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,12 +99,12 @@ def expand(self):

if residual > 0:
new_value = torch.stack(
torch.unbind(new_value, dim=self.group_dim)[residual:], dim=self.group_dim)
torch.unbind(new_value, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)
new_scale = torch.stack(
torch.unbind(new_scale, dim=self.group_dim)[residual:], dim=self.group_dim)
torch.unbind(new_scale, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)
if self.zero_point_.shape != ():
new_zp = torch.stack(
torch.unbind(new_zp, dim=self.group_dim)[residual:], dim=self.group_dim)
torch.unbind(new_zp, dim=self.group_dim)[:unpadding_shape], dim=self.group_dim)

return new_value, new_scale, new_zp

Expand Down

0 comments on commit ae52c79

Please sign in to comment.