Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Support for BitNet Architecture Inference #2664

Open
wants to merge 8 commits into
base: main
Choose a base branch
from

Conversation

JoseCarlosGarcia95
Copy link

Introduction

Hello! My name is José Carlos, and I hold a degree in Mathematics. Alongside other talented individuals, I co-founded a company where we focus on AI and infrastructure solutions.

I’ve always believed that BitNet is one of the most significant breakthroughs of the past year. I have a personal obsession with this architecture, as it embodies the potential to balance performance and efficiency in language models—a pursuit that deeply motivates me.


Changes Made

  1. Added Support for BitNet Architecture Inference:

    • Implemented initial support to infer the BitNet architecture within Candle.
    • Note: This implementation does not yet include quantization support for BitNet.
  2. New Example Added:

    • Introduced a new example to test and demonstrate the newly added BitNet inference functionality.
  3. Supported HF Models:
    The following models were tested successfully:

    • "1bitLLM/bitnet_b1_58-large": BitNet B1 58 Large.
    • "1bitLLM/bitnet_b1_58-3B": BitNet B1 58 3B.

    Future support will be added for:

    • "HF1BitLLM/Llama3-8B-1.58-100B-tokens": Llama 3 (8B, 1.58).

Known Limitations

  • Current implementation does not support quantization for BitNet.
  • Matrix multiplication methods in this implementation are not optimized yet.

I plan to address these in future updates and prepare a PR with:

  1. Support for quantizations tailored to BitNet.
  2. Optimized methods for matrix multiplication.

Roadmap

  • Add support for Llama 3 (8B, 1.58).
  • Explore and implement methods to enhance matrix multiplication performance.

Thank you for considering this PR! Feedback is welcome, and I’m excited to continue contributing to this amazing project. 😊

Signed-off-by: José Carlos García <[email protected]>
Signed-off-by: José Carlos García <[email protected]>
@JoseCarlosGarcia95 JoseCarlosGarcia95 marked this pull request as draft December 9, 2024 16:10
@JoseCarlosGarcia95 JoseCarlosGarcia95 marked this pull request as ready for review December 9, 2024 16:11
@LaurentMazare
Copy link
Collaborator

The bit-linear tests that you added seem to be broken, would you mind having a look?

Signed-off-by: José Carlos García <[email protected]>
@JoseCarlosGarcia95
Copy link
Author

JoseCarlosGarcia95 commented Dec 9, 2024

The bit-linear tests that you added seem to be broken, would you mind having a look?

Done! @LaurentMazare

@LaurentMazare
Copy link
Collaborator

Thanks, could you also provide details on how the model results were lined up with the python implementation? Did you ensure that the logits generated by the candle version are somewhat in line?

@JoseCarlosGarcia95
Copy link
Author

@LaurentMazare Everything stems from: The Era of 1-bit LLMs - Training Tips, Code, FAQ

Since the inference for Llama and Linear was already implemented in this project, I used the existing Llama implementation as a foundation and applied the following changes, based on the paper and as seen here:

  • Replace Linear with BitLinear in MLP and Attention.
  • Add weights and layers for RMSNorm before calling BitLinear, as shown here:
  • Implement activation_quant and weight_quant equivalently using Candle.

In principle, unless I’m mistaken, everything seems consistent with the Python implementation. You can test it in situ as follows:

cargo run --example llama-bitnet --features metal

Thank you so much for reviewing!

@LaurentMazare
Copy link
Collaborator

In order to check the consistency, the best would be to generate the logits for the same prompt on the candle and python side and check that they are reasonably close. That's what we do for most models before adding them, would you mind giving it a try?

@JoseCarlosGarcia95
Copy link
Author

@LaurentMazare

Since this is my first contribution, I don't have much experience with this, but it seems that the logits from the Python implementation and my implementation are similar:

I used the code from this link in Python, and in the code I created, I added a print of the logits (in Candle) and a print of the output.score in Python.

Here’s what I observed:

Candle:
[-3.4375, -10.9765625, 1.171875, -0.4970703, -0.88427734, 2.0449219, 3.1738281, -0.15246582, 0.5522461, 1.4648438, 2.34375, 0.8886719, 4.7539063]

Python:
[-3.458984375, -11.28125, 1.1884765625, -0.52587890625, -0.81103515625, 2.021484375, 3.17578125, -0.173828125, 0.568359375, 1.5283203125, 2.455078125, 0.96044921875, 4.84375]

They seem equivalent, except for numerical error.

@JoseCarlosGarcia95
Copy link
Author

I have been checking and the transformers library does not have support for these models, however, using the code from the model repository, it generates the same logits.

https://huggingface.co/1bitLLM/bitnet_b1_58-xl/tree/main

@codesoda
Copy link

I'd love to see this support the just released Falcon3 1.58bit model :)

https://huggingface.co/tiiuae/Falcon3-7B-Instruct-1.58bit

@JoseCarlosGarcia95
Copy link
Author

Not working with this model! I'll check asap :) Thank you for reporting it! @codesoda

Signed-off-by: José Carlos García <[email protected]>
@noppej
Copy link

noppej commented Dec 17, 2024

@JoseCarlosGarcia95 Great work, and thank you for this. I don't want to slow down the completion of this PR, but it should be noted somewhere that this PR does not work when loading the model from a .gguf file. The data types for the BitNet tensors are not supported by the existing Gguf enums.
IMHO your PR is still very useful, and the GGUF support can be tackled as a separate PR / TODO.

@JoseCarlosGarcia95
Copy link
Author

@noppej Thank you! I’m currently working on adding support for quantized models in parallel.

The models available at https://huggingface.co/tiiuae/Falcon3-1B-Instruct-1.58bit and https://huggingface.co/HF1BitLLM/Llama3-8B-1.58-100B-tokens/tree/main are already quantized, so this PR isn’t compatible with them.

My goal is to create a separate PR to support both models and also include support for the methods outlined here: https://github.com/microsoft/BitNet/tree/main.

@JoseCarlosGarcia95
Copy link
Author

Hi @LaurentMazare

I hope you’re doing well. I wanted to kindly follow up on this PR to see if there’s anything else needed from my side to facilitate the review or integration.

Thank you again for your time and support, and I look forward to hearing back!

@JoseCarlosGarcia95
Copy link
Author

Fyi #2683

@JoseCarlosGarcia95
Copy link
Author

Added support for Falcon3 @codesoda and added support for loading GGUF @noppej (#2683) :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants