Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: set config into weights for quantization feature support more easily #400

Merged
merged 3 commits into from
Apr 10, 2024

Conversation

thincal
Copy link
Contributor

@thincal thincal commented Apr 9, 2024

What does this PR do?

Fixes #399

Before submitting

  • This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
  • Was this discussed/approved via a Github issue or the discord / slack channel? Please add a link
    to it if that's the case.
  • Did you write any new necessary tests?

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.

@tgaddair, thanks.

@thincal thincal changed the title refactor: set config into weights for better quantization feature su… refactor: set config into weights for better quantization feature support Apr 9, 2024
@thincal thincal changed the title refactor: set config into weights for better quantization feature support refactor: set config into weights for quantization feature support more easily Apr 9, 2024
@tgaddair
Copy link
Contributor

tgaddair commented Apr 9, 2024

Thanks for the PR @thincal! I noticed there are a couple of issues I'm getting when attempting to test this with AWQ and GPTQ quants. Do you know what's going on here?

TheBloke/Mistral-7B-Instruct-v0.1-AWQ --quantize awq

Output:

File "/data/lorax/server/lorax_server/models/__init__.py", line 181, in get_model                                                                                                                        
    return FlashMistral(                                                                                                                                                                                   
                                                                                                                                                                                                           
  File "/data/lorax/server/lorax_server/models/flash_mistral.py", line 76, in __init__                                                                                                                     
    model = FlashMistralForCausalLM(config, weights)                                                                                                                                                       
                                                                                                                                                                                                           
  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 540, in __init__                                                                                           
    self.model = MistralModel(config, weights)                                                                                                                                                             
                                                                                                                                                                                                           
  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 478, in __init__                                                                                           
    [                                                                                                                                                                                                      
                                                                                                                                                                                                           
  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 479, in <listcomp>                                                                                         
    MistralLayer(                                                                                                                                                                                          
                                                                                                                                                                                                           
  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 412, in __init__                                                                                           
    self.self_attn = MistralAttention(                                                                                                                                                                     
                                                                                                                                                                                                           
  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 250, in __init__                                                                                           
    self.query_key_value = load_attention(config, prefix, weights, layer_id)                                                                                                                               
                                                                                                                                                                                                           
  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 166, in load_attention                                                                                     
    base_layer = load_attention_multi(config, prefix, weights)                                                                                                                                             
                                                                                                                                                                                                           
  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 183, in load_attention_multi                                                                               
    return _load_gqa(config, prefix, weights)                                                                                                                                                              

  File "/data/lorax/server/lorax_server/models/custom_modeling/flash_mistral_modeling.py", line 198, in _load_gqa
    weight = weights.get_multi_weights_col(

  File "/data/lorax/server/lorax_server/utils/weights.py", line 227, in get_multi_weights_col
    bits, groupsize = self._get_bits_and_groupsize()

  File "/data/lorax/server/lorax_server/utils/weights.py", line 323, in _get_bits_and_groupsize
    bits = self.get_tensor("gptq_bits").item()

  File "/data/lorax/server/lorax_server/utils/weights.py", line 137, in get_tensor
    filename, tensor_name = self.get_filename(tensor_name)

  File "/data/lorax/server/lorax_server/utils/weights.py", line 124, in get_filename
    raise RuntimeError(f"weight {tensor_name} does not exist")

RuntimeError: weight gptq_bits does not exist

Same error with TheBloke/Mistral-7B-Instruct-v0.2-GPTQ --quantize gptq.

@thincal
Copy link
Contributor Author

thincal commented Apr 10, 2024

@tgaddair thanks for the test, I have checked the field name from config.json is wrongly specified, already fixed now.

Copy link
Contributor

@tgaddair tgaddair left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Verified it works with AWQ, GTPQ, and unquantized models.

@tgaddair tgaddair merged commit 70db455 into predibase:main Apr 10, 2024
1 check failed
@thincal thincal deleted the fix/refactor-quantize-config branch April 11, 2024 04:48
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Refactor the quantization config for weights
2 participants