Skip to content

Commit

Permalink
feat: re-add GGUF (#600)
Browse files Browse the repository at this point in the history
* refactor gguf kernels

* fix: incorrect filename for vecdotq header

* finish up the re-impl

* add requirements
  • Loading branch information
AlpinDale authored Sep 2, 2024
1 parent 6c1eab6 commit 0e6c400
Show file tree
Hide file tree
Showing 26 changed files with 4,821 additions and 5,070 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
"kernels/quantization/marlin/qqq/marlin_qqq_gemm_kernel.cu"
"kernels/quantization/gptq_marlin/gptq_marlin.cu"
"kernels/quantization/gptq_marlin/gptq_marlin_repack.cu"
"kernels/quantization/gguf/gguf_kernel.cu"
"kernels/quantization/gptq_marlin/awq_marlin_repack.cu"
"kernels/quantization/fp8/fp8_marlin.cu"
"kernels/all_reduce/custom_all_reduce.cu"
Expand Down
32 changes: 32 additions & 0 deletions aphrodite/_custom_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,38 @@ def marlin_qqq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
workspace, size_m, size_n, size_k)


# gguf
def ggml_dequantize(W: torch.Tensor, quant_type: int, m: int, n: int):
return torch.ops._C.ggml_dequantize(W, quant_type, m, n)


def ggml_mul_mat_vec(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
):
return torch.ops._C.ggml_mul_mat_vec(W, X, quant_type, row)


def ggml_mul_mat_vec_a8(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
):
return torch.ops._C.ggml_mul_mat_vec_a8(W, X, quant_type, row)


def ggml_mul_mat_a8(
W: torch.Tensor,
X: torch.Tensor,
quant_type: int,
row: int,
):
return torch.ops._C.ggml_mul_mat_a8(W, X, quant_type, row)


# mamba
def causal_conv1d_fwd(x: torch.Tensor, weight: torch.Tensor,
bias_: Optional[torch.Tensor],
Expand Down
1 change: 0 additions & 1 deletion aphrodite/attention/backends/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -214,7 +214,6 @@ def build(self, seq_lens: List[int], query_lens: List[int],
dtype=query_start_loc.dtype,
out=query_start_loc[1:])


return self._metadata_cls( # type: ignore
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
Expand Down
1 change: 1 addition & 0 deletions aphrodite/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -695,6 +695,7 @@ class LoadFormat(str, enum.Enum):
DUMMY = "dummy"
TENSORIZER = "tensorizer"
SHARDED_STATE = "sharded_state"
GGUF = "gguf"
BITSANDBYTES = "bitsandbytes"


Expand Down
3 changes: 3 additions & 0 deletions aphrodite/engine/args_tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,9 @@ def from_cli_args(cls, args: argparse.Namespace) -> "EngineArgs":
return engine_args

def create_engine_config(self, ) -> EngineConfig:
# gguf file needs a specific model loader and doesn't use hf_repo
if self.model.endswith(".gguf"):
self.quantization = self.load_format = "gguf"

# bitsandbytes quantization needs a specific model loader
# so we make sure the quant method and the load format are consistent
Expand Down
95 changes: 94 additions & 1 deletion aphrodite/modeling/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import torch
import torch.nn.functional as F
from loguru import logger
from torch.nn.parameter import Parameter
from torch.nn.parameter import Parameter, UninitializedParameter

# yapf: disable
from aphrodite.distributed import (divide,
Expand Down Expand Up @@ -322,6 +322,17 @@ def __init__(self,
def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_rank = get_tensor_model_parallel_rank()
output_dim = getattr(param, "output_dim", None)

# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()

# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

param_data = param.data
if output_dim is not None:
shard_size = param_data.shape[output_dim]
Expand Down Expand Up @@ -412,6 +423,27 @@ def weight_loader(self,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[int] = None):

# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.data[loaded_shard_id].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return

if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES

ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)

param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
Expand Down Expand Up @@ -483,6 +515,18 @@ def weight_loader(self,
shard_offset = loaded_weight.shape[output_dim] * \
loaded_shard_id

if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape

input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)

param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if self.quant_config is None:
Expand Down Expand Up @@ -600,6 +644,29 @@ def weight_loader(self,
param: Parameter,
loaded_weight: torch.Tensor,
loaded_shard_id: Optional[str] = None):

# Special case for GGUF
# initialize GGUF param after we know the quantize type
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type and loaded_shard_id is not None:
idx_map = {"q": 0, "k": 1, "v": 2}
param.data[idx_map[loaded_shard_id]].copy_(loaded_weight)
param.shard_weight_type[loaded_shard_id] = loaded_weight.item()
return

if is_gguf_weight and isinstance(param, UninitializedParameter):
from gguf.constants import GGML_QUANT_SIZES

ori_shape = param.tensor_shape
weight_types = self.qweight_type.shard_weight_type.values()
row_size = []
for weight_type in weight_types:
block_size, type_size = GGML_QUANT_SIZES[weight_type]
row_size.append(ori_shape[1] // block_size * type_size)
q_shape = (ori_shape[0], max(row_size))
param.materialize(q_shape, dtype=loaded_weight.dtype)

param_data = param.data
output_dim = getattr(param, "output_dim", None)
# Special case for AQLM codebooks.
Expand Down Expand Up @@ -695,6 +762,18 @@ def weight_loader(self,
shard_size, shard_offset = adjust_bitsandbytes_shard(
param, orig_qkv_offsets, loaded_shard_id)

if is_gguf_weight:
tp_size = get_tensor_model_parallel_world_size()
output_dim = getattr(param, "output_dim", None)
shard_shape = list(loaded_weight.shape)
shard_shape[output_dim] = shard_shape[output_dim] // tp_size
param.shard_id.append(loaded_shard_id)
param.shard_size[loaded_shard_id] = shard_shape

input_dim = getattr(param, "input_dim", None)
input_size = loaded_weight.shape[input_dim]
param_data = param_data.narrow(input_dim, 0, input_size)

param_data = param_data.narrow(output_dim, shard_offset,
shard_size)
if self.quant_config is None:
Expand Down Expand Up @@ -814,7 +893,21 @@ def __init__(self,
self.register_parameter("bias", None)

def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
tp_size = get_tensor_model_parallel_world_size()
input_dim = getattr(param, "input_dim", None)
# Special case for GGUF
is_gguf_weight = getattr(param, "is_gguf_weight", False)
is_gguf_weight_type = getattr(param, "is_gguf_weight_type", False)
if is_gguf_weight_type:
param.weight_type = loaded_weight.item()

# Materialize GGUF UninitializedParameter
if is_gguf_weight and isinstance(param, UninitializedParameter):
weight_shape = list(loaded_weight.shape)
if input_dim:
weight_shape[input_dim] = weight_shape[input_dim] // tp_size
param.materialize(tuple(weight_shape), dtype=loaded_weight.dtype)

param_data = param.data
if input_dim is not None:
shard_size = param_data.shape[input_dim]
Expand Down
59 changes: 53 additions & 6 deletions aphrodite/modeling/layers/vocab_parallel_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,46 @@

import torch
import torch.nn.functional as F
from torch.nn.parameter import Parameter
from torch.nn.parameter import Parameter, UninitializedParameter

from aphrodite.distributed import (divide, get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from aphrodite.modeling.layers.linear import UnquantizedLinearMethod
from aphrodite.modeling.utils import set_weight_attrs
from aphrodite.quantization.base_config import (QuantizationConfig,
QuantizeMethodBase)
from aphrodite.quantization.base_config import (
QuantizationConfig, QuantizeMethodBase, method_has_implemented_embedding)

DEFAULT_VOCAB_PADDING_SIZE = 64


class UnquantizedEmbeddingMethod(QuantizeMethodBase):
"""Unquantized method for embeddings."""

def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int], input_size: int,
output_size: int, params_dtype: torch.dtype,
**extra_weight_attrs):
"""Create weights for embedding layer."""
weight = Parameter(torch.empty(sum(output_partition_sizes),
input_size_per_partition,
dtype=params_dtype),
requires_grad=False)
set_weight_attrs(weight, {"input_dim": 1, "output_dim": 0})
layer.register_parameter("weight", weight)
set_weight_attrs(weight, extra_weight_attrs)

def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
return F.linear(x, layer.weight, bias)

def embedding(self, layer: torch.nn.Module,
input_: torch.Tensor) -> torch.Tensor:
return F.embedding(input_, layer.weight)


def pad_vocab_size(vocab_size: int,
pad_to: int = DEFAULT_VOCAB_PADDING_SIZE) -> int:
"""Pad the vocab size to the given value."""
Expand Down Expand Up @@ -199,7 +226,18 @@ def __init__(self,
if quant_config is not None:
linear_method = quant_config.get_quant_method(self, prefix=prefix)
if linear_method is None:
linear_method = UnquantizedLinearMethod()
linear_method = UnquantizedEmbeddingMethod()

# If we are making an embedding layer, then our quantization linear
# method must implement the embedding operation. If we are another
# layer type like ParallelLMHead, this is not important.
is_embedding_layer = type(self.__class__) is VocabParallelEmbedding
linear_method_implements_embedding = method_has_implemented_embedding(
type(linear_method))
if is_embedding_layer and not linear_method_implements_embedding:
raise NotImplementedError(
f"The class {type(linear_method).__name__} must implement "
"the 'embedding' method, see UnquantizedEmbeddingMethod.")
self.linear_method: QuantizeMethodBase = linear_method

if params_dtype is None:
Expand Down Expand Up @@ -305,6 +343,14 @@ def weight_loader(self, param: Parameter, loaded_weight: torch.Tensor):
output_dim = getattr(param, "output_dim", None)
packed_dim = getattr(param, "packed_dim", None)

# If the parameter is a gguf weight, then load it directly.
if getattr(param, "is_gguf_weight_type", None):
param.data.copy_(loaded_weight)
param.weight_type = loaded_weight.item()
return
elif isinstance(param, UninitializedParameter):
param.materialize(loaded_weight.shape, dtype=loaded_weight.dtype)

# If parameter does not have output dim, then it should
# be copied onto all gpus (e.g. g_idx for act_order gptq).
if output_dim is None:
Expand Down Expand Up @@ -343,7 +389,8 @@ def forward(self, input_):
else:
masked_input = input_
# Get the embeddings.
output_parallel = F.embedding(masked_input.long(), self.weight)
output_parallel = self.linear_method.embedding(self,
masked_input.long())
# Mask the output embedding.
if self.tp_size > 1:
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
Expand Down
Loading

0 comments on commit 0e6c400

Please sign in to comment.