Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple quantized parameters in the same layers #1122

Closed
balditommaso opened this issue Dec 9, 2024 · 3 comments
Closed

Multiple quantized parameters in the same layers #1122

balditommaso opened this issue Dec 9, 2024 · 3 comments
Labels
enhancement New feature or request

Comments

@balditommaso
Copy link

I am working with custom layers which have two set of trainable parameters. They follow two different distributions, so it will be useful to have two different quantization scheme for both of them in order to reduce the quantization error.

Is it already possible in Brevitas? is there ant wolkaround that I can adopt?

@balditommaso balditommaso added the enhancement New feature or request label Dec 9, 2024
@Giuseppe5
Copy link
Collaborator

Giuseppe5 commented Dec 9, 2024

The answer provided by the bot is wrong (no surprise there).
If you have 2 sets of parameters, I would recommend splitting into two layers with its own custom quantizers. The layers can be very minimal, like:

class GateWeight(QuantWeightMixin, nn.Module):
def __init__(self, input_features, output_features, weight_quant, dtype, device, **kwargs):
nn.Module.__init__(self)
self.weight = nn.Parameter(
torch.randn(output_features, input_features, dtype=dtype, device=device))
QuantWeightMixin.__init__(self, weight_quant=weight_quant, **kwargs)
@property
def output_channel_dim(self):
return 0
@property
def out_channels(self):
return self.weight.size(self.output_channel_dim)
def forward(self):
return self.weight_quant(self.weight)

Then you can have your custom layer as:

class NewLayer(torch.nn.Module):
    def __init__(self, weight_quant_1 = Int8WeightPerTensorFloat, weight_quant_2 = Int8WeightPerChannelFloat):
        self.parameter_1= GateWeight(...., weight_quant_1)
        self.parameter_2 = GateWeight(...., weight_quant_2)

    def forward(self, x):
        quant_weight_1 = self.parameter_1.quant_weight() # or self.parameter_1()
        quant_weight_2 = self.parameter_2.quant_weight() # or self.parameter_2()

I didn't test this code, take it more as pseudocode but it should give you an idea of where to start

@balditommaso
Copy link
Author

Great! Thanks for your help!

However, also the bot solution at least looks cool ;)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

3 participants
@Giuseppe5 @balditommaso and others