Skip to content

Commit

Permalink
add rocm support:
Browse files Browse the repository at this point in the history
- resovle nvbench problem

- add hip cuda defs and port test_norm

- add test_norm & bench_norm
  • Loading branch information
yiakwy-xpu-ml-framework-team committed Sep 9, 2024
1 parent f2ca781 commit 553037f
Show file tree
Hide file tree
Showing 18 changed files with 709 additions and 20 deletions.
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
[submodule "3rdparty/nvbench"]
path = 3rdparty/nvbench
url = https://github.com/NVIDIA/nvbench.git
[submodule "3rdparty/hipbench"]
path = 3rdparty/hipbench
url = https://github.com/ROCm/hipBench.git
[submodule "3rdparty/googletest"]
path = 3rdparty/googletest
url = https://github.com/google/googletest.git
Expand Down
171 changes: 159 additions & 12 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,10 +1,72 @@
cmake_minimum_required(VERSION 3.23.1)
project(flashinfer CUDA CXX)
cmake_minimum_required(VERSION 3.26.4)

# set compiler conditional
# Verified for ROCM >= 6.2, alias to $hip_LIB_INSTALL_DIR defined in ${ROCM_HOME}/lib/cmake/hip/hip-config-amd.cmake
set(ROCM_HOME "/opt/rocm" CACHE PATH "ROCM SDK INSTALLATION HOME")
if (NOT IS_DIRECTORY ${ROCM_HOME})
message(WARNING "ROCM_HOME ${ROCM_HOME} is not a directory")
endif()

if (LINUX)
# SDK Root in CMAKE config file; LINUX system defaults to ENV{ROCM_PATH}; WIN32 system defaults to ENV{HIP_PATH}
set(ENV{ROCM_PATH} ${ROCM_HOME})
endif()

if(NOT DEFINED HIP_CMAKE_PATH)
if(NOT DEFINED ENV{HIP_CMAKE_PATH})
# NOTE(yiakwy) : find_package(HIP) will first search for
# cmake/Modules/FindAMDDeviceLibs.cmake
# , then
# /opt/rocm/lib/cmake/AMDDeviceLibs/AMDDeviceLibsConfig.cmake
# this will add hip::host, hip::device dependencies to be linked by any hip targets (ROCM >= 6.x).
# Add hip-config.cmake to CMake module search path.
# set(HIP_CMAKE_PATH "${ROCM_HOME}/share/rocm/cmake" "${ROCM_HOME}/share/rocmcmakebuildtools/cmake/" CACHE PATH "Path to which HIP has been installed")
# NOTE(yiakwy) : adding ${ROCM_HOME}/lib/cmake/hip has conflicts with 3rdparty/mscclpp
set(HIP_CMAKE_PATH "${ROCM_HOME}/lib/cmake/AMDDeviceLibs" "${ROCM_HOME}/lib/cmake/amd_comgr" "${ROCM_HOME}/lib/cmake/hsa-runtime64" CACHE PATH "Path to which HIP has been installed")
message(WARNING "System variable HIP_CMAKE_PATH is nonexist, defaults to ${HIP_CMAKE_PATH}")

set(CMAKE_PREFIX_PATH "${ROCM_HOME};${CMAKE_PREFIX_PATH}")
else()
set(HIP_CMAKE_PATH $ENV{HIP_CMAKE_PATH} CACHE PATH "Path to which HIP has been installed")
endif()
endif()

set(CMAKE_MODULE_PATH "${HIP_CMAKE_PATH}" ${CMAKE_MODULE_PATH})

##### Flash infer project
project(flashinfer C CXX)

# set CMAKE_CXX_COMPILER to hipcc
# set(CMAKE_FIND_DEBUG_MODE TRUE)
find_package(HIP QUIET)
if(HIP_FOUND)
message(STATUS "Found HIP: " ${HIP_VERSION})
execute_process(COMMAND bash -c "/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*'"
OUTPUT_VARIABLE CMAKE_HIP_ARCHITECTURES OUTPUT_STRIP_TRAILING_WHITESPACE)

enable_language(HIP)

add_definitions(-DUSE_ROCM)
else()
message(WARNING "Could not find HIP. Ensure that ROCM SDK is either installed in /opt/rocm or the variable HIP_CMAKE_PATH is set to point to the right location.")
endif()

find_package(CUDA QUIET)
if (CUDA_FOUND)
message(STATUS "FOUND CUDA: " ${CUDA_TOOLKIT_ROOT_DIR})
else()
message(WARNING "Could not find CUDA.")
endif()

if (NOT (HIP_FOUND) AND NOT (CUDA_FOUND))
message(FATAL "ROCM/CUDA SDK must be supported")
endif()

include(cmake/utils/Utils.cmake)

set(CMAKE_CXX_STANDARD 17)
set(CMAKE_CUDA_STANDARD 17)
set(CMAKE_HIP_STANDARD 17)

if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake)
include(${CMAKE_BINARY_DIR}/config.cmake)
Expand Down Expand Up @@ -45,23 +107,41 @@ flashinfer_option(FLASHINFER_GEN_POS_ENCODING_MODES "Pos encodings to enable" 0
flashinfer_option(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "QK reductions to enable" "false" "true")
flashinfer_option(FLASHINFER_GEN_MASK_MODES "Mask modes to enable" 0 1 2)

# ROCM ARCH
if(DEFINED CMAKE_HIP_ARCHITECTURES)
message(STATUS "CMAKE_HIP_ARCHITECTURES : ${CMAKE_HIP_ARCHITECTURES}")

else(CMAKE_HIP_ARCHITECTURES)

# CUDA ARCH
if(DEFINED FLASHINFER_CUDA_ARCHITECTURES)
message(STATUS "CMAKE_CUDA_ARCHITECTURES set to ${FLASHINFER_CUDA_ARCHITECTURES}.")
message(STATUS "CMAKE_CUDA_ARCHITECTURES set to
${FLASHINFER_CUDA_ARCHITECTURES}.")
set(CMAKE_CUDA_ARCHITECTURES ${FLASHINFER_CUDA_ARCHITECTURES})
else(DEFINED FLASHINFER_CUDA_ARCHITECTURES)
message(STATUS "CMAKE_CUDA_ARCHITECTURES is ${CMAKE_CUDA_ARCHITECTURES}")
endif(DEFINED FLASHINFER_CUDA_ARCHITECTURES)

endif()

list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules")
list(APPEND CMAKE_MODULE_PATH "${ROCM_HOME}/lib/cmake/hip")
if(FLASHINFER_PREFILL OR FLASHINFER_DECODE OR FLASHINFER_PAGE OR FLASHINFER_CASCADE OR FLASHINFER_SAMPLING OR FLASHINFER_NORM)
message(STATUS "NVBench and GoogleTest enabled")
add_subdirectory(3rdparty/nvbench)
if(FLASHINFER_DISTRIBUTED)
if (HIP_FOUND)
add_subdirectory(3rdparty/hipbench)
else()
add_subdirectory(3rdparty/nvbench)
endif()
if (FLASHINFER_DISTRIBUTED)
message(STATUS "compiling 3rdparty/mscclpp ...")
add_subdirectory(3rdparty/mscclpp)
else(FLASHINFER_DISTRIBUTED)
add_subdirectory(3rdparty/googletest)
endif(FLASHINFER_DISTRIBUTED)
endif(FLASHINFER_PREFILL OR FLASHINFER_DECODE OR FLASHINFER_PAGE OR FLASHINFER_CASCADE OR FLASHINFER_SAMPLING OR FLASHINFER_NORM)

# fixed with rocm path
find_package(Thrust REQUIRED)

set(
Expand All @@ -77,6 +157,8 @@ endif(FLASHINFER_ENABLE_FP8)
if(FLASHINFER_ENABLE_BF16)
message(STATUS "Compile bf16 kernels.")
add_definitions(-DFLASHINFER_ENABLE_BF16)
else()
message (WARNING "Since bf16 is not enabled, many tests will be disabled.")
endif(FLASHINFER_ENABLE_BF16)

# generate kernel inst
Expand Down Expand Up @@ -189,6 +271,9 @@ endforeach(head_dim)
add_library(decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src})
target_include_directories(decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_compile_options(decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all)
if (HIP_FOUND)
set_target_properties(decode_kernels PROPERTIES LINKER_LANGUAGE HIP)
endif()

# single prefill kernel inst generation
foreach(head_dim IN LISTS HEAD_DIMS)
Expand Down Expand Up @@ -302,6 +387,9 @@ endforeach(head_dim)
add_library(prefill_kernels STATIC ${single_prefill_kernels_src} ${batch_paged_prefill_kernels_src} ${batch_ragged_prefill_kernels_src})
target_include_directories(prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_compile_options(prefill_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all)
if (HIP_FOUND)
set_target_properties(prefill_kernels PROPERTIES LINKER_LANGUAGE HIP)
endif()

if (FLASHINFER_DECODE)
message(STATUS "Compile single decode kernel benchmarks.")
Expand All @@ -315,6 +403,7 @@ if (FLASHINFER_DECODE)

message(STATUS "Compile single decode kernel tests.")
file(GLOB_RECURSE TEST_DECODE_SRCS ${PROJECT_SOURCE_DIR}/src/test_single_decode.cu)
message(STATUS "test source : ${TEST_DECODE_SRCS}")
add_executable(test_single_decode ${TEST_DECODE_SRCS})
target_include_directories(test_single_decode PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(test_single_decode PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
Expand All @@ -339,6 +428,13 @@ if (FLASHINFER_DECODE)
add_dependencies(test_batch_decode dispatch_inc)
target_link_libraries(test_batch_decode PRIVATE gtest gtest_main decode_kernels)
target_compile_options(test_batch_decode PRIVATE -Wno-switch-bool)

if (HIP_FOUND)
set_target_properties(bench_single_decode PROPERTIES LINKER_LANGUAGE HIP)
set_target_properties(test_single_decode PROPERTIES LINKER_LANGUAGE HIP)
set_target_properties(bench_batch_decode PROPERTIES LINKER_LANGUAGE HIP)
set_target_properties(test_batch_decode PROPERTIES LINKER_LANGUAGE HIP)
endif()
endif(FLASHINFER_DECODE)

if (FLASHINFER_PREFILL)
Expand Down Expand Up @@ -377,6 +473,13 @@ if (FLASHINFER_PREFILL)
add_dependencies(test_batch_prefill dispatch_inc)
target_link_libraries(test_batch_prefill PRIVATE gtest gtest_main prefill_kernels)
target_compile_options(test_batch_prefill PRIVATE -Wno-switch-bool)

if (HIP_FOUND)
set_target_properties(bench_single_prefill PROPERTIES LINKER_LANGUAGE HIP)
set_target_properties(test_single_prefill PROPERTIES LINKER_LANGUAGE HIP)
set_target_properties(bench_batch_prefill PROPERTIES LINKER_LANGUAGE HIP)
set_target_properties(test_batch_prefill PROPERTIES LINKER_LANGUAGE HIP)
endif()
endif(FLASHINFER_PREFILL)

if (FLASHINFER_PAGE)
Expand All @@ -387,6 +490,10 @@ if (FLASHINFER_PAGE)
target_include_directories(test_page PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
target_link_libraries(test_page PRIVATE gtest gtest_main)
target_compile_options(test_page PRIVATE -Wno-switch-bool)

if (HIP_FOUND)
set_target_properties(test_page PROPERTIES LINKER_LANGUAGE HIP)
endif()
endif(FLASHINFER_PAGE)

if (FLASHINFER_CASCADE)
Expand All @@ -407,6 +514,10 @@ if (FLASHINFER_CASCADE)
add_dependencies(test_cascade dispatch_inc)
target_link_libraries(test_cascade PRIVATE gtest gtest_main decode_kernels prefill_kernels)
target_compile_options(test_cascade PRIVATE -Wno-switch-bool)

if (HIP_FOUND)
set_target_properties(test_cascade PROPERTIES LINKER_LANGUAGE HIP)
endif()
endif(FLASHINFER_CASCADE)

if (FLASHINFER_SAMPLING)
Expand All @@ -425,27 +536,52 @@ if (FLASHINFER_SAMPLING)
target_include_directories(test_sampling PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
target_link_libraries(test_sampling PRIVATE gtest gtest_main)
target_compile_options(test_sampling PRIVATE -Wno-switch-bool)

if (HIP_FOUND)
set_target_properties(bench_sampling PROPERTIES LINKER_LANGUAGE HIP)
set_target_properties(test_sampling PROPERTIES LINKER_LANGUAGE HIP)
endif()
endif(FLASHINFER_SAMPLING)

if (FLASHINFER_NORM)
if (TRUE)#(FLASHINFER_NORM) TODO(yiakwy) : fix options
message(STATUS "Compile normalization kernel benchmarks.")
file(GLOB_RECURSE BENCH_NORM_SRCS ${PROJECT_SOURCE_DIR}/src/bench_norm.cu)
add_executable(bench_norm ${BENCH_NORM_SRCS})

if (HIP_FOUND)
set_source_files_properties(${BENCH_NORM_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
hip_add_executable(bench_norm ${BENCH_NORM_SRCS})
target_include_directories(bench_norm PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/hipbench)
else(HIP_FOUND)
add_executable(bench_norm ${BENCH_NORM_SRCS})
target_include_directories(bench_norm PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench)
endif()

target_include_directories(bench_norm PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(bench_norm PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench)
target_link_libraries(bench_norm PRIVATE nvbench::main)
target_compile_options(bench_norm PRIVATE -Wno-switch-bool)

message(STATUS "Compile normalization kernel tests.")
file(GLOB_RECURSE TEST_NORM_SRCS ${PROJECT_SOURCE_DIR}/src/test_norm.cu)
add_executable(test_norm ${TEST_NORM_SRCS})

if (HIP_FOUND)
set_source_files_properties(${TEST_NORM_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1)
hip_add_executable(test_norm ${TEST_NORM_SRCS})
else(HIP_FOUND)
add_executable(test_norm ${TEST_NORM_SRCS})
endif()

target_include_directories(test_norm PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(test_norm PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
target_link_libraries(test_norm PRIVATE gtest gtest_main)
target_compile_options(test_norm PRIVATE -Wno-switch-bool)

if (HIP_FOUND)
set_target_properties(bench_norm PROPERTIES LINKER_LANGUAGE HIP)
set_target_properties(test_norm PROPERTIES LINKER_LANGUAGE HIP)
endif()
endif(FLASHINFER_NORM)

if(FLASHINFER_TVM_BINDING)
if (FLASHINFER_TVM_BINDING)
message(STATUS "Compile tvm binding.")
if(NOT FLASHINFER_TVM_SOURCE_DIR STREQUAL "")
set(TVM_SOURCE_DIR_SET ${FLASHINFER_TVM_SOURCE_DIR})
Expand Down Expand Up @@ -477,6 +613,10 @@ if(FLASHINFER_FASTDIV_TEST)
target_include_directories(test_fastdiv PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(test_fastdiv PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
target_link_libraries(test_fastdiv PRIVATE gtest gtest_main)

if (HIP_FOUND)
set_target_properties(test_fastdiv PROPERTIES LINKER_LANGUAGE HIP)
endif()
endif(FLASHINFER_FASTDIV_TEST)

if(FLASHINFER_FASTDEQUANT_TEST)
Expand All @@ -486,9 +626,11 @@ if(FLASHINFER_FASTDEQUANT_TEST)
target_include_directories(test_fast_dequant PRIVATE ${FLASHINFER_INCLUDE_DIR})
target_include_directories(test_fast_dequant PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR})
target_link_libraries(test_fast_dequant PRIVATE gtest gtest_main)
endif(FLASHINFER_FASTDEQUANT_TEST)


if (HIP_FOUND)
set_target_properties(test_fast_dequant PROPERTIES LINKER_LANGUAGE HIP)
endif()
endif(FLASHINFER_FASTDEQUANT_TEST)

if (FLASHINFER_DISTRIBUTED)
find_package(MPI REQUIRED)
Expand All @@ -506,4 +648,9 @@ if (FLASHINFER_DISTRIBUTED)
target_include_directories(test_attn_all_reduce PRIVATE ${FLASHINFER_INCLUDE_DIR} 3rdparty/mscclpp/include 3rdparty/spdlog/include)
target_link_libraries(test_attn_all_reduce PRIVATE MPI::MPI_CXX mscclpp)
target_compile_definitions(test_attn_all_reduce PRIVATE -DENABLE_MPI)

if (HIP_FOUND)
set_target_properties(test_sum_all_reduce PROPERTIES LINKER_LANGUAGE HIP)
set_target_properties(test_attn_all_reduce PROPERTIES LINKER_LANGUAGE HIP)
endif()
endif(FLASHINFER_DISTRIBUTED)
2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ set(FLASHINFER_GEN_MASK_MODES 0 1 2)
# So it's recommended to set it to a specific value if you know the architecture of the target GPU.
# Example:
# set(FLASHINFER_CUDA_ARCHITECTURES 80)
set(FLASHINFER_CUDA_ARCHITECTURES native)
set(FLASHINFER_CUDA_ARCHITECTURES native)
2 changes: 2 additions & 0 deletions cmake/modules/FindThrust.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ find_path( THRUST_INCLUDE_DIR
/usr/include/cuda
/usr/local/include
/usr/local/cuda/include
/opt/rocm/include
${CUDA_INCLUDE_DIRS}
${HIP_INCLUDE_DIRS}
NAMES thrust/version.h
DOC "Thrust headers"
)
Expand Down
4 changes: 4 additions & 0 deletions cmake/utils/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ macro(flashinfer_option variable description value)
if("${__value}" MATCHES ";")
# list values directly pass through
__flashinfer_option(${variable} "${description}" "${__value}")
message(STATUS "1 : creating ${variable} option, description : ${description}, value : ${__value}")
elseif(DEFINED ${__value})
if(${__value})
__flashinfer_option(${variable} "${description}" ON)
message(STATUS "2 : creating ${variable} option, description : ${description}, value : ON")
else()
__flashinfer_option(${variable} "${description}" OFF)
message(STATUS "3 : creating ${variable} option, description : ${description}, value : OFF")
endif()
else()
__flashinfer_option(${variable} "${description}" "${__value}")
message(STATUS "4 : creating ${variable} option, description : ${description}, value : ${__value}")
endif()
else()
unset(${variable} CACHE)
Expand Down
Loading

0 comments on commit 553037f

Please sign in to comment.