Skip to content

Commit

Permalink
minor cleanup
Browse files Browse the repository at this point in the history
Signed-off-by: Lucas Wilkinson <[email protected]>
  • Loading branch information
LucasWilkinson committed Nov 6, 2024
1 parent a9783f4 commit 17bebb1
Show file tree
Hide file tree
Showing 5 changed files with 14 additions and 14 deletions.
4 changes: 2 additions & 2 deletions csrc/cutlass_extensions/epilogue/scaled_mm_epilogues_c2x.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,9 @@
/*
This file defines custom epilogues for fusing channel scales, token scales,
bias, and activation zero-points onto a GEMM operation using the
CUTLASS 3.x API, for pre sm90 (Hopper) NVIDIA GPUs.
CUTLASS 2.x API, for sm80 (Ampere) NVIDIA GPUs.
Epilogues must contain a public type named EVTCompute of type Sm90EVT,
Epilogues must contain a public type named EVTCompute of type Sm80EVT,
as well as a static prepare_args function that constructs an
EVTCompute::Arguments struct.
*/
Expand Down
6 changes: 3 additions & 3 deletions csrc/quantization/machete/generate.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@
{% for impl_config in impl_configs %}
{% set t = impl_config.types -%}
{% set type_sig = gen_type_sig(t) -%}
if (args.btype == {{VLLMScalarTypeTag[t.b]}}
if (args.b_type == {{VLLMScalarTypeTag[t.b]}}
&& a_type == {{TorchTypeTag[t.a]}}
&& out_type == {{TorchTypeTag[t.out]}}
&& {%if t.b_group_scale != void -%}
Expand All @@ -105,7 +105,7 @@
TORCH_CHECK_NOT_IMPLEMENTED(
false, "machete_mm(..) is not implemented for "
"a_type=", args.A.scalar_type(),
", b_type=", args.btype.str(),
", b_type=", args.b_type.str(),
", out_type=", out_type,
", with_group_scale_type=", maybe_g_scales_type
? toString(*maybe_g_scales_type) : "None",
Expand Down Expand Up @@ -231,7 +231,7 @@
TORCH_CHECK_NOT_IMPLEMENTED(false,
"prepack_B_dispatch(..) is not implemented for "
"atype = ", args.a_type,
", btype = ", args.b_type.str(),
", b_type = ", args.b_type.str(),
", with_group_scales_type= ", args.maybe_group_scales_type ?
toString(*args.maybe_group_scales_type) : "None");
}
Expand Down
2 changes: 1 addition & 1 deletion csrc/quantization/machete/machete_mm_launcher.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ namespace machete {
struct MMArgs {
torch::Tensor const& A;
torch::Tensor const& B;
vllm::ScalarType const& btype;
vllm::ScalarType const& b_type;
c10::optional<at::ScalarType> const& maybe_out_type;
c10::optional<torch::Tensor> const& maybe_group_scales;
c10::optional<torch::Tensor> const& maybe_group_zeros;
Expand Down
14 changes: 7 additions & 7 deletions csrc/quantization/machete/machete_pytorch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@ namespace machete {
using namespace vllm;

std::vector<std::string> supported_schedules(
at::ScalarType a_type, int64_t btype_id,
at::ScalarType a_type, int64_t b_type_id,
c10::optional<at::ScalarType> maybe_group_scales_type,
c10::optional<at::ScalarType> maybe_group_zeros_type,
c10::optional<at::ScalarType> maybe_channel_scales_type,
c10::optional<at::ScalarType> maybe_token_scales_type,
c10::optional<at::ScalarType> maybe_out_type) {
ScalarType const b_type = ScalarType::from_id(btype_id);
ScalarType const b_type = ScalarType::from_id(b_type_id);
return supported_schedules_dispatch({
.a_type = a_type,
.b_type = b_type,
Expand All @@ -28,18 +28,18 @@ std::vector<std::string> supported_schedules(
}

torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
int64_t btype_id,
int64_t b_type_id,
c10::optional<at::ScalarType> const& maybe_out_type,
c10::optional<torch::Tensor> const& maybe_group_scales,
c10::optional<torch::Tensor> const& maybe_group_zeros,
c10::optional<int64_t> maybe_group_size,
c10::optional<torch::Tensor> const& maybe_channel_scales,
c10::optional<torch::Tensor> const& maybe_token_scales,
c10::optional<std::string> maybe_schedule) {
ScalarType const b_type = ScalarType::from_id(btype_id);
ScalarType const b_type = ScalarType::from_id(b_type_id);
return mm_dispatch({.A = A,
.B = B,
.btype = b_type,
.b_type = b_type,
.maybe_out_type = maybe_out_type,
.maybe_group_scales = maybe_group_scales,
.maybe_group_zeros = maybe_group_zeros,
Expand All @@ -50,9 +50,9 @@ torch::Tensor mm(torch::Tensor const& A, torch::Tensor const& B,
}

torch::Tensor prepack_B(
torch::Tensor const& B, at::ScalarType const& a_type, int64_t btype_id,
torch::Tensor const& B, at::ScalarType const& a_type, int64_t b_type_id,
c10::optional<at::ScalarType> const& maybe_group_scales_type) {
ScalarType const b_type = ScalarType::from_id(btype_id);
ScalarType const b_type = ScalarType::from_id(b_type_id);
return prepack_B_dispatch(
{.B = B,
.a_type = a_type,
Expand Down
2 changes: 1 addition & 1 deletion csrc/torch_bindings.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
"machete_mm("
" Tensor A,"
" Tensor B,"
" int btype,"
" int b_type,"
" ScalarType? out_type,"
" Tensor? group_scales,"
" Tensor? group_zeros,"
Expand Down

0 comments on commit 17bebb1

Please sign in to comment.