Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

kernel: do not compile machete for cuda 11 and below #901

Merged
merged 1 commit into from
Dec 16, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 35 additions & 28 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@ message(STATUS "Target device: ${APHRODITE_TARGET_DEVICE}")

include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)

# Suppress potential warnings about unused manually-specified variables
set(ignoreMe "${APHRODITE_PYTHON_PATH}")

#
# Supported python versions. These versions will be searched in order, the
# first match will be selected. These should be kept in sync with setup.py.
Expand Down Expand Up @@ -250,43 +253,47 @@ if(APHRODITE_GPU_LANG STREQUAL "CUDA")
endif()

#
# For the Machete kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/kernels/cutlass_extensions/:${CUTLASS_DIR}/python/:${APHRODITE_PYTHON_PATH}:$PYTHONPATH
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantization/machete/generate.py
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
)

if (NOT machete_generation_result EQUAL 0)
message(FATAL_ERROR "Machete generation failed."
" Result: \"${machete_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
else()
message(STATUS "Machete generation completed successfully.")
endif()
# Machete kernels

# Add machete generated sources
file(GLOB MACHETE_GEN_SOURCES "kernels/quantization/machete/generated/*.cu")
list(APPEND APHRODITE_EXT_SRC ${MACHETE_GEN_SOURCES})
message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")

# See comment above for scaled_mm_c3x (same if condition)
# The machete kernels only work on hopper and require CUDA 12.0 or later.
if(${CMAKE_CUDA_COMPILER_VERSION} VERSION_GREATER 12.0)
#
# For the Machete kernels we automatically generate sources for various
# preselected input type pairs and schedules.
# Generate sources:
execute_process(
COMMAND ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_CURRENT_SOURCE_DIR}/kernels/cutlass_extensions/:${CUTLASS_DIR}/python/:${APHRODITE_PYTHON_PATH}:$PYTHONPATH
${Python_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/kernels/quantization/machete/generate.py
RESULT_VARIABLE machete_generation_result
OUTPUT_VARIABLE machete_generation_output
OUTPUT_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
ERROR_FILE ${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log
)
if (NOT machete_generation_result EQUAL 0)
message(FATAL_ERROR "Machete generation failed."
" Result: \"${machete_generation_result}\""
"\nCheck the log for details: "
"${CMAKE_CURRENT_BINARY_DIR}/machete_generation.log")
else()
message(STATUS "Machete generation completed successfully.")
endif()

# Add machete generated sources
file(GLOB MACHETE_GEN_SOURCES "kernels/quantization/machete/generated/*.cu")
list(APPEND APHRODITE_EXT_SRC ${MACHETE_GEN_SOURCES})
message(STATUS "Machete generated sources: ${MACHETE_GEN_SOURCES}")

set_source_files_properties(
${MACHETE_GEN_SOURCES}
PROPERTIES
COMPILE_FLAGS
"-gencode arch=compute_90a,code=sm_90a")
endif()

# Add pytorch binding
# Add pytorch binding for machete (add on even CUDA < 12.0 so that we can
# raise an error if the user that this was built with an incompatible
# CUDA version)
list(APPEND APHRODITE_EXT_SRC
kernels/quantization/machete/machete_pytorch.cu)
endif()
Expand Down
12 changes: 12 additions & 0 deletions kernels/quantization/machete/machete_pytorch.cu
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,13 @@ static auto scalar_type_dispatch(ScalarType const& type, Fn fn) {
//

std::vector<std::string> supported_schedules(ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
return GemmDispatcher<half_t, decltype(BType)>::supported_schedules();
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}

torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
Expand All @@ -50,6 +54,7 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
c10::optional<torch::Tensor> const& C,
c10::optional<double> alpha, c10::optional<double> beta,
c10::optional<std::string> schedule) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
auto args = PyTorchArguments{.A = A,
.B = B,
.scales = scales,
Expand All @@ -67,13 +72,20 @@ torch::Tensor gemm(torch::Tensor const& A, torch::Tensor const& B,
return GemmDispatcher<ComputeType, decltype(BType)>::dispatch(args);
});
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}

torch::Tensor prepack_B(torch::Tensor const& B,
ScalarTypeTorchPtr const& btype) {
#if defined(__CUDACC_VER_MAJOR__) && __CUDACC_VER_MAJOR__ >= 12
return scalar_type_dispatch(*btype, [&](auto BType) {
return PrepackBDispatcher<half_t, decltype(BType), half_t>::dispatch(B);
});
#else
TORCH_CHECK(false, "Machete requires CUDA 12.0 or later");
#endif
}

}; // namespace machete
Loading