Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Support Microsoft Runtime Kernel Lib for our Low Precision Computation #6036

Open
wants to merge 46 commits into
base: main
Choose a base branch
from

Conversation

LeiWang1999
Copy link

@LeiWang1999 LeiWang1999 commented Jul 1, 2024

Hi all, this PR introduces support for the Microsoft Runtime Kernel Library to enhance our low precision computation capabilities.

Brief Introduction of BitBLAS

BitBLAS is a library to support mixed-precision BLAS operations on GPUs, for example, the $W_{wdtype}A_{adtype}$ mixed-precision matrix multiplication where $C_{cdtype}[M, N] = A_{adtype}[M, K] \times W_{wdtype}[N, K]$.
BitBLAS aims to support efficient mixed-precision DNN model deployment, especially the $W_{wdtype}A_{adtype}$ quantization in large language models (LLMs), for example, the $W_{UINT4}A_{FP16}$ in GPTQ, the $W_{INT2}A_{FP16}$ in BitDistiller, the $W_{INT2}A_{INT8}$ in BitNet-b1.58.

PR Overview

This PR integrates BitBLAS into vLLM by adding examples of its usage. We provide two forms:

  1. Load from GPTQ Checkpoints: This allows the loading of models from GPTQ format checkpoints.
  2. Load from GPTQ CKPT with BitBLAS Format: This enables the loading of models using the BitBLAS format for further optimized performance.

Below are the benchmarking results that we evaluated several months ago:

TODO ITEMS

  • Update and provide the latest benchmarking results.
  • 1.58Bits Model
  • Provide Benchmark/Test Scripts

Any feedback and suggestions to improve this integration are appreciated.

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

Nice!

@LeiWang1999
Copy link
Author

LeiWang1999 commented Jul 1, 2024

BTW, are there any tools available that can automatically resolve lint issues?

vllm/model_executor/layers/quantization/gptq_bitblas.py:28:1: E402 Module level import not at top of file
vllm/model_executor/layers/quantization/gptq_bitblas.py:28:8: F811 Redefinition of unused `bitblas` from line 21
vllm/model_executor/layers/quantization/gptq_bitblas.py:29:1: E402 Module level import not at top of file
vllm/model_executor/layers/quantization/gptq_bitblas.py:66:81: E501 Line too long (107 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:172:81: E501 Line too long (85 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:222:81: E501 Line too long (105 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:230:81: E501 Line too long (89 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:233:81: E501 Line too long (110 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:236:81: E501 Line too long (99 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:242:81: E501 Line too long (84 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:253:81: E501 Line too long (94 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:414:81: E501 Line too long (86 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:417:29: G004 Logging statement uses f-string
vllm/model_executor/layers/quantization/gptq_bitblas.py:420:17: G004 Logging statement uses f-string
vllm/model_executor/layers/quantization/gptq_bitblas.py:427:81: E501 Line too long (103 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:433:81: E501 Line too long (116 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:454:81: E501 Line too long (82 > 80)

@robertgshaw2-neuralmagic
Copy link
Sponsor Collaborator

BTW, are there any tools available that can automatically resolve lint issues?

vllm/model_executor/layers/quantization/gptq_bitblas.py:28:1: E402 Module level import not at top of file
vllm/model_executor/layers/quantization/gptq_bitblas.py:28:8: F811 Redefinition of unused `bitblas` from line 21
vllm/model_executor/layers/quantization/gptq_bitblas.py:29:1: E402 Module level import not at top of file
vllm/model_executor/layers/quantization/gptq_bitblas.py:66:81: E501 Line too long (107 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:172:81: E501 Line too long (85 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:222:81: E501 Line too long (105 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:230:81: E501 Line too long (89 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:233:81: E501 Line too long (110 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:236:81: E501 Line too long (99 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:242:81: E501 Line too long (84 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:253:81: E501 Line too long (94 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:414:81: E501 Line too long (86 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:417:29: G004 Logging statement uses f-string
vllm/model_executor/layers/quantization/gptq_bitblas.py:420:17: G004 Logging statement uses f-string
vllm/model_executor/layers/quantization/gptq_bitblas.py:427:81: E501 Line too long (103 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:433:81: E501 Line too long (116 > 80)
vllm/model_executor/layers/quantization/gptq_bitblas.py:454:81: E501 Line too long (82 > 80)

./format.sh fixes whatever it can, but not everything is automated for fixing (esp line length)

@mgoin
Copy link
Sponsor Collaborator

mgoin commented Jul 1, 2024

@LeiWang1999 thanks for the WIP, very cool interface with bitblas as a package. Can you explain if the GPTQ benchmarking results in vLLM were run with the base "gptq" kernels or using the "gptq_marlin" interface to take advantage of Marlin kernels? This will be important to compare with the current baseline we consider for GPTQ models in vLLM

@LeiWang1999
Copy link
Author

@LeiWang1999 thanks for the WIP, very cool interface with bitblas as a package. Can you explain if the GPTQ benchmarking results in vLLM were run with the base "gptq" kernels or using the "gptq_marlin" interface to take advantage of Marlin kernels? This will be important to compare with the current baseline we consider for GPTQ models in vLLM

Thanks, it utilized exllamav2 during our benchmarking at that time; we will examine the comparison with the Marlin kernel.

@LeiWang1999
Copy link
Author

LeiWang1999 commented Jul 19, 2024

Hi all, I recently update the the supports for 1.58bits model and related bitblas inference kernel for vllm.

    Token Per Second(tok/s)    
model framework BS16IN32OUT128 BS1IN512OUT1024 B32IN32OUT128
bitnet-3b-1.58bits pytorch 106.83 49.34 209.03
bitnet-3b-1.58bits pytorch-bitblas 240.33 103.09 493.31
bitnet-3b-1.58bits vllm-bitblas 379.25 117.43 752.55
bitnet-3b-1.58bits vllm-bitblas-cuda-graph 2543.58 1621.08 2731.79

@LeiWang1999 LeiWang1999 marked this pull request as ready for review July 19, 2024 04:23
@LeiWang1999
Copy link
Author

We will soon do benchmarking with marlin, and looks like the docs build failed because of the dependency for bitblas, do you have any ideas to fix this issue? should we put the bitblas requirements to the doc/requirements or is there some options to skip this dependency? @mgoin

@LeiWang1999
Copy link
Author

LeiWang1999 commented Aug 20, 2024

I think this PR is ready for review. Here is a summary of this update:

We now support BitBLAS as a quantized backend and can use vLLM to serve pretrained models from Hugging Face (in GPTQ, BitNet, or BitBLAS format) with the BitBLAS inference kernel.

We briefly tested the performance using Marlin with the throughput benchmark scripts provided by vLLM on A100:

python benchmark_throughput.py --backend vllm --num-prompts 1 --input-len 32 --output-len 512 --max-model-len 1024 --model "hxbgsyxh/llama-13b-4bit-g-1-bitblas" --quantization "bitblas" 

python benchmark_throughput.py --backend vllm --num-prompts 1 --input-len 32 --output-len 512 --max-model-len 1024 --model "hxbgsyxh/llama-13b-4bit-g-1-marlin" --quantization "marlin" 

The performance results are:

  • Marlin: 122.67 toks/s
  • BitBLAS: 127.11 toks/s

Some notes:

  • Marlin requires a workspace for spin lock to perform global reduction, while BitBLAS doesn’t require it.
  • BitBLAS supports more complex cases compared to Marlin, such as sym=False or 2-bit formats.

Moreover, this PR also adds support for the 1.58-bit BitNET model.

Model Framework BS16IN32OUT128 BS1IN512OUT1024 B32IN32OUT128
bitnet-3b-1.58bits PyTorch 106.83 49.34 209.03
bitnet-3b-1.58bits PyTorch-BitBLAS 240.33 103.09 493.31
bitnet-3b-1.58bits vLLM-BitBLAS 379.25 117.43 752.55
bitnet-3b-1.58bits vLLM-BitBLAS-CUDA-Graph 2543.58 1621.08 2731.79

All correctness checks have been evaluated with the following:

from conftest import VllmRunner
import torch

# Test BitNET model with BitBLAS quantization
with VllmRunner(
    "hxbgsyxh/bitnet_b1_58-3B",
    dtype="half",
    quantization="bitnet_bitblas",
    enforce_eager=True,
    gpu_memory_utilization=0.5,
) as bitnet_model:
    bitbnet_outputs = bitnet_model.generate_greedy(
        ["Hi, tell me about Microsoft?"], max_tokens=128
    )
    print("bitnet_bitblas:")
    print(bitbnet_outputs[0][0])
    print(bitbnet_outputs[0][1])

# Test another BitBLAS model
with VllmRunner(
    "hxbgsyxh/bitnet_b1_58-3B_bitblas",
    dtype="half",
    quantization="bitblas",
    enforce_eager=True,
) as bitnet_model:
    torch.cuda.profiler.start()
    bitbnet_outputs = bitnet_model.generate_greedy(
        ["Hi, tell me about Microsoft?"], max_tokens=128
    )
    torch.cuda.profiler.stop()
    print("bitblas:")
    print(bitbnet_outputs[0][0])
    print(bitbnet_outputs[0][1])

# Test GPTQ quantized model
with VllmRunner(
    "hxbgsyxh/opt-125m-4bit-128g",
    dtype="half",
    quantization="gptq",
    enforce_eager=True,
) as marlin_model:
    torch.cuda.profiler.start()
    bitbnet_outputs = marlin_model.generate_greedy(
        ["Hi, tell me about Microsoft?"], max_tokens=128
    )
    torch.cuda.profiler.stop()
    print("bitblas:")
    print(bitbnet_outputs[0][0])
    print(bitbnet_outputs[0][1])

torch.compiler.reset()

# Test GPTQ quantized model with BitBLAS
with VllmRunner(
    "hxbgsyxh/opt-125m-4bit-128g-bitblas",
    dtype="half",
    quantization="bitblas",
    enforce_eager=True,
) as bitblas_model:
    torch.cuda.profiler.start()
    bitbnet_outputs = bitblas_model.generate_greedy(
        ["Hi, tell me about Microsoft?"], max_tokens=128
    )
    torch.cuda.profiler.stop()
    print("bitblas:")
    print(bitbnet_outputs[0][0])
    print(bitbnet_outputs[0][1])

@LeiWang1999
Copy link
Author

any questions are welcome and please take a review when you have a moment :) @mgoin @robertgshaw2-neuralmagic

@mgoin
Copy link
Sponsor Collaborator

mgoin commented Aug 20, 2024

Thanks for all the work @LeiWang1999! I have a few high-level thoughts first on how to make landing this more straightforward:

  1. Make bitblas an optional dependency and remove from requirements-common.txt. The pypi package is ~90MB, possibly is built for a specific version of PyTorch/CUDA, and seems to include a lot of deps (TVM?). I think it is hard to require, especially for non-CUDA devices. See bitsandbytes or deepspeedfp for an example of how we usually implement lazy import with an exception message to install.
  2. Add support for bitnet in another PR. I think it is worth looking at separately and understanding the pros/cons. I find it a bit surprising that it requires implementing a whole new model and tokenizer class.
  3. gptq_bitblas seems a bit redundant without further benchmarking separating it from gptq_marlin. I agree it is certainly useful where marlin doesn't have support. Also we want to eventually perform a refactor to separate checkpoint formats from kernel implementations, so we will need to revisit this soon.

@LeiWang1999
Copy link
Author

Thanks for your suggestions, @mgoin!

  1. For the first suggestion, we’ve restructured the bitblas import to be a lazy import. Additionally, while bitblas still requires a CUDA environment, it is now linked to libcuda.so instead of being tied to a specific CUDA version :).
  2. Regarding he second suggestion, we’ve removed the bitnet-related items in this PR. Let’s continue that discussion in a separate pr.
  3. For the third suggestion, I believe keeping gptq_bitblas is still valuable for formats that Marlin doesn’t support, such as those with dynamic zero points or lower precision. It allows us to directly repack the bitblas format from a gptq checkpoint.
with VllmRunner(
    "hxbgsyxh/llama-13b-4bit-g-1", # model with gptq format
    dtype="half",
    quantization="bitblas",
    enforce_eager=True,
) as bitblas_model:
    torch.cuda.profiler.start()
    bitbnet_outputs = bitblas_model.generate_greedy(
        ["Hi, tell me about microsoft?"], max_tokens=128
    )
    torch.cuda.profiler.stop()
    print("bitblas:")
    print(bitbnet_outputs[0][0])
    print(bitbnet_outputs[0][1])

Copy link
Sponsor Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for splitting it up! I left a first round of clear nits/issues and will do a more in-depth pass later. There seem to be a lot of various formatting changes, for some reason

docs/source/quantization/bitblas.rst Outdated Show resolved Hide resolved
docs/source/quantization/bitblas.rst Outdated Show resolved Hide resolved
docs/source/quantization/bitblas.rst Outdated Show resolved Hide resolved
docs/source/quantization/bitblas.rst Outdated Show resolved Hide resolved
docs/source/quantization/bitblas.rst Outdated Show resolved Hide resolved
vllm/config.py Outdated Show resolved Hide resolved
vllm/config.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/bitblas.py Outdated Show resolved Hide resolved
vllm/model_executor/layers/quantization/bitblas.py Outdated Show resolved Hide resolved
@LeiWang1999
Copy link
Author

Hi @mgoin , are there any further updates or actions we should take?

Copy link
Sponsor Collaborator

@mgoin mgoin left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @LeiWang1999 I'm very sorry for the delay, I lost track of this PR and didn't catch your ping.

There has been an ongoing refactor for quantization methods to use a new set of vLLMParameters (see gptq_marlin PR #7281) to simplify weight loading, but we could delay this for bitblas to make it easier to land this initial PR.

Also as mentioned in #7725 (comment), there will be a few merge conflicts with main.

If/when you have bandwidth to finish this out, I promise to get this over the line asap. Please let me know!

Comment on lines +498 to +520
if layer.bitblas_state == GPTQBitBLASState.REPACK:
layer.bitblas_state = GPTQBitBLASState.READY

# Newly generated tensors need to replace existing tensors that are
# already registered as parameters by vLLM (and won't be freed)
def replace_tensor(name, new_t):
# It is important to use copy_() here since it ensures
# the same buffer is reused
getattr(layer, name).copy_(
new_t.view(getattr(layer, name).dtype).view(
getattr(layer, name).shape))
del new_t

# Repack weights
bitblas_qweight, bitblas_scales, bitblas_qzeros = (
self.repack_bitblas_from_gptq(
layer.qweight,
layer.scales,
layer.qzeros,
))
replace_tensor("qweight", bitblas_qweight)
replace_tensor("scales", bitblas_scales)
replace_tensor("qzeros", bitblas_qzeros)
Copy link
Sponsor Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be best to move this into a process_weights_after_loading function we have specifically for this purpose, example in gptq_marlin.py

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, I'll take a look. I'm currently working on the stream-k template in bitblas :)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants