Skip to content

Commit

Permalink
quants: update qqq and gptq_marlin_24 to use AphroditeParameters (#…
Browse files Browse the repository at this point in the history
AlpinDale authored Dec 18, 2024
1 parent 9c9b2dd commit 5d90219
Showing 3 changed files with 98 additions and 116 deletions.
3 changes: 2 additions & 1 deletion aphrodite/modeling/layers/linear.py
Original file line number Diff line number Diff line change
@@ -26,7 +26,8 @@
WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod", "GPTQMarlinLinearMethod",
"AWQMarlinLinearMethod", "AWQLinearMethod", "HQQMarlinMethod",
"Fp8LinearMethod", "MarlinLinearMethod"
"Fp8LinearMethod", "MarlinLinearMethod", "QQQLinearMethod",
"GPTQMarlin24LinearMethod",
]


100 changes: 48 additions & 52 deletions aphrodite/quantization/gptq_marlin_24.py
Original file line number Diff line number Diff line change
@@ -6,7 +6,10 @@

from aphrodite import _custom_ops as ops
from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
from aphrodite.modeling.utils import set_weight_attrs
from aphrodite.modeling.parameter import (BaseAphroditeParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedAphroditeParameter)
from aphrodite.quantization.base_config import QuantizationConfig
from aphrodite.scalar_type import scalar_types

@@ -146,6 +149,7 @@ def create_weights(
**extra_weight_attrs,
):
del output_size # Unused.
weight_loader = extra_weight_attrs["weight_loader"]

if params_dtype != torch.float16:
raise ValueError(
@@ -184,87 +188,79 @@ def create_weights(
"Each permutation group must reside on the same gpu")

# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
qweight = PackedAphroditeParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size // 2,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader)

# Meta
meta = Parameter(
torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
requires_grad=False,
)
set_weight_attrs(
meta,
{
"input_dim": 0,
"packed_dim": 1,
"pack_factor": 1,
"output_dim": 1,
"marlin_tile_size": 2,
},
)
meta = PackedAphroditeParameter(data=torch.empty(
input_size_per_partition // 8 // 2 // 2,
output_size_per_partition * 2,
device="cuda",
dtype=torch.int16,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=1,
marlin_tile_size=2,
weight_loader=weight_loader)

# Determine if channelwise or not
input_groups = (1 if self.quant_config.group_size == -1 else
input_size_per_partition //
self.quant_config.group_size)

scales = Parameter(
weight_scale_args = {
"data":
torch.empty(
input_groups,
output_size_per_partition,
device="cuda",
dtype=params_dtype,
),
requires_grad=False,
)
set_weight_attrs(
scales,
{
"input_dim": None if input_groups == 1 else 0,
"output_dim": 1,
},
)
"weight_loader":
weight_loader
}
if input_groups == 1:
scales = ChannelQuantScaleParameter(output_dim=1,
**weight_scale_args)
else:
scales = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**weight_scale_args)

# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
workspace = BaseAphroditeParameter(data=torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
weight_loader=weight_loader)

layer.register_parameter("B_24", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("B_meta", meta)
set_weight_attrs(meta, extra_weight_attrs)
layer.register_parameter("s", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B_24 = Parameter(layer.B_24.data, requires_grad=False)
layer.s = Parameter(layer.s.data, requires_grad=False)
layer.B_meta = Parameter(layer.B_meta.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)

def apply(
self,
111 changes: 48 additions & 63 deletions aphrodite/quantization/qqq.py
Original file line number Diff line number Diff line change
@@ -5,7 +5,10 @@

from aphrodite import _custom_ops as ops
from aphrodite.modeling.layers.linear import LinearBase, LinearMethodBase
from aphrodite.modeling.utils import set_weight_attrs
from aphrodite.modeling.parameter import (BaseAphroditeParameter,
ChannelQuantScaleParameter,
GroupQuantScaleParameter,
PackedAphroditeParameter)
from aphrodite.quantization.base_config import QuantizationConfig

MARLIN_QQQ_TILE = 16
@@ -128,6 +131,7 @@ def create_weights(
params_dtype: torch.dtype,
**extra_weight_attrs,
):
weight_loader = extra_weight_attrs["weight_loader"]
if params_dtype != torch.float16:
raise ValueError(
f"The params dtype must be float16, but got {params_dtype}")
@@ -165,90 +169,71 @@ def create_weights(
"Each permutation group must reside on the same gpu")

# Quantized 4Bit weights packed into Int32.
qweight = Parameter(
torch.empty(
qweight = PackedAphroditeParameter(
data=torch.empty(
input_size_per_partition // self.quant_config.tile_size,
output_size_per_partition * self.quant_config.tile_size //
self.quant_config.pack_factor,
device="cuda",
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(
qweight,
{
"input_dim": 0,
"output_dim": 1,
"packed_dim": 1,
"pack_factor": self.quant_config.pack_factor,
"marlin_tile_size": self.quant_config.tile_size,
},
)

s_channel = Parameter(
torch.empty(
1,
output_size_per_partition,
device="cuda",
dtype=torch.float,
),
requires_grad=False,
)
set_weight_attrs(
s_channel,
{
"input_dim": None,
"output_dim": 1,
},
)
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
marlin_tile_size=self.quant_config.tile_size,
weight_loader=weight_loader)
s_channel = ChannelQuantScaleParameter(data=torch.empty(
1,
output_size_per_partition,
device="cuda",
dtype=torch.float,
),
weight_loader=weight_loader,
output_dim=1)

if self.quant_config.group_size == -1:
s_group = Parameter(
torch.tensor(
[],
device="cuda",
dtype=torch.half,
),
requires_grad=False,
s_group_data = torch.tensor(
[],
device="cuda",
dtype=torch.half,
)
else:
s_group = Parameter(
torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
device="cuda",
dtype=torch.half,
),
requires_grad=False,
s_group_data = torch.empty(
input_size_per_partition // self.quant_config.group_size,
output_size_per_partition,
device="cuda",
dtype=torch.half,
)

set_weight_attrs(
s_group,
{
"input_dim": None if self.quant_config.group_size == -1 else 0,
"output_dim":
None if self.quant_config.group_size == -1 else 1,
},
)
s_group_attr = {"data": s_group_data, "weight_loader": weight_loader}
if self.quant_config.group_size == -1:
s_group = BaseAphroditeParameter(**s_group_attr)
else:
s_group = GroupQuantScaleParameter(output_dim=1,
input_dim=0,
**s_group_attr)

# Allocate workspace (Used for internal locking mechanism)
max_workspace_size = (
output_size_per_partition //
self.quant_config.min_n_threads) * self.quant_config.max_parallel
workspace = Parameter(torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
requires_grad=False)
workspace = BaseAphroditeParameter(data=torch.zeros(max_workspace_size,
device="cuda",
dtype=torch.int),
weight_loader=weight_loader)

layer.register_parameter("B", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
layer.register_parameter("s_channel", s_channel)
set_weight_attrs(s_channel, extra_weight_attrs)
layer.register_parameter("s_group", s_group)
set_weight_attrs(s_group, extra_weight_attrs)
layer.register_parameter("workspace", workspace)
set_weight_attrs(workspace, extra_weight_attrs)

def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
# required by torch.compile
layer.B = Parameter(layer.B.data, requires_grad=False)
layer.s_channel = Parameter(layer.s_channel.data, requires_grad=False)
layer.s_group = Parameter(layer.s_group.data, requires_grad=False)
layer.workspace = Parameter(layer.workspace.data, requires_grad=False)

def apply(
self,

0 comments on commit 5d90219

Please sign in to comment.