Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat (mx): adding padding and transposed support #1007

Merged
merged 9 commits into from
Aug 26, 2024
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
82 changes: 80 additions & 2 deletions notebooks/minifloat_mx_tutorial.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -104,7 +104,8 @@
"o = ocp_fp8_model(x)\n",
"\n",
"intermediate_input = ocp_fp8_model.conv.input_quant(x)\n",
"assert isinstance(intermediate_input, FloatQuantTensor)"
"assert isinstance(intermediate_input, FloatQuantTensor)\n",
"assert isinstance(ocp_fp8_model.conv.quant_weight(), FloatQuantTensor)"
]
},
{
Expand Down Expand Up @@ -180,7 +181,84 @@
"o = mx_model(x)\n",
"\n",
"intermediate_input = mx_model.conv.input_quant(x)\n",
"assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)"
"assert isinstance(intermediate_input, GroupwiseFloatQuantTensor)\n",
"assert isinstance(mx_model.conv.quant_weight(), GroupwiseFloatQuantTensor)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If the input channel dimension is not divisible by group size, padding will be applied."
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Non padding weights shape torch.Size([64, 8, 3, 3])\n",
"Padded weights shape torch.Size([64, 32, 3, 3])\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/giuseppe/miniconda3/envs/brevitas_dev/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Defining your `__torch_function__` as a plain method is deprecated and will be an error in future, please define it as a classmethod. (Triggered internally at /opt/conda/conda-bld/pytorch_1712608853099/work/torch/csrc/utils/python_arg_parser.cpp:294.)\n",
" return F.conv2d(input, weight, bias, self.stride,\n"
]
}
],
"source": [
"class MXFloat8WeightNoPadding(MXFloat8e4m3Weight, Fp8e4m3Mixin):\n",
" # The group dimension for the weights it is automatically identified based on the layer type\n",
" # If a new layer type is used, it can be manually specified\n",
" group_size = 8\n",
"\n",
"class MXFloat8ActNoPadding(MXFloat8e4m3Act, Fp8e4m3Mixin):\n",
" # It is necessary to specify the group dimension for the activation quantization\n",
" group_size = 8\n",
" group_dim = 1\n",
"\n",
"\n",
"class MXModelNoPadding(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(8, 64, 3, weight_quant=MXFloat8WeightNoPadding, input_quant=MXFloat8ActNoPadding)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"class MXModel(nn.Module):\n",
" def __init__(self):\n",
" super().__init__()\n",
" self.conv = qnn.QuantConv2d(8, 64, 3, weight_quant=MXFloat8Weight, input_quant=MXFloat8Act)\n",
" \n",
" def forward(self, x):\n",
" return self.conv(x)\n",
"\n",
"mx_model_no_padding = MXModelNoPadding()\n",
"mx_model = MXModel()\n",
"# Make sure that the modules are the same\n",
"mx_model_no_padding.load_state_dict(mx_model.state_dict())\n",
"\n",
"x = torch.randn(1, 8, 8, 8)\n",
"mx_model.eval()\n",
"mx_model_no_padding.eval()\n",
"o_no_padding = mx_model_no_padding(x)\n",
"o = mx_model(x)\n",
"\n",
"# The quant weight of the padded model is different from the non padding one\n",
"print(f\"Non padding weights shape {mx_model_no_padding.conv.quant_weight().value.shape}\")\n",
"print(f\"Padded weights shape {mx_model.conv.quant_weight().value.shape}\")\n",
"\n",
"# However, results are still the same \n",
"assert torch.allclose(o, o_no_padding)"
]
},
{
Expand Down
19 changes: 13 additions & 6 deletions src/brevitas/core/function_wrapper/shape.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,17 +156,14 @@ def forward(self, x: torch.Tensor):
class OverSubChannelBlockView(brevitas.jit.ScriptModule):
__constants__ = ['expanded_scaling_shape']

def __init__(self, expanded_scaling_shape, permute_dims: Optional[Tuple[int, ...]]) -> None:
def __init__(self, expanded_scaling_shape, padding) -> None:
super(OverSubChannelBlockView, self).__init__()
self.expanded_scaling_shape = expanded_scaling_shape
if permute_dims is not None:
self.permute_impl = PermuteDims(permute_dims)
else:
self.permute_impl = torch.nn.Identity()
self.padding = padding

@brevitas.jit.script_method
def forward(self, x: torch.Tensor):
y = self.permute_impl(x)
y = torch.nn.functional.pad(x, self.padding, mode='constant', value=0)
y = y.view(self.expanded_scaling_shape)
return y

Expand All @@ -181,6 +178,16 @@ def __init__(self, group_size, group_dim) -> None:

@brevitas.jit.script_method
def forward(self, x):

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
padding = [0, 0] * len(tensor_shape_list)
if tensor_shape_list[self.group_dim] % self.group_size != 0:
padding[2 * self.group_dim] = self.group_size - tensor_shape_list[
self.group_dim] % self.group_size
padding = list(reversed(padding))
x = torch.nn.functional.pad(x, padding, mode='constant', value=0)

tensor_shape = x.shape
tensor_shape_list = list(tensor_shape)
tensor_shape_list[self.group_dim] = int(tensor_shape_list[self.group_dim] / self.group_size)
Expand Down
4 changes: 2 additions & 2 deletions src/brevitas/quant/solver/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,13 +172,13 @@ def int_scaling_impl(restrict_scaling_type):
class SolveStatsReduceDimFromEnum(ExtendedInjector):

@value
def stats_reduce_dim(scaling_stats_op, scaling_per_output):
def stats_reduce_dim(scaling_stats_op, scaling_per_output, group_dim=None):
if scaling_per_output == ScalingPerOutputType.CHANNEL or scaling_stats_op == StatsOp.MAX_AVE:
return SCALING_STATS_REDUCE_DIM
elif scaling_per_output == ScalingPerOutputType.TENSOR:
return None
elif scaling_per_output == ScalingPerOutputType.GROUP:
return SCALING_STATS_REDUCE_DIM + 1
return group_dim + 1

@value
def keepdim(scaling_per_output):
Expand Down
26 changes: 17 additions & 9 deletions src/brevitas/quant/solver/parameter.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,36 +111,44 @@ def scaling_impl(scaling_impl_type):
class SolveParameterScalingShape(ExtendedInjector):

@value
def scaling_shape(module, group_size=None, scaling_per_output=None):
def scaling_shape(module, group_dim, group_size=None, scaling_per_output=None):
if scaling_per_output == ScalingPerOutputType.TENSOR:
return SCALAR_SHAPE
elif scaling_per_output == ScalingPerOutputType.CHANNEL:
return this.scaling_per_output_channel_shape
elif scaling_per_output == ScalingPerOutputType.GROUP:
assert group_size is not None, "Per Group scaling requires group size"
assert group_dim is not None, "Per Group scaling requires group dim"
size = list(module.weight.shape)
assert size[1] % group_size == 0, 'Input channel is not divisible by group size'
size[1] = size[1] // group_size
size.insert(2, 1)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, 1)
return size

@value
def reshaped_scaling_shape(module):
return module.weight.shape

@value
def expanded_scaling_shape(module, group_size=None):
def expanded_scaling_shape(module, group_dim, group_size=None):
assert group_size is not None, "Per Group scaling requires group size"
size = list(module.weight.shape)
assert size[1] % group_size == 0, 'Input channel is not divisible by group size'
size[1] = size[1] // group_size
size.insert(2, group_size)
size[group_dim] = (size[group_dim] + group_size - 1) // group_size
size.insert(group_dim + 1, group_size)
return size

@value
def padding(module, group_dim, group_size):
padding = [0, 0] * len(module.weight.shape)
size = list(module.weight.shape)
if size[group_dim] % group_size != 0:
padding[2 * group_dim] = group_size - size[group_dim] % group_size
padding = list(reversed(padding))
Giuseppe5 marked this conversation as resolved.
Show resolved Hide resolved
return padding

@value
def group_dim(module, group_size=None):
if group_size is not None:
return 1
return 1 if not hasattr(module, 'transposed') or not module.transposed else 0


class SolveInputViewImpl(ExtendedInjector):
Expand Down
Loading