Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Changing bit_width of an already created quantizer #702

Closed
AdrianGras opened this issue Sep 6, 2023 · 2 comments
Closed

Changing bit_width of an already created quantizer #702

AdrianGras opened this issue Sep 6, 2023 · 2 comments
Assignees

Comments

@AdrianGras
Copy link

Unable to Modify Bitwidth Without Re-initializing or re-instantiate Quantizer

We have been working with the Brevitas library for quantizing neural networks and have encountered a challenge when attempting to change the bitwidth without having to re-instantiate the quantizer.

Current Approach:

map = {}
class bx_quantizer(Int8WeightPerTensorFixedPoint):
    @value
    def bit_width(module):
        if module in map:
            return map[module]
        else:
            map[module] = 8
            return 8

linear1 = qnn.QuantLinear(2, 3, bias=True, weight_quant=bx_quantizer)
map[linear1] = 7
linear1.weight_quant.init_tensor_quant()

Issue:
Using init_tensor_quant() causes discrepancies when summing two tensors within a model. Although both tensors should maintain identical scales, invoking init_tensor_quant() breaks this scale synchronization.

image

Switching from generic target backend to layerwise eliminates the error. However, it internally ceases to enforce identical scales when summing, which is not the desired behavior.

Attempted Solution:

We tried altering the bitwidth of the quantizer directly but were unable to determine the attribute's storage location. We found the following:

class BitWidthConst(brevitas.jit.ScriptModule):
    ...
    def __init__(self, bit_width: int) -> None:
        super(BitWidthConst, self).__init__()
        assert isinstance(bit_width, int)
        self.bit_width = StatelessBuffer(torch.tensor(float(bit_width)))

    @brevitas.jit.script_method
    def forward(self) -> Tensor:
        return self.bit_width()

Nevertheless, upon encountering StatelessBuffer, we're left unsure on the next steps.

Request:
We would greatly appreciate any guidance or recommendations on resolving this issue. Thanks in advance!

@Giuseppe5
Copy link
Collaborator

Let's address the two issues separately.
One regards how to maintain the scale factors aligned after changing the bit width, the other is how to change the bit width without re-initializing the quantizer (in general, re-initializing the quantizer could have unexpected side effects).

For the first issue, there are some modification in the brevitas source code that I could suggest for you to try out, while we discuss and test them internally.
The scale factor alignement is handled mainly through this function .
If you replace it with the following, then the bit width is also synced across the branches of an add/cat.

def align_input_quant(
        module, shared_quant_identity, shared_quant_identity_name, quant_identity_map, align_sign):
    """
    Based on the input module, the function decides how to align its output.
    """
    # If it is a QuantIdentity already, simply modify tensor_quant or the scaling implementations
    # based on whether we need to align the sign or not
    if isinstance(module, qnn.QuantIdentity):
        if align_sign or module.is_quant_act_signed == shared_quant_identity.is_quant_act_signed:
            return shared_quant_identity
        else:
            assert not module.is_quant_act_signed and shared_quant_identity.is_quant_act_signed
            quant_module_class, quant_module_kwargs = quant_identity_map['unsigned']
            return (
                quant_module_class,
                {
                    **quant_module_kwargs,
                    'bit_width_impl': 
                        shared_quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant
                        .msb_clamp_bit_width_impl,
                    'scaling_impl':
                        shared_quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant
                        .scaling_impl,
                    'int_scaling_impl':
                        shared_quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant
                        .int_scaling_impl})
    elif hasattr(module, 'output_quant'):
        return (type(module), {'output_quant': shared_quant_identity})
    # If it is a QuantAct where the scaling can be determined through stats (thus through calibration),
    # then adapt its act_quant according to align_sign.
    elif hasattr(module, 'act_quant') and not isinstance(
            module.act_quant.fused_activation_quant_proxy.tensor_quant.scaling_impl,
        (ParameterScaling, ConstScaling)):
        module_type = type(module)
        if align_sign:
            partial_config = {
                'signed':
                    shared_quant_identity.act_quant.is_signed,
                'tensor_quant':
                    shared_quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant}
        else:
            partial_config = {
                'bit_width_impl': 
                    shared_quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant
                    .msb_clamp_bit_width_impl,
                'scaling_impl':
                    shared_quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant
                    .scaling_impl,
                'int_scaling_impl':
                    shared_quant_identity.act_quant.fused_activation_quant_proxy.tensor_quant
                    .int_scaling_impl}
        injector = module.act_quant.quant_injector.let(**partial_config)
        return module_type(act_quant=injector, return_quant_tensor=True)
    # In all other cases, return the name of the QuantIdentity that will be added at the output of
    # the module
    else:
        return shared_quant_identity_name

This means that when you change the bit width of one of the two branches, the other will be synced as well.
There could be other ways to align the scales without aligning the bit width but it will be more complex to set up.

Regarding the second issue: How to change the bit width without re-initializing the tensor, the idea is the following. Given a Quant module with "standard" quantizers, generally the class containing the bit_width value is of type BitWidthConst. You could simply iterate through the modules to identify the BitWidthConst class and update the value:

linear1 = qnn.QuantLinear(2, 3, bias=True)
device = next(iter(linear1.parameters())).device
dtype = next(iter(linear1.parameters())).dtype
def update_bitwidth(module, new_value):
    for name, submodule in module.named_modules():
        if isinstance(submodule, BitWidthConst):
            submodule.bit_width.value = new_value
update_bitwidth(linear1.weight_quant, torch.tensor(2., device=device, dtype=dtype))

Here I passed only the weight quantizer, because a QuantLayer could have several quantizers within the module (e.g. weight, bias, input, output for QuantLinear), so if you pass the entire QuantLayer, all the quantizers would be updated.

Maybe @volcacius could chime in especially regarding the second issue, if there are more clean ways to do this.

@Giuseppe5 Giuseppe5 self-assigned this Sep 19, 2023
@AdrianGras
Copy link
Author

Thank you so much. This solved the problem.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants