Skip to content

Commit

Permalink
quantization: update marlin to use AphroditeParameters
Browse files Browse the repository at this point in the history
  • Loading branch information
AlpinDale committed Dec 18, 2024
1 parent 16e5b2b commit 41e67bc
Show file tree
Hide file tree
Showing 2 changed files with 102 additions and 71 deletions.
2 changes: 1 addition & 1 deletion aphrodite/modeling/layers/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod",
"AWQMarlinLinearMethod", "AWQLinearMethod", "HQQMarlinMethod",
"Fp8LinearMethod",
"Fp8LinearMethod", "MarlinLinearMethod"
]


Expand Down
171 changes: 101 additions & 70 deletions aphrodite/quantization/marlin.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,10 @@
from aphrodite import _custom_ops as ops
from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
from aphrodite.modeling.layers.vocab_parallel_embedding import ParallelLMHead
from aphrodite.modeling.utils import set_weight_attrs
from aphrodite.modeling.parameter import (BaseAphroditeParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedAphroditeParameter)
from aphrodite.quantization.base_config import QuantizationConfig


Expand All @@ -29,7 +32,8 @@ def __init__(
raise ValueError(
"Currently, only group size 128 and -1 (channelwise) "
"is supported for Marlin, but got group_size of "
f"{self.group_size}")
f"{self.group_size}"
)

# 4 Bits packed into 32 bit datatype.
self.pack_factor = 32 // 4
Expand All @@ -51,8 +55,10 @@ def __init__(
self.perm_len = 1024

def __repr__(self) -> str:
return (f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})")
return (
f"MarlinConfig(group_size={self.group_size}, "
f"lm_head_quantized={self.lm_head_quantized})"
)

@classmethod
def get_name(cls) -> str:
Expand All @@ -74,33 +80,42 @@ def get_config_filenames(cls) -> List[str]:
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "MarlinConfig":
group_size = cls.get_from_keys(config, ["group_size"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"],
default=False)
lm_head_quantized = cls.get_from_keys_or(
config, ["lm_head"], default=False
)
return cls(group_size, lm_head_quantized)

@classmethod
def override_quantization_method(cls, hf_quant_cfg,
user_quant) -> Optional[str]:
def override_quantization_method(
cls, hf_quant_cfg, user_quant
) -> Optional[str]:
# compat: autogptq >=0.8.0 use checkpoint_format: str
# compat: autogptq <=0.7.1 is_marlin_format: bool
is_marlin_format = (hf_quant_cfg.get("checkpoint_format") == "marlin"
or hf_quant_cfg.get("is_marlin_format", False))
is_marlin_format = hf_quant_cfg.get(
"checkpoint_format"
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)

is_valid_user_quant = (user_quant is None or user_quant == "gptq"
or user_quant == "marlin")
is_valid_user_quant = (
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
)

if is_marlin_format and is_valid_user_quant:
msg = ("The model is serialized in {} format. Using {} kernel.".
format(cls.get_name(), cls.get_name()))
msg = (
"The model is serialized in {} format. Using {} kernel.".format(
cls.get_name(), cls.get_name()
)
)
logger.info(msg)
return cls.get_name()

return None

def get_quant_method(self, layer: torch.nn.Module,
prefix: str) -> Optional["MarlinLinearMethod"]:
if (isinstance(layer, LinearBase) or
(isinstance(layer, ParallelLMHead) and self.lm_head_quantized)):
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["MarlinLinearMethod"]:
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
return MarlinLinearMethod(self)
return None

Expand Down Expand Up @@ -129,102 +144,117 @@ def create_weights(
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs["weight_loader"]

if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
f"The params dtype must be float16, but got {params_dtype}"
)

# Validate output_size_per_partition
output_size_per_partition = sum(output_partition_sizes)
if output_size_per_partition % self.quant_config.min_n_threads != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"min_n_threads = {self.quant_config.min_n_threads}.")
f"min_n_threads = {self.quant_config.min_n_threads}."
)
if output_size_per_partition % self.quant_config.pack_factor != 0:
raise ValueError(
f"Weight output_size_per_partition = "
f"{output_size_per_partition} is not divisible by "
f"pack_factor = {self.quant_config.pack_factor}.")
f"pack_factor = {self.quant_config.pack_factor}."
)

# Validate input_size_per_partition
if input_size_per_partition % self.quant_config.min_k_threads != 0:
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"min_k_threads = {self.quant_config.min_k_threads}.")
if (self.quant_config.group_size != -1 and
input_size_per_partition % self.quant_config.group_size != 0):
raise ValueError(f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}.")
f"min_k_threads = {self.quant_config.min_k_threads}."
)
if (
self.quant_config.group_size != -1
and input_size_per_partition % self.quant_config.group_size != 0
):
raise ValueError(
f"Weight input_size_per_partition = "
f"{input_size_per_partition} is not divisible by "
f"group_size = {self.quant_config.group_size}."
)

# Check that we have at least 4 tiles horizontally in the shard
num_tiles_per_perm = self.quant_config.perm_len // (
self.quant_config.tile_size**2)
self.quant_config.tile_size**2
)
if output_size_per_partition % num_tiles_per_perm != 0:
raise ValueError(
"Each permutation group must reside on the same gpu")
"Each permutation group must reside on the same gpu"
)

# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
qweight = PackedAphroditeParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
output_size_per_partition
* self.quant_config.tile_size
// self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader,
)

# Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)
input_groups = (
1
if self.quant_config.group_size == -1
else input_size_per_partition // self.quant_config.group_size
)

scales = Parameter(
torch.empty(
weight_scale_args = {
"data": torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": None if input_groups == 1 else 0,
"output_dim": 1,
},
)
"weight_loader": weight_loader,
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(
output_dim=1, **weight_scale_args
)
else:
scales = GroupQuantScaleParameter(
output_dim=1, input_dim=0, **weight_scale_args
)

# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
output_size_per_partition // self.quant_config.min_n_threads
) * self.quant_config.max_parallel
workspace = BaseAphroditeParameter(
data=torch.zeros(
max_workspace_size, device="cuda", dtype=torch.int
),
weight_loader=weight_loader,
)

layer.register_parameter("B", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B = Parameter(layer.B.data, requires_grad=False)
layer.s = Parameter(layer.s.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)

def apply(
self,
Expand All @@ -242,10 +272,11 @@ def apply(
size_k = x_2d.shape[1]
size_n = scales.shape[1]

output_2d = ops.marlin_gemm(x_2d, qweight, scales, workspace, size_m,
size_n, size_k)
output_2d = ops.marlin_gemm(
x_2d, qweight, scales, workspace, size_m, size_n, size_k
)

output = output_2d.view(x.shape[:-1] + (output_2d.shape[1], ))
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))

if bias is not None:
output.add_(bias) # In-place add
Expand Down

0 comments on commit 41e67bc

Please sign in to comment.