Skip to content

Commit

Permalink
Eetq (#195)
Browse files Browse the repository at this point in the history
  • Loading branch information
flozi00 authored Jan 23, 2024
1 parent 5e68edd commit 1be43cf
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 1 deletion.
2 changes: 1 addition & 1 deletion .github/workflows/build.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -151,4 +151,4 @@ jobs:
# Delete the SHA image(s) from containerd store
sudo ctr i rm $(sudo ctr i ls -q)
10 changes: 10 additions & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,13 @@ COPY server/punica_kernels/ .
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX"
RUN python setup.py build

# Build eetq kernels
FROM kernel-builder as eetq-kernels-builder
WORKDIR /usr/src
COPY server/Makefile-eetq Makefile
# Build specific version of transformers
RUN TORCH_CUDA_ARCH_LIST="8.0;8.6+PTX" make build-eetq

# LoRAX base image
FROM nvidia/cuda:11.8.0-base-ubuntu20.04 as base

Expand Down Expand Up @@ -194,6 +201,9 @@ COPY --from=punica-builder /usr/src/build/lib.linux-x86_64-cpython-310 /opt/cond
# Copy build artifacts from megablocks builder
COPY --from=megablocks-kernels-builder /usr/src/megablocks/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Copy build artifacts from eetq builder
COPY --from=eetq-kernels-builder /usr/src/eetq/build/lib.linux-x86_64-cpython-310 /opt/conda/lib/python3.10/site-packages

# Install flash-attention dependencies
RUN pip install einops --no-cache-dir

Expand Down
4 changes: 4 additions & 0 deletions launcher/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ enum Quantization {
BitsandbytesFP4,
Gptq,
Awq,
Eetq,
Hqq_4bit,
Hqq_3bit,
Hqq_2bit,
Expand All @@ -50,6 +51,9 @@ impl std::fmt::Display for Quantization {
Quantization::Awq => {
write!(f, "awq")
}
Quantization::Eetq => {
write!(f, "eetq")
}
Quantization::Hqq_4bit => {
write!(f, "hqq-4bit")
}
Expand Down
1 change: 1 addition & 0 deletions server/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@ include Makefile-flash-att
include Makefile-flash-att-v2
include Makefile-vllm
include Makefile-megablocks
include Makefile-eetq

unit-tests:
pytest -s -vv -m "not private" tests
Expand Down
13 changes: 13 additions & 0 deletions server/Makefile-eetq
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
eetq_commit := cc2fdb4637e03652ac264eaef44dd8492472de01 # 323827dd471458a84e9c840f614e4592b157a4b1

eetq:
# Clone eetq
pip install packaging
git clone https://github.com/NetEase-FuXi/EETQ.git eetq

build-eetq: eetq
cd eetq && git fetch && git checkout $(eetq_commit) && git submodule update --init --recursive
cd eetq && python setup.py build

install-eetq: build-eetq
cd eetq && python setup.py install
1 change: 1 addition & 0 deletions server/lorax_server/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ class Quantization(str, Enum):
bitsandbytes_fp4 = "bitsandbytes-fp4"
gptq = "gptq"
awq = "awq"
eetq = "eetq"
hqq_4bit = "hqq-4bit"
hqq_3bit = "hqq-3bit"
hqq_2bit = "hqq-2bit"
Expand Down
87 changes: 87 additions & 0 deletions server/lorax_server/utils/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,14 @@
except ImportError:
HAS_AWQ = False

HAS_EETQ = False
try:
from EETQ import quant_weights, w8_a16_gemm

HAS_EETQ = True
except ImportError:
pass

HAS_HQQ = True
try:
from hqq.core.quantize import BaseQuantizeConfig, HQQLinear
Expand Down Expand Up @@ -127,6 +135,78 @@ def forward(self, input: torch.Tensor) -> torch.Tensor:
x = torch.addmm(self.bias, input.view(-1, input.size(-1)), self.weight)
x = x.view(size_out)
return x


class EETQLinear(nn.Module):
"""
EETQLinear module applies quantized linear transformation to the input tensor.
Args:
weight (torch.Tensor): The weight tensor for the linear transformation.
bias (torch.Tensor): The bias tensor for the linear transformation.
Attributes:
weight (torch.Tensor): The weight tensor for the linear transformation.
scale (torch.Tensor): The scale tensor used for quantization.
bias (torch.Tensor): The bias tensor for the linear transformation.
"""

def __init__(
self,
weight,
bias,
) -> None:
super().__init__()
# Get the device where the weight tensor is currently stored.
device = weight.device

# Transpose the weight tensor and make a contiguous copy of it on the CPU.
# The contiguous() function is used to ensure that the tensor is stored in a contiguous block of memory,
# which can improve performance in some cases.
weight_transposed = torch.t(weight)
weight_contiguous = weight_transposed.contiguous()
weight_cpu = weight_contiguous.cpu()

# Quantize the weights. The quant_weights function is assumed to perform the quantization.
# The weights are quantized to int8 format, and the quantization is not performed in place (False).
weight_quantized, scale = quant_weights(weight_cpu, torch.int8, False)

# Move the quantized weights and the scale back to the original device (GPU if available).
# The cuda() function is used to move the tensors to the GPU.
self.weight = weight_quantized.cuda(device)
self.scale = scale.cuda(device)

# If a bias is present, move it to the GPU as well. If not, set the bias to None.
if bias is not None:
self.bias = bias.cuda(device)
else:
self.bias = None

def forward(self, input: torch.Tensor) -> torch.Tensor:
"""
Performs the forward pass of the layer.
Args:
input (torch.Tensor): The input tensor.
Returns:
torch.Tensor: The output tensor.
"""
# The function w8_a16_gemm performs a matrix multiplication operation between the input and the weight of the layer.
# The result is then scaled by a factor (self.scale).
gemm_output = w8_a16_gemm(input, self.weight, self.scale)

# If a bias is present (i.e., self.bias is not None), it is added to the output of the matrix multiplication.
# If a bias is not present (i.e., self.bias is None), the output of the matrix multiplication is returned as is.
if self.bias is not None:
final_output = gemm_output + self.bias
else:
final_output = gemm_output

# The final output is returned.
return final_output



class Linear8bitLt(nn.Module):
Expand Down Expand Up @@ -254,6 +334,13 @@ def get_linear(weight, bias, quantize, fan_in_fan_out=False):
bias,
quant_type="fp4",
)
elif quantize == "eetq":
if HAS_EETQ:
linear = EETQLinear(weight, bias)
else:
raise ImportError(
"Please install EETQ from https://github.com/NetEase-FuXi/EETQ"
)
elif quantize == "gptq":
try:
qweight, qzeros, scales, g_idx, bits, groupsize, use_exllama = weight
Expand Down

0 comments on commit 1be43cf

Please sign in to comment.