diff --git a/benchmarks/kernels/benchmark_machete.py b/benchmarks/kernels/benchmark_machete.py index 21d376b12d876..a0342d08f1db8 100644 --- a/benchmarks/kernels/benchmark_machete.py +++ b/benchmarks/kernels/benchmark_machete.py @@ -549,10 +549,11 @@ def to_torch_dtype(dt): "int": torch.int, "float": torch.float, }[dt] - + class ToTorchDtype(argparse.Action): - def __call__(self, parser, namespace, values, option_string=None): - setattr(namespace, self.dest, to_torch_dtype(values)) + + def __call__(self, parser, namespace, values, option_string=None): + setattr(namespace, self.dest, to_torch_dtype(values)) parser = FlexibleArgumentParser( description=""" diff --git a/csrc/quantization/machete/generate.py b/csrc/quantization/machete/generate.py index 203a70d922ff8..ac63afe79a255 100644 --- a/csrc/quantization/machete/generate.py +++ b/csrc/quantization/machete/generate.py @@ -94,8 +94,8 @@ && {%if t.b_channel_scale != void -%} maybe_ch_scales_type == {{TorchTypeTag[t.b_channel_scale]}} {%- else %}!maybe_ch_scales_type{%endif%} - && {%if t.b_token_scale != void -%} - maybe_tok_scales_type == {{TorchTypeTag[t.b_token_scale]}} + && {%if t.a_token_scale != void -%} + maybe_tok_scales_type == {{TorchTypeTag[t.a_token_scale]}} {%- else %}!maybe_tok_scales_type{%endif%} ) { return mm_dispatch_{{type_sig}}(args); @@ -188,7 +188,7 @@ {{DataTypeTag[t.b_group_scale]}}, // GroupScaleT {{DataTypeTag[t.b_group_zeropoint]}}, // GroupZeroT {{DataTypeTag[t.b_channel_scale]}}, // ChannelScaleT - {{DataTypeTag[t.b_token_scale]}}, // TokenScaleT + {{DataTypeTag[t.a_token_scale]}}, // TokenScaleT cutlass::gemm::KernelTmaWarpSpecializedCooperativeMixedInput, Sch>; @@ -259,7 +259,7 @@ class TypeConfig: b_group_scale: DataType b_group_zeropoint: DataType b_channel_scale: DataType - b_token_scale: DataType + a_token_scale: DataType out: DataType accumulator: DataType @@ -532,7 +532,7 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]): b_group_scale=a, b_group_zeropoint=DataType.void, b_channel_scale=DataType.void, - b_token_scale=DataType.void, + a_token_scale=DataType.void, out=a, accumulator=DataType.f32, ) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128) @@ -552,7 +552,7 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]): b_group_scale=a, b_group_zeropoint=a, b_channel_scale=DataType.void, - b_token_scale=DataType.void, + a_token_scale=DataType.void, out=a, accumulator=DataType.f32, ) for b in (DataType.u4, DataType.u8) @@ -615,7 +615,7 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]): b_group_scale=b_group_scale, b_group_zeropoint=DataType.void, b_channel_scale=DataType.f32, - b_token_scale=DataType.f32, + a_token_scale=DataType.f32, out=DataType.f16, accumulator=DataType.s32, ) for b_group_scale in (DataType.f16, DataType.void)), @@ -625,7 +625,7 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]): b_group_scale=b_group_scale, b_group_zeropoint=DataType.void, b_channel_scale=DataType.f32, - b_token_scale=DataType.f32, + a_token_scale=DataType.f32, out=DataType.f16, accumulator=DataType.f32, ) for b_group_scale in (DataType.f16, DataType.void)), diff --git a/tests/kernels/test_machete_gemm.py b/tests/kernels/test_machete_mm.py similarity index 100% rename from tests/kernels/test_machete_gemm.py rename to tests/kernels/test_machete_mm.py diff --git a/vllm/model_executor/layers/quantization/utils/quant_utils.py b/vllm/model_executor/layers/quantization/utils/quant_utils.py index 88f1759abb881..83055d6000d83 100644 --- a/vllm/model_executor/layers/quantization/utils/quant_utils.py +++ b/vllm/model_executor/layers/quantization/utils/quant_utils.py @@ -408,8 +408,6 @@ def unpack_cols( orig_device = packed_q_w.device - packed_q_w = packed_q_w.t() - packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32) q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)