Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Custom bit-width and QuantTensor Operation #982

Closed
balditommaso opened this issue Jul 9, 2024 · 4 comments
Closed

Custom bit-width and QuantTensor Operation #982

balditommaso opened this issue Jul 9, 2024 · 4 comments

Comments

@balditommaso
Copy link

Hi,

I have two question for this issue:

  1. I wonder if it is possible to overwrite the bit_width field of a quantization strategy without creating a new class which inherits the field of the selected strategy. I tried to specify it in the arguments of the layer, but the layer is still using the bit-width of the strategy.

Here an example:

class Uint4ActPerTensorFloat(UintActPerTensorFloat):
    bit_width = 4
    
class Uint6ActPerTensorFloat(UintActPerTensorFloat):
    bit_width = 6
  1. If I want to multiply a QuantTensor by a scalar, I am getting as a result a Tensor, but I think that this case can be handled in order to return a QuantTensor too. Indeed, if you want to multiply a QuantTensor by a scalar, you can basically just update the scale field of the QuantTensor.
r = (q - Z) * s  ->  a * r = (q - Z) * s * a  ->  r' = (q - Z) * s'    where s' = s * a

It may be tricky to handle situations where the scale is forced to be something different from a FP, but I think it is a useful feature.
I tried to use set to update manually the scale but the operation is not tracked in the Pytorch graph. Can you help me with that?

Kind regards,

Tommaso

@Giuseppe5
Copy link
Collaborator

Giuseppe5 commented Jul 9, 2024

  1. I'm not sure I can reproduce what you are seeing
>>> import brevitas.nn as qnn
>>> from brevitas.quant.scaled_int import Uint8ActPerTensorFloat
>>> class NewUint4ActPerTensorFloatMaxInit(Uint8ActPerTensorFloat):
...  bit_width = 4
... 
>>> a = qnn.QuantReLU(NewUint4ActPerTensorFloat)
>>> a.act_quant.quant_injector.bit_width
4
  1. I see your point but as you said it might work in all cases, and in general we always try to stay on the safe side.
    When you say PyTorch graph, how are you capturing the graph?
    Another option would be to inherit from IntQuantTensor and override the mul handler to do what you need, but I'm still not sure if it is going to work (based on how you capture the graph)

@balditommaso
Copy link
Author

  1. This is what I am doing and I would like to avoid. Let's say that I have my NN with many layers which are following the same startegy but with different precision. What I would like to do is to specify in the arguments of each layer just the strategy and the bit-width.

for example, what I am trying to do is something like this:

from brevitas.inject.enum import *
from brevitas.core.zero_point import ZeroZeroPoint
from brevitas.quant.solver import WeightQuantSolver
import brevitas.nn as qnn

class SymmWeightPerChannelFloat(WeightQuantSolver):
    quant_type = QuantType.INT  # integer quantization
    bit_width_impl_type = BitWidthImplType.STATEFUL_CONST   # constant bit-width (saved in the model ckpt)
    bit_width = 8
    float_to_int_impl_type = FloatToIntImplType.ROUND   # round to the nearest
    scaling_impl_type = ScalingImplType.STATS   # scale computation based on statistics
    scaling_stats_op = StatsOp.MAX  # scale statistics is the abs max value (good for symmetric quant)
    restrict_scaling_type = RestrictValueType.FP    # the scale can be a Floating point
    scaling_per_output_channel = True   # channel-wise quantization
    signed = True    # quantization range is signed
    narrow_range = True # quantization range is [-127, 127] rather than [-128, 127]
    zero_point_impl = ZeroZeroPoint # Z = 0 (symmetric quant)

conv = qnn.QuantConv2d(1, 16, 
                                           kernel_size=3, 
                                           stride=2, 
                                           padding=1,
                                           bias_quant=Int16Bias,
                                           input_quant=Int8ActPerTensorFloat,
                                           weight_quant=SymmWeightPerChannelFloat,
                                           bit_width=4,
                                           return_quant_tensor=True)

print(conv.weight_quant.quant_injector.bit_width)

4 should ovrewrite 8, otherwise I have to prepare a strategy for each possible bit-width. Is there a way to do so?

  1. I see, however I tried to change the implementation of __mul__
def __mul__(self, other):
        if isinstance(other, QuantTensor) and self.is_not_none and other.is_not_none:
            output_value = self.value * other.value
            output_scale = self.scale * other.scale
            output_bit_width = self.bit_width + other.bit_width
            output_signed = self.signed or other.signed
            output_training = self.training or other.training
            if self.is_zero_zero_point(self) and self.is_zero_zero_point(other):
                output_zero_point = self.zero_point * other.zero_point
            else:
                raise RuntimeError("Zero-points of mul operands are non-zero, not supported.")
            output = QuantTensor(
                value=output_value,
                scale=output_scale,
                zero_point=output_zero_point,
                bit_width=output_bit_width,
                signed=output_signed,
                training=output_training)
        elif isinstance(other, QuantTensor):
            output = self.value * other.value
        elif (isinstance(other, Tensor) and other.dim() == 0) or np.isscalar(other):   # case of scalar
            output_value = self.value * other
            output_scale = self.scale * other
            output_signed = self.signed or (other < 0)
            output = QuantTensor(
                value=output_value,
                scale=output_scale,
                zero_point=self.zero_point,
                bit_width=self.bit_width,
                signed=output_signed,
                training=self.training)
        else:
            output = self.value * other
        return output

It's working fine, let me know if you think I may face some problems with that.

@Giuseppe5
Copy link
Collaborator

  1. Apologies, I did not understand the first time. Since a QuantLayer has several quantization elements (input, weight, bias, output), to disentangle what keyword argument refers to what part of the quantization process, we use prefixes. To override anything related to weight quantization, you would add weight_, so in your case it would be weight_bit_width. Same thing applies for input/output/bias.
  2. For experimental purposes I think it should be fine but let me give it some extra thoughts and come back to you (and feel free to ping me in case I forget)

@Giuseppe5
Copy link
Collaborator

Still unsure if 2. is solved, it depends on what you're doing with tracing. If you still have issue, feel free to re-open this and share a minimal script to reproduce.

I am going to close this for now :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants