Skip to content

Commit

Permalink
review comments
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Nov 6, 2024
1 parent 17bebb1 commit 630c540
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 28 deletions.
11 changes: 7 additions & 4 deletions csrc/cutlass_extensions/vllm_numeric_conversion.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,11 +31,14 @@ struct InterleavedNumericArrayConverter {
static result_type convert(source_type const& source) {
if (cute::elect_one_sync()) {
if constexpr (std::is_same_v<IlvBlkLayout, void>) {
printf(" %s <= %s (N = %d, IlvBlkLayout = void)\n", nameof_v<T>,
nameof_v<S>, N);
printf(
"Convert %s <= %s (N = %d, IlvBlkLayout = void), not implemented\n",
nameof_v<T>, nameof_v<S>, N);
} else {
printf(" %s <= %s (N = %d, size(IlvBlkLayout{}) = %d)\n", nameof_v<T>,
nameof_v<S>, N, size(IlvBlkLayout{}));
printf(
"Convert %s <= %s (N = %d, size(IlvBlkLayout{}) = %d), not "
"implemented\n",
nameof_v<T>, nameof_v<S>, N, size(IlvBlkLayout{}));
}
__brkpt();
}
Expand Down
50 changes: 26 additions & 24 deletions csrc/quantization/machete/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -525,40 +525,42 @@ def get_unique_schedules(heuristic: Dict[str, ScheduleConfig]):

impl_configs = []

GPTQ_kernel_types = list((TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=DataType.void,
b_channel_scale=DataType.void,
b_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for a in (DataType.f16, DataType.bf16)))
GPTQ_kernel_type_configs = list(
TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=DataType.void,
b_channel_scale=DataType.void,
b_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (VLLMDataType.u4b8, VLLMDataType.u8b128)
for a in (DataType.f16, DataType.bf16))

impl_configs += [
ImplConfig(x[0], x[1], x[2])
for x in zip(GPTQ_kernel_types,
for x in zip(GPTQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic))
]

AWQ_kernel_types = list((TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=a,
b_channel_scale=DataType.void,
b_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (DataType.u4, DataType.u8)
for a in (DataType.f16, DataType.bf16)))
AWQ_kernel_type_configs = list(
TypeConfig(
a=a,
b=b,
b_group_scale=a,
b_group_zeropoint=a,
b_channel_scale=DataType.void,
b_token_scale=DataType.void,
out=a,
accumulator=DataType.f32,
) for b in (DataType.u4, DataType.u8)
for a in (DataType.f16, DataType.bf16))

impl_configs += [
ImplConfig(x[0], x[1], x[2])
for x in zip(AWQ_kernel_types,
for x in zip(AWQ_kernel_type_configs,
itertools.repeat(get_unique_schedules(default_heuristic)),
itertools.repeat(default_heuristic))
]
Expand Down

0 comments on commit 630c540

Please sign in to comment.