Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BrevitasQuantProxyHandler args passing may be sometimes wrong #986

Closed
andrei-stoian-zama opened this issue Jul 15, 2024 · 9 comments
Closed
Assignees
Labels
bug Something isn't working

Comments

@andrei-stoian-zama
Copy link
Contributor

I have a QuantLinear that I instantiate like so:

            layer = qnn.QuantLinear(
                in_features,
                out_features,
                True,
                weight_bit_width=n_w_bits,
                bias_quant=None,
                weight_narrow_range=False,
                narrow_range=False,
                signed=True,
                weight_quant=Int8WeightPerTensorFloat
            )

I export the model with BrevitasOnnxManager:


        BrevitasONNXManager.export(
            self.base_module,
            input_shape=X[[0], ::].shape,
            export_path=str(output_onnx_file_path),
            keep_initializers_as_inputs=False,
            opset_version=OPSET_VERSION_FOR_ONNX_EXPORT,
        )

when the weights are exported, the BrevitasQuantProxyHandler is called:

class BrevitasQuantProxyHandler(ONNXBaseHandler, ABC):
    def symbolic_execution(self, x: Tensor):
        scale = self.symbolic_kwargs['scale']
        zero_point = self.symbolic_kwargs['zero_point']
        bit_width = self.symbolic_kwargs['bit_width']
        if bit_width == 1:
            x = BrevitasBinaryQuantFn.apply(x, *self.symbolic_kwargs.values())
        else:
            x = BrevitasQuantFn.apply(x, *self.symbolic_kwargs.values())
        return x, scale, zero_point, bit_width

this in turn calls

class BrevitasQuantFn(Function):

    @staticmethod
    def forward(ctx, x, scale, zero_point, bit_width, narrow_range, signed, rounding_mode):
        float_to_int_impl = solve_float_to_int_impl_from_enum(rounding_mode)
        quant = IntQuant(
            float_to_int_impl=float_to_int_impl(),
            tensor_clamp_impl=TensorClamp(),
            narrow_range=narrow_range,
            signed=signed)
        y = quant(scale, zero_point, bit_width, x)
        return y

In BrevitasQuantFn the arguments to forward are positional, but these arguments are given as *self.symbolic_kwargs.values() where symbolic_kwargs is a dictionary.

I think there is an issue with the order of arguments as sometimes, while I set signed=True/narrow=False everywhere, the exported onnx shows the opposite:

image

It would seem to be related to the dicitonary .values not being ordered, as symbolic_kwargs is a dict.

Do you know of any workaround?

@andrei-stoian-zama andrei-stoian-zama changed the title BrevitasQuantProxyHandler BrevitasQuantProxyHandler args passing is wrong Jul 15, 2024
@andrei-stoian-zama andrei-stoian-zama changed the title BrevitasQuantProxyHandler args passing is wrong BrevitasQuantProxyHandler args passing may be sometimes wrong Jul 16, 2024
@nickfraser nickfraser self-assigned this Jul 16, 2024
@nickfraser
Copy link
Collaborator

Quick response while I investigate further. Firstly, if you want a kwarg to apply to the weight quantizer, you need to prefix it with weight_<quantizer_arg>. I.e., your narrow_range and signed kwargs are not being applied to the weight quantizer, so narrow_range=False and signed=True, is being ignored.

So before we discuss export, could you please check that the quantizer is being instantiated as you intend? You can quickly check with some code like this:

import brevitas.nn as qnn
from brevitas.quant import Int8WeightPerTensorFloat

layer = qnn.QuantLinear(
    2,
    3,
    True,
    weight_bit_width=7, # weight_ prefix
    bias_quant=None,
    weight_narrow_range=False, # weight_ prefix
    weight_signed=True, # weight_ prefix
    weight_quant=Int8WeightPerTensorFloat
)

print(layer.weight_quant.tensor_quant.int_quant.signed)
print(layer.weight_quant.tensor_quant.int_quant.narrow_range)
print(layer.weight_quant.tensor_quant.msb_clamp_bit_width_impl())

With your chosen values inserted.

However, I would've expected signed=True to be the default, so I'm not sure that this is the only issue - I will continue to investigate. Please close the issue if the export issue disappears once you've applied these arguments with the correct prefix.

@nickfraser
Copy link
Collaborator

Also, thanks for the detailed issue description!

@andrei-stoian-zama
Copy link
Contributor Author

andrei-stoian-zama commented Jul 16, 2024

I tried changing signed to weight_signed and the same problem occurs. Indeed signed is default for weights in Int8WeightPerTensorFloat -> NarrowIntQuant

    narrow_range = True
    signed = True

but the weight_narrow_range in my layer definition will override this narrow_range here.

I checked with the debugger so I'm pretty sure the problem is the one I detailed above. If you can not reproduce I can try to provide minimal code, but the order of the symbolic_args dictionary will be arbitrary and you might not be able to reproduce.

A solution could be to use introspection to list the arguments of the quantization function and do something to make sure the caller sorts the symbolic args in the right order.

Another approach would be to use kwargs in BrevitasQuantFn but since this function needs to be traced by pytorch I'm not sure it's doable.

@nickfraser
Copy link
Collaborator

Thanks for confirming. I haven't been able to reproduce it yet... What python version are you using? It seems that dictionary order for dict.values() should be maintained for python>=3.6, however, I see the example code in the documentation doesn't update the order until python>=3.7. The latest documentation seems to also show insertion order is maintained.

Sources:

@andrei-stoian-zama
Copy link
Contributor Author

I am using 3.8 so you're right, the order of the definition should be preserved, but maybe something weird is going on because a copy is made and some .pop calls are executed on the symbolic_args

    def symbolic_execution(self, x: Tensor, input_scale=None, input_bit_width=None):
        # avoid in place pop in case the proxy is shared
        symbolic_kwargs = copy(self.symbolic_kwargs)
        scale = symbolic_kwargs.pop('scale')
        bit_width = symbolic_kwargs.pop('bit_width')
        zero_point = symbolic_kwargs.pop('zero_point')
        if scale is None:
            assert input_scale is not None, 'Input scale required for bias export'
            scale = input_scale
        if bit_width is None:
            assert input_bit_width is not None, 'Input bit_width required for bias export'
            bit_width = input_bit_width
        y = BrevitasQuantFn.apply(x, scale, zero_point, bit_width, *symbolic_kwargs.values())
        return y, scale, zero_point, bit_width

while the original self.symbolic_kwargs has the right order:


class BrevitasWeightQuantProxyHandler(BrevitasQuantProxyHandler):
 ....
    def prepare_for_export(self, module: WeightQuantProxyFromInjector):
        if module.is_quant_enabled:
            first_qweight = module.tracked_module_list[0].quant_weight()
            self.validate(first_qweight.bit_width, first_qweight.zero_point)
            self.quant_weight_values = {
                tm.weight.data_ptr(): tm.quant_weight().value for tm in module.tracked_module_list}
            self.symbolic_kwargs = {
                'scale': first_qweight.scale,
                'zero_point': first_qweight.zero_point,
                'bit_width': first_qweight.bit_width,
                'signed': first_qweight.signed,
                # narrow_range is not a property of the QuantTensor, take it from the proxy instead
                'narrow_range': module.is_narrow_range,
                # override rounding mode since quantization has been pre-applied
                'rounding_mode': 'ROUND'}

@nickfraser
Copy link
Collaborator

Good spotting - I'll continue to investigate and get back to you

@nickfraser
Copy link
Collaborator

while the original self.symbolic_kwargs has the right order:

Actually, a quick inspection tells me that self.symbolic_kwargs are defined in the wrong order - signed and narrow_range must be swapped. Can you test #988 and see if it solves your issue?

@nickfraser nickfraser added the bug Something isn't working label Jul 16, 2024
@andrei-stoian-zama
Copy link
Contributor Author

That fixes the bug!

I'm not sure there's a way to add test for it as you could not repro in the first place. Maybe and idea would be to make a test that looks at the self.symbolic_args value, and asserts if the keys are not in the right order ?

@nickfraser
Copy link
Collaborator

The QONNX tests need a revamp. I'm going to open that as a separate issue and address it when we revamp those tests.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants