From 553037f4d3204b848d9622ca4351c3faa0c224ee Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Wed, 4 Sep 2024 08:10:20 +0000 Subject: [PATCH 01/15] add rocm support: - resovle nvbench problem - add hip cuda defs and port test_norm - add test_norm & bench_norm --- .gitmodules | 3 + CMakeLists.txt | 171 +++++++++++++++++-- cmake/config.cmake | 2 +- cmake/modules/FindThrust.cmake | 2 + cmake/utils/Utils.cmake | 4 + include/flashinfer/hip_cuda_type_utils.h | 72 ++++++++ include/flashinfer/hip_defs.h | 107 ++++++++++++ include/flashinfer/hip_warp_sync_functions.h | 72 ++++++++ include/flashinfer/math.cuh | 152 +++++++++++++++++ include/flashinfer/norm.cuh | 16 ++ include/flashinfer/page.cuh | 17 ++ include/flashinfer/pos_enc.cuh | 24 +++ include/flashinfer/utils.cuh | 26 +++ include/flashinfer/vec_dtypes.cuh | 35 +++- src/test_batch_decode.cu | 5 + src/test_norm.cu | 13 +- src/utils.h | 8 + src/vec_dtypes.h | 0 18 files changed, 709 insertions(+), 20 deletions(-) create mode 100644 include/flashinfer/hip_cuda_type_utils.h create mode 100644 include/flashinfer/hip_defs.h create mode 100644 include/flashinfer/hip_warp_sync_functions.h create mode 100644 src/vec_dtypes.h diff --git a/.gitmodules b/.gitmodules index 60cb77ed..b9576b70 100644 --- a/.gitmodules +++ b/.gitmodules @@ -1,6 +1,9 @@ [submodule "3rdparty/nvbench"] path = 3rdparty/nvbench url = https://github.com/NVIDIA/nvbench.git +[submodule "3rdparty/hipbench"] + path = 3rdparty/hipbench + url = https://github.com/ROCm/hipBench.git [submodule "3rdparty/googletest"] path = 3rdparty/googletest url = https://github.com/google/googletest.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 68c2b6cb..4cc71ad0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,10 +1,72 @@ -cmake_minimum_required(VERSION 3.23.1) -project(flashinfer CUDA CXX) +cmake_minimum_required(VERSION 3.26.4) + +# set compiler conditional +# Verified for ROCM >= 6.2, alias to $hip_LIB_INSTALL_DIR defined in ${ROCM_HOME}/lib/cmake/hip/hip-config-amd.cmake +set(ROCM_HOME "/opt/rocm" CACHE PATH "ROCM SDK INSTALLATION HOME") +if (NOT IS_DIRECTORY ${ROCM_HOME}) + message(WARNING "ROCM_HOME ${ROCM_HOME} is not a directory") +endif() + +if (LINUX) + # SDK Root in CMAKE config file; LINUX system defaults to ENV{ROCM_PATH}; WIN32 system defaults to ENV{HIP_PATH} + set(ENV{ROCM_PATH} ${ROCM_HOME}) +endif() + +if(NOT DEFINED HIP_CMAKE_PATH) + if(NOT DEFINED ENV{HIP_CMAKE_PATH}) + # NOTE(yiakwy) : find_package(HIP) will first search for + # cmake/Modules/FindAMDDeviceLibs.cmake + # , then + # /opt/rocm/lib/cmake/AMDDeviceLibs/AMDDeviceLibsConfig.cmake + # this will add hip::host, hip::device dependencies to be linked by any hip targets (ROCM >= 6.x). + # Add hip-config.cmake to CMake module search path. + # set(HIP_CMAKE_PATH "${ROCM_HOME}/share/rocm/cmake" "${ROCM_HOME}/share/rocmcmakebuildtools/cmake/" CACHE PATH "Path to which HIP has been installed") + # NOTE(yiakwy) : adding ${ROCM_HOME}/lib/cmake/hip has conflicts with 3rdparty/mscclpp + set(HIP_CMAKE_PATH "${ROCM_HOME}/lib/cmake/AMDDeviceLibs" "${ROCM_HOME}/lib/cmake/amd_comgr" "${ROCM_HOME}/lib/cmake/hsa-runtime64" CACHE PATH "Path to which HIP has been installed") + message(WARNING "System variable HIP_CMAKE_PATH is nonexist, defaults to ${HIP_CMAKE_PATH}") + + set(CMAKE_PREFIX_PATH "${ROCM_HOME};${CMAKE_PREFIX_PATH}") + else() + set(HIP_CMAKE_PATH $ENV{HIP_CMAKE_PATH} CACHE PATH "Path to which HIP has been installed") + endif() +endif() + +set(CMAKE_MODULE_PATH "${HIP_CMAKE_PATH}" ${CMAKE_MODULE_PATH}) + +##### Flash infer project +project(flashinfer C CXX) + +# set CMAKE_CXX_COMPILER to hipcc +# set(CMAKE_FIND_DEBUG_MODE TRUE) +find_package(HIP QUIET) +if(HIP_FOUND) + message(STATUS "Found HIP: " ${HIP_VERSION}) + execute_process(COMMAND bash -c "/opt/rocm/bin/rocminfo | grep -o -m1 'gfx.*'" + OUTPUT_VARIABLE CMAKE_HIP_ARCHITECTURES OUTPUT_STRIP_TRAILING_WHITESPACE) + + enable_language(HIP) + + add_definitions(-DUSE_ROCM) +else() + message(WARNING "Could not find HIP. Ensure that ROCM SDK is either installed in /opt/rocm or the variable HIP_CMAKE_PATH is set to point to the right location.") +endif() + +find_package(CUDA QUIET) +if (CUDA_FOUND) + message(STATUS "FOUND CUDA: " ${CUDA_TOOLKIT_ROOT_DIR}) +else() + message(WARNING "Could not find CUDA.") +endif() + +if (NOT (HIP_FOUND) AND NOT (CUDA_FOUND)) + message(FATAL "ROCM/CUDA SDK must be supported") +endif() include(cmake/utils/Utils.cmake) set(CMAKE_CXX_STANDARD 17) set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_HIP_STANDARD 17) if(EXISTS ${CMAKE_BINARY_DIR}/config.cmake) include(${CMAKE_BINARY_DIR}/config.cmake) @@ -45,23 +107,41 @@ flashinfer_option(FLASHINFER_GEN_POS_ENCODING_MODES "Pos encodings to enable" 0 flashinfer_option(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "QK reductions to enable" "false" "true") flashinfer_option(FLASHINFER_GEN_MASK_MODES "Mask modes to enable" 0 1 2) +# ROCM ARCH +if(DEFINED CMAKE_HIP_ARCHITECTURES) + message(STATUS "CMAKE_HIP_ARCHITECTURES : ${CMAKE_HIP_ARCHITECTURES}") + +else(CMAKE_HIP_ARCHITECTURES) + +# CUDA ARCH if(DEFINED FLASHINFER_CUDA_ARCHITECTURES) - message(STATUS "CMAKE_CUDA_ARCHITECTURES set to ${FLASHINFER_CUDA_ARCHITECTURES}.") + message(STATUS "CMAKE_CUDA_ARCHITECTURES set to +${FLASHINFER_CUDA_ARCHITECTURES}.") set(CMAKE_CUDA_ARCHITECTURES ${FLASHINFER_CUDA_ARCHITECTURES}) else(DEFINED FLASHINFER_CUDA_ARCHITECTURES) message(STATUS "CMAKE_CUDA_ARCHITECTURES is ${CMAKE_CUDA_ARCHITECTURES}") endif(DEFINED FLASHINFER_CUDA_ARCHITECTURES) +endif() + list(APPEND CMAKE_MODULE_PATH "${CMAKE_CURRENT_SOURCE_DIR}/cmake/modules") +list(APPEND CMAKE_MODULE_PATH "${ROCM_HOME}/lib/cmake/hip") if(FLASHINFER_PREFILL OR FLASHINFER_DECODE OR FLASHINFER_PAGE OR FLASHINFER_CASCADE OR FLASHINFER_SAMPLING OR FLASHINFER_NORM) message(STATUS "NVBench and GoogleTest enabled") - add_subdirectory(3rdparty/nvbench) - if(FLASHINFER_DISTRIBUTED) + if (HIP_FOUND) + add_subdirectory(3rdparty/hipbench) + else() + add_subdirectory(3rdparty/nvbench) + endif() + if (FLASHINFER_DISTRIBUTED) + message(STATUS "compiling 3rdparty/mscclpp ...") add_subdirectory(3rdparty/mscclpp) else(FLASHINFER_DISTRIBUTED) add_subdirectory(3rdparty/googletest) endif(FLASHINFER_DISTRIBUTED) endif(FLASHINFER_PREFILL OR FLASHINFER_DECODE OR FLASHINFER_PAGE OR FLASHINFER_CASCADE OR FLASHINFER_SAMPLING OR FLASHINFER_NORM) + +# fixed with rocm path find_package(Thrust REQUIRED) set( @@ -77,6 +157,8 @@ endif(FLASHINFER_ENABLE_FP8) if(FLASHINFER_ENABLE_BF16) message(STATUS "Compile bf16 kernels.") add_definitions(-DFLASHINFER_ENABLE_BF16) +else() + message (WARNING "Since bf16 is not enabled, many tests will be disabled.") endif(FLASHINFER_ENABLE_BF16) # generate kernel inst @@ -189,6 +271,9 @@ endforeach(head_dim) add_library(decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src}) target_include_directories(decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_compile_options(decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all) +if (HIP_FOUND) + set_target_properties(decode_kernels PROPERTIES LINKER_LANGUAGE HIP) +endif() # single prefill kernel inst generation foreach(head_dim IN LISTS HEAD_DIMS) @@ -302,6 +387,9 @@ endforeach(head_dim) add_library(prefill_kernels STATIC ${single_prefill_kernels_src} ${batch_paged_prefill_kernels_src} ${batch_ragged_prefill_kernels_src}) target_include_directories(prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_compile_options(prefill_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all) +if (HIP_FOUND) + set_target_properties(prefill_kernels PROPERTIES LINKER_LANGUAGE HIP) +endif() if (FLASHINFER_DECODE) message(STATUS "Compile single decode kernel benchmarks.") @@ -315,6 +403,7 @@ if (FLASHINFER_DECODE) message(STATUS "Compile single decode kernel tests.") file(GLOB_RECURSE TEST_DECODE_SRCS ${PROJECT_SOURCE_DIR}/src/test_single_decode.cu) + message(STATUS "test source : ${TEST_DECODE_SRCS}") add_executable(test_single_decode ${TEST_DECODE_SRCS}) target_include_directories(test_single_decode PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_single_decode PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) @@ -339,6 +428,13 @@ if (FLASHINFER_DECODE) add_dependencies(test_batch_decode dispatch_inc) target_link_libraries(test_batch_decode PRIVATE gtest gtest_main decode_kernels) target_compile_options(test_batch_decode PRIVATE -Wno-switch-bool) + + if (HIP_FOUND) + set_target_properties(bench_single_decode PROPERTIES LINKER_LANGUAGE HIP) + set_target_properties(test_single_decode PROPERTIES LINKER_LANGUAGE HIP) + set_target_properties(bench_batch_decode PROPERTIES LINKER_LANGUAGE HIP) + set_target_properties(test_batch_decode PROPERTIES LINKER_LANGUAGE HIP) + endif() endif(FLASHINFER_DECODE) if (FLASHINFER_PREFILL) @@ -377,6 +473,13 @@ if (FLASHINFER_PREFILL) add_dependencies(test_batch_prefill dispatch_inc) target_link_libraries(test_batch_prefill PRIVATE gtest gtest_main prefill_kernels) target_compile_options(test_batch_prefill PRIVATE -Wno-switch-bool) + + if (HIP_FOUND) + set_target_properties(bench_single_prefill PROPERTIES LINKER_LANGUAGE HIP) + set_target_properties(test_single_prefill PROPERTIES LINKER_LANGUAGE HIP) + set_target_properties(bench_batch_prefill PROPERTIES LINKER_LANGUAGE HIP) + set_target_properties(test_batch_prefill PROPERTIES LINKER_LANGUAGE HIP) + endif() endif(FLASHINFER_PREFILL) if (FLASHINFER_PAGE) @@ -387,6 +490,10 @@ if (FLASHINFER_PAGE) target_include_directories(test_page PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) target_link_libraries(test_page PRIVATE gtest gtest_main) target_compile_options(test_page PRIVATE -Wno-switch-bool) + + if (HIP_FOUND) + set_target_properties(test_page PROPERTIES LINKER_LANGUAGE HIP) + endif() endif(FLASHINFER_PAGE) if (FLASHINFER_CASCADE) @@ -407,6 +514,10 @@ if (FLASHINFER_CASCADE) add_dependencies(test_cascade dispatch_inc) target_link_libraries(test_cascade PRIVATE gtest gtest_main decode_kernels prefill_kernels) target_compile_options(test_cascade PRIVATE -Wno-switch-bool) + + if (HIP_FOUND) + set_target_properties(test_cascade PROPERTIES LINKER_LANGUAGE HIP) + endif() endif(FLASHINFER_CASCADE) if (FLASHINFER_SAMPLING) @@ -425,27 +536,52 @@ if (FLASHINFER_SAMPLING) target_include_directories(test_sampling PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) target_link_libraries(test_sampling PRIVATE gtest gtest_main) target_compile_options(test_sampling PRIVATE -Wno-switch-bool) + + if (HIP_FOUND) + set_target_properties(bench_sampling PROPERTIES LINKER_LANGUAGE HIP) + set_target_properties(test_sampling PROPERTIES LINKER_LANGUAGE HIP) + endif() endif(FLASHINFER_SAMPLING) -if (FLASHINFER_NORM) +if (TRUE)#(FLASHINFER_NORM) TODO(yiakwy) : fix options message(STATUS "Compile normalization kernel benchmarks.") file(GLOB_RECURSE BENCH_NORM_SRCS ${PROJECT_SOURCE_DIR}/src/bench_norm.cu) - add_executable(bench_norm ${BENCH_NORM_SRCS}) + + if (HIP_FOUND) + set_source_files_properties(${BENCH_NORM_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(bench_norm ${BENCH_NORM_SRCS}) + target_include_directories(bench_norm PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/hipbench) + else(HIP_FOUND) + add_executable(bench_norm ${BENCH_NORM_SRCS}) + target_include_directories(bench_norm PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench) + endif() + target_include_directories(bench_norm PRIVATE ${FLASHINFER_INCLUDE_DIR}) - target_include_directories(bench_norm PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench) target_link_libraries(bench_norm PRIVATE nvbench::main) target_compile_options(bench_norm PRIVATE -Wno-switch-bool) message(STATUS "Compile normalization kernel tests.") file(GLOB_RECURSE TEST_NORM_SRCS ${PROJECT_SOURCE_DIR}/src/test_norm.cu) - add_executable(test_norm ${TEST_NORM_SRCS}) + + if (HIP_FOUND) + set_source_files_properties(${TEST_NORM_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(test_norm ${TEST_NORM_SRCS}) + else(HIP_FOUND) + add_executable(test_norm ${TEST_NORM_SRCS}) + endif() + target_include_directories(test_norm PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_norm PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) target_link_libraries(test_norm PRIVATE gtest gtest_main) target_compile_options(test_norm PRIVATE -Wno-switch-bool) + + if (HIP_FOUND) + set_target_properties(bench_norm PROPERTIES LINKER_LANGUAGE HIP) + set_target_properties(test_norm PROPERTIES LINKER_LANGUAGE HIP) + endif() endif(FLASHINFER_NORM) -if(FLASHINFER_TVM_BINDING) +if (FLASHINFER_TVM_BINDING) message(STATUS "Compile tvm binding.") if(NOT FLASHINFER_TVM_SOURCE_DIR STREQUAL "") set(TVM_SOURCE_DIR_SET ${FLASHINFER_TVM_SOURCE_DIR}) @@ -477,6 +613,10 @@ if(FLASHINFER_FASTDIV_TEST) target_include_directories(test_fastdiv PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_fastdiv PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) target_link_libraries(test_fastdiv PRIVATE gtest gtest_main) + + if (HIP_FOUND) + set_target_properties(test_fastdiv PROPERTIES LINKER_LANGUAGE HIP) + endif() endif(FLASHINFER_FASTDIV_TEST) if(FLASHINFER_FASTDEQUANT_TEST) @@ -486,9 +626,11 @@ if(FLASHINFER_FASTDEQUANT_TEST) target_include_directories(test_fast_dequant PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_fast_dequant PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) target_link_libraries(test_fast_dequant PRIVATE gtest gtest_main) -endif(FLASHINFER_FASTDEQUANT_TEST) - + if (HIP_FOUND) + set_target_properties(test_fast_dequant PROPERTIES LINKER_LANGUAGE HIP) + endif() +endif(FLASHINFER_FASTDEQUANT_TEST) if (FLASHINFER_DISTRIBUTED) find_package(MPI REQUIRED) @@ -506,4 +648,9 @@ if (FLASHINFER_DISTRIBUTED) target_include_directories(test_attn_all_reduce PRIVATE ${FLASHINFER_INCLUDE_DIR} 3rdparty/mscclpp/include 3rdparty/spdlog/include) target_link_libraries(test_attn_all_reduce PRIVATE MPI::MPI_CXX mscclpp) target_compile_definitions(test_attn_all_reduce PRIVATE -DENABLE_MPI) + + if (HIP_FOUND) + set_target_properties(test_sum_all_reduce PROPERTIES LINKER_LANGUAGE HIP) + set_target_properties(test_attn_all_reduce PROPERTIES LINKER_LANGUAGE HIP) + endif() endif(FLASHINFER_DISTRIBUTED) diff --git a/cmake/config.cmake b/cmake/config.cmake index 0d51e491..6ec5e043 100644 --- a/cmake/config.cmake +++ b/cmake/config.cmake @@ -40,4 +40,4 @@ set(FLASHINFER_GEN_MASK_MODES 0 1 2) # So it's recommended to set it to a specific value if you know the architecture of the target GPU. # Example: # set(FLASHINFER_CUDA_ARCHITECTURES 80) -set(FLASHINFER_CUDA_ARCHITECTURES native) +set(FLASHINFER_CUDA_ARCHITECTURES native) \ No newline at end of file diff --git a/cmake/modules/FindThrust.cmake b/cmake/modules/FindThrust.cmake index a0f8008f..19eeeb8d 100644 --- a/cmake/modules/FindThrust.cmake +++ b/cmake/modules/FindThrust.cmake @@ -33,7 +33,9 @@ find_path( THRUST_INCLUDE_DIR /usr/include/cuda /usr/local/include /usr/local/cuda/include + /opt/rocm/include ${CUDA_INCLUDE_DIRS} + ${HIP_INCLUDE_DIRS} NAMES thrust/version.h DOC "Thrust headers" ) diff --git a/cmake/utils/Utils.cmake b/cmake/utils/Utils.cmake index 8d277bb4..17f5e185 100644 --- a/cmake/utils/Utils.cmake +++ b/cmake/utils/Utils.cmake @@ -36,14 +36,18 @@ macro(flashinfer_option variable description value) if("${__value}" MATCHES ";") # list values directly pass through __flashinfer_option(${variable} "${description}" "${__value}") + message(STATUS "1 : creating ${variable} option, description : ${description}, value : ${__value}") elseif(DEFINED ${__value}) if(${__value}) __flashinfer_option(${variable} "${description}" ON) + message(STATUS "2 : creating ${variable} option, description : ${description}, value : ON") else() __flashinfer_option(${variable} "${description}" OFF) + message(STATUS "3 : creating ${variable} option, description : ${description}, value : OFF") endif() else() __flashinfer_option(${variable} "${description}" "${__value}") + message(STATUS "4 : creating ${variable} option, description : ${description}, value : ${__value}") endif() else() unset(${variable} CACHE) diff --git a/include/flashinfer/hip_cuda_type_utils.h b/include/flashinfer/hip_cuda_type_utils.h new file mode 100644 index 00000000..081d00f6 --- /dev/null +++ b/include/flashinfer/hip_cuda_type_utils.h @@ -0,0 +1,72 @@ +/* +Copyright (c) 2024 by LEI WANG +Copyright (c) 2015 - 2023 Advanced Micro Devices, Inc. All rights reserved. + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +*/ +#ifndef FLASHINFER_HIP_CUDA_TYPE_UTILS_H_ +#define FLASHINFER_HIP_CUDA_TYPE_UTILS_H_ + +// namespace flashinfer { + +#if defined(__HIP_PLATFORM_AMD__) + +#include +#include +#include +#include + +// CUDA DEVICE API Supported : https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Device_API_supported_by_HIP.html + +/*! \brief Struct to packet two 16 bit brain floating point numbers. */ +using nv_bfloat162 = __hip_bfloat162; +using __nv_bfloat162 = __hip_bfloat162; + +/*! \brief Struct to represent a 16 bit brain floating point number. */ +using nv_bfloat16 = __hip_bfloat16; +using __nv_bfloat16 = __hip_bfloat16; + +// ROCM FP8 is different from nv FP8 : https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39 + +// TODO (yiakwy) : FP8 datatype support + + +// TODO (yiakwy) : FP8 cast, generic cast, vector cast support + + +// bf16 utils +__BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 x, const __hip_bfloat16 y) +{ + __hip_bfloat162 t; t.x = x; t.y = y; return t; +} + +// Following math functions included in ROCM6.2 SDK : +// __hmul: bfloat16 -> bfloat16, +// __hmul2: bfloat16 -> bfloat16, +// __floats2bfloat162_rn: (float,float) -> __hip_bfloat162, +// __float22bfloat162_rn: float2 -> __hip_bfloat162, +// __float2bfloat162_rn: float -> __hip_bfloat162, +// __bfloat1622float2: __hip_bfloat162 -> float2 + +#endif + +// } // flashinfer + +#endif // FLASHINFER_HIP_CUDA_TYPE_UTILS_H_ + diff --git a/include/flashinfer/hip_defs.h b/include/flashinfer/hip_defs.h new file mode 100644 index 00000000..9b090ab7 --- /dev/null +++ b/include/flashinfer/hip_defs.h @@ -0,0 +1,107 @@ +// adpated from MSC mscclpp project, also see examples from cholla (https://github.com/cholla-hydro/cholla/blob/main/src/utils/gpu.hpp) + +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +#ifndef FLASHINFER_HIP_DEFS_H_ +#define FLASHINFER_HIP_DEFS_H_ + +#if defined(__HIP_PLATFORM_AMD__) + +#include + +using cudaError_t = hipError_t; +using cudaGraph_t = hipGraph_t; +using cudaGraphExec_t = hipGraphExec_t; +using cudaDeviceProp = hipDeviceProp_t; +using cudaStream_t = hipStream_t; +using cudaStreamCaptureMode = hipStreamCaptureMode; +using cudaMemcpyKind = hipMemcpyKind; +using cudaIpcMemHandle_t = hipIpcMemHandle_t; + +using CUresult = hipError_t; +using CUdeviceptr = hipDeviceptr_t; +using CUmemGenericAllocationHandle = hipMemGenericAllocationHandle_t; +using CUmemAllocationProp = hipMemAllocationProp; +using CUmemAccessDesc = hipMemAccessDesc; + +constexpr auto cudaSuccess = hipSuccess; +constexpr auto cudaStreamNonBlocking = hipStreamNonBlocking; +constexpr auto cudaStreamCaptureModeGlobal = hipStreamCaptureModeGlobal; +constexpr auto cudaStreamCaptureModeRelaxed = hipStreamCaptureModeRelaxed; +constexpr auto cudaHostAllocMapped = hipHostMallocMapped; +constexpr auto cudaHostAllocWriteCombined = hipHostMallocWriteCombined; +constexpr auto cudaMemcpyDefault = hipMemcpyDefault; +constexpr auto cudaMemcpyDeviceToDevice = hipMemcpyDeviceToDevice; +constexpr auto cudaMemcpyHostToDevice = hipMemcpyHostToDevice; +constexpr auto cudaMemcpyDeviceToHost = hipMemcpyDeviceToHost; +constexpr auto cudaIpcMemLazyEnablePeerAccess = hipIpcMemLazyEnablePeerAccess; + +constexpr auto CU_MEM_ALLOCATION_TYPE_PINNED = hipMemAllocationTypePinned; +constexpr auto CU_MEM_LOCATION_TYPE_DEVICE = hipMemLocationTypeDevice; +constexpr auto CU_MEM_HANDLE_TYPE_POSIX_FILE_DESCRIPTOR = hipMemHandleTypePosixFileDescriptor; +constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWrite; + +#ifndef CUDA_SUCCESS +#define CUDA_SUCCESS hipSuccess +#endif // CUDA_SUCCESS + +#define cudaGetErrorString(...) hipGetErrorString(__VA_ARGS__) +#define cudaGetDevice(...) hipGetDevice(__VA_ARGS__) +#define cudaGetDeviceCount(...) hipGetDeviceCount(__VA_ARGS__) +#define cudaGetDeviceProperties(...) hipGetDeviceProperties(__VA_ARGS__) +#define cudaGetLastError(...) hipGetLastError(__VA_ARGS__) +#define cudaSetDevice(...) hipSetDevice(__VA_ARGS__) +#define cudaDeviceSynchronize(...) hipDeviceSynchronize(__VA_ARGS__) +#define cudaDeviceGetPCIBusId(...) hipDeviceGetPCIBusId(__VA_ARGS__) +#define cudaHostAlloc(...) hipHostMalloc(__VA_ARGS__) +#define cudaMalloc(...) hipMalloc(__VA_ARGS__) +#define cudaFree(...) hipFree(__VA_ARGS__) +#define cudaFreeHost(...) hipHostFree(__VA_ARGS__) +#define cudaMemset(...) hipMemset(__VA_ARGS__) +#define cudaMemsetAsync(...) hipMemsetAsync(__VA_ARGS__) +#define cudaMemcpy(...) hipMemcpy(__VA_ARGS__) +#define cudaMemcpyAsync(...) hipMemcpyAsync(__VA_ARGS__) +#define cudaMemcpyToSymbol(...) hipMemcpyToSymbol(__VA_ARGS__) +#define cudaMemcpyToSymbolAsync(...) hipMemcpyToSymbolAsync(__VA_ARGS__) +#define cudaStreamCreate(...) hipStreamCreate(__VA_ARGS__) +#define cudaStreamCreateWithFlags(...) hipStreamCreateWithFlags(__VA_ARGS__) +#define cudaStreamSynchronize(...) hipStreamSynchronize(__VA_ARGS__) +#define cudaStreamBeginCapture(...) hipStreamBeginCapture(__VA_ARGS__) +#define cudaStreamEndCapture(...) hipStreamEndCapture(__VA_ARGS__) +#define cudaStreamDestroy(...) hipStreamDestroy(__VA_ARGS__) +#define cudaGraphInstantiate(...) hipGraphInstantiate(__VA_ARGS__) +#define cudaGraphLaunch(...) hipGraphLaunch(__VA_ARGS__) +#define cudaGraphDestroy(...) hipGraphDestroy(__VA_ARGS__) +#define cudaGraphExecDestroy(...) hipGraphExecDestroy(__VA_ARGS__) +#define cudaThreadExchangeStreamCaptureMode(...) hipThreadExchangeStreamCaptureMode(__VA_ARGS__) +#define cudaIpcGetMemHandle(...) hipIpcGetMemHandle(__VA_ARGS__) +#define cudaIpcOpenMemHandle(...) hipIpcOpenMemHandle(__VA_ARGS__) +#define cudaIpcCloseMemHandle(...) hipIpcCloseMemHandle(__VA_ARGS__) + +#define cuGetErrorString(...) hipDrvGetErrorString(__VA_ARGS__) +#define cuMemAddressReserve(...) hipMemAddressReserve(__VA_ARGS__) +#define cuMemAddressFree(...) hipMemAddressFree(__VA_ARGS__) +#define cuMemGetAddressRange(...) hipMemGetAddressRange(__VA_ARGS__) +#define cuMemCreate(...) hipMemCreate(__VA_ARGS__) +#define cuMemRelease(...) hipMemRelease(__VA_ARGS__) +#define cuMemSetAccess(...) hipMemSetAccess(__VA_ARGS__) +#define cuMemMap(...) hipMemMap(__VA_ARGS__) +#define cuMemUnmap(...) hipMemUnmap(__VA_ARGS__) + +#else + +#include +#include + +#endif + +// NVLS +#if !defined(__HIP_PLATFORM_AMD__) +#include +#define USE_NVLS ((CUDART_VERSION >= 12010) && (LINUX_VERSION_CODE >= KERNEL_VERSION(5, 6, 0))) +#else // !defined(__HIP_PLATFORM_AMD__) +#define USE_NVLS 0 +#endif // !defined(__HIP_PLATFORM_AMD__) + +#endif // FLASHINFER_HIP_DEFS_H_ \ No newline at end of file diff --git a/include/flashinfer/hip_warp_sync_functions.h b/include/flashinfer/hip_warp_sync_functions.h new file mode 100644 index 00000000..df699e5b --- /dev/null +++ b/include/flashinfer/hip_warp_sync_functions.h @@ -0,0 +1,72 @@ +// ported from #include +#ifndef FLASHINFER_HIP_WARP_SYNC_FUNCTIONS_PORTED_H_ +#define FLASHINFER_HIP_WARP_SYNC_FUNCTIONS_PORTED_H_ + +#include + +#define __hip_adjust_mask_for_wave32(MASK) \ + do { \ + if (warpSize == 32) MASK &= 0xFFFFFFFF; \ + } while (0) + +#if defined(NDEBUG) +#define __hip_assert(COND) +#else +#define __hip_assert(COND) \ + do { \ + if (!(COND)) \ + __builtin_trap(); \ + } while (0) +#endif + +template +__device__ inline +T __hip_readfirstlane(T val) { + // In theory, behaviour is undefined when reading from a union member other + // than the member that was last assigned to, but it works in practice because + // we rely on the compiler to do the reasonable thing. + union { + unsigned long long l; + T d; + } u; + u.d = val; + // NOTE: The builtin returns int, so we first cast it to unsigned int and only + // then extend it to 64 bits. + unsigned long long lower = (unsigned)__builtin_amdgcn_readfirstlane(u.l); + unsigned long long upper = + (unsigned)__builtin_amdgcn_readfirstlane(u.l >> 32); + u.l = (upper << 32) | lower; + return u.d; +} + +#define __hip_check_mask(MASK) \ + do { \ + __hip_assert(MASK && "mask must be non-zero"); \ + bool done = false; \ + while (__any(!done)) { \ + if (!done) { \ + auto chosen_mask = __hip_readfirstlane(MASK); \ + if (MASK == chosen_mask) { \ + __hip_assert(MASK == __ballot(true) && \ + "all threads specified in the mask" \ + " must execute the same operation with the same mask"); \ + done = true; \ + } \ + } \ + } \ + } while(0) + +template +__device__ inline +T __shfl_xor_sync(MaskT mask, T var, int laneMask, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl_xor(var, laneMask, width); +} + +#endif \ No newline at end of file diff --git a/include/flashinfer/math.cuh b/include/flashinfer/math.cuh index c2401c7e..55267848 100644 --- a/include/flashinfer/math.cuh +++ b/include/flashinfer/math.cuh @@ -16,9 +16,24 @@ #ifndef FLASHINFER_MATH_CUH_ #define FLASHINFER_MATH_CUH_ +#ifdef USE_ROCM + +#include +// TODO (yiakwy) : functions not included +#include +#include "flashinfer/hip_warp_sync_functions.h" +#include "flashinfer/hip_cuda_type_utils.h" + +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#else + #include #include +#endif // USE_ROCM-1 + namespace flashinfer { namespace math { @@ -29,6 +44,141 @@ __forceinline__ __device__ half2 uint32_as_half2(uint32_t x) { return *(half2*)& __forceinline__ __device__ uint32_t half2_as_uint32(half2 x) { return *(uint32_t*)&x; } + +#ifdef USE_ROCM + +#include + +namespace amdgpu { + +// ROCM exp c primitive, which computes 2^x in fp8/fp16/bf16/fp32 +template +__forceinline__ __device__ T exp2(T); + +template +__forceinline__ __device__ T log2(T); + +template +__forceinline__ __device__ T rcp(T); + +template +__forceinline__ __device__ T shfl_xor_sync(T, int); + +template +__forceinline__ __device__ T rsqrt(T); + +// sepicalization + +// TODO (yiakwy) : add equivalent asm version for fast exp computation (polynomial approx) +template<> +inline __device__ float exp2(float x) { + return exp2f(x); +} + +template<> +inline __device__ half exp2(half x) { + return hexp2(x); +} + +template<> +__forceinline__ __device__ float log2(float x) { + return log2f(x); +} + +template<> +inline __device__ half log2(half x) { + return hlog2(x); +} + +template<> +__forceinline__ __device__ float rcp(float x) { + // TODO (yiakwy) : __frcp_rn is not supported in ROCM 6.2 + return 1.f / x; +} + +// TODO (yiakwy) : verify; see details from here : https://rocm.docs.amd.com/projects/HIP/en/develop/reference/kernel_language.html +template<> +__forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) { + // note AMD uses 8 byte mask (i.e. long datatype) to allow all 64 threads participate in + // TODO (yiakwy) : this does not work + // return __shfl_xor_sync(0xffffffffffffffff, x, lane_mask); + // TODO (yiakwy) : workaround + return __shfl_xor(x, lane_mask); +} + +template<> +__forceinline__ __device__ half shfl_xor_sync(half x, int lane_mask) { + // note AMD uses 8 byte mask (i.e. long datatype) + // TODO (yiakwy) : this does not work + // return __shfl_xor_sync(0xffffffffffffffff, x, lane_mask); + // TODO (yiakwy) : workaround + return __shfl_xor(x, lane_mask); +} + +template<> +__forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) { + // note AMD uses 8 byte mask (i.e. long datatype) + // TODO (yiakwy) : this does not work + // return __shfl_xor_sync(0xffffffffffffffff, x, lane_mask); + // TODO (yiakwy) : workaround + return __shfl_xor(x, lane_mask); +} + +template<> +__forceinline__ __device__ float rsqrt(float x) { + return rsqrtf(x); +} + +} // amdgpu + +/*! + * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x + * \param x input + */ +__forceinline__ __device__ float ptx_exp2(float x) { + return amdgpu::exp2(x); +} + + +/*! + * \brief Wrapper of PTX lg2.approx instruction, which computes log2(x) + * \param x input + */ +__forceinline__ __device__ float ptx_log2(float x) { + return amdgpu::log2(x); +} + + +/*! + * \brief Wrapper of PTX rcp.approx instruction, which computes 1/x + * \param x input + */ +__forceinline__ __device__ float ptx_rcp(float x) { + return amdgpu::rcp(x); +} + +/*! + * \brief Wrapper of PTX shfl.sync.bfly instruction, which performs a butterfly shuffle + * between threads in a warp. + * \param x The value in the source lane + * \param lane_mask The mask to perform thread index xor with: y[i] <- x[i ^ delta] + */ +__forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) { + return amdgpu::shfl_xor_sync(x, lane_mask); +} + +/*! + * \brief Wrapper of PTX rsqrt approximation instruction, which computes 1/sqrt(x) + * \param x input + */ +__forceinline__ __device__ float rsqrt(float x) { + return amdgpu::rsqrt(x); +} + +#else + +// NVIDIA PTX exlusive codes + /*! * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x * \param x input @@ -145,6 +295,8 @@ __forceinline__ __device__ half tanh(half x) { return __ushort_as_half(y_u16); } +#endif // USE_ROCM-2 + } // namespace math } // namespace flashinfer #endif // FLASHINFER_MATH_CUH_ diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 82d2513d..26682101 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -109,7 +109,11 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_ DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = RMSNormKernel; + #if USE_ROCM + FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + #endif }); return cudaSuccess; } @@ -206,7 +210,11 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = FusedAddRMSNormKernel; + #if USE_ROCM + FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + #endif }); return cudaSuccess; @@ -293,7 +301,11 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = GemmaRMSNormKernel; + #if USE_ROCM + FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + #endif }); return cudaSuccess; } @@ -390,7 +402,11 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = GemmaFusedAddRMSNormKernel; + #if USE_ROCM + FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); + #endif }); return cudaSuccess; diff --git a/include/flashinfer/page.cuh b/include/flashinfer/page.cuh index d79a5ff0..9adbdea8 100644 --- a/include/flashinfer/page.cuh +++ b/include/flashinfer/page.cuh @@ -16,6 +16,15 @@ #ifndef FLASHINFER_PAGE_CUH_ #define FLASHINFER_PAGE_CUH_ +#ifdef USE_ROCM + +#include + +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#endif // USE_ROCM + #include #include "fastdiv.cuh" @@ -451,7 +460,11 @@ cudaError_t AppendPagedKVCacheDecode(paged_kv_t pag dim3 nthrs(bdx, bdy); auto kernel = AppendPagedKVCacheDecodeKernel; void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value}; + #ifdef USE_ROCM + FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #endif }); return cudaSuccess; } @@ -484,7 +497,11 @@ cudaError_t AppendPagedKVCache(paged_kv_t paged_kv, dim3 nthrs(bdx, bdy); auto kernel = AppendPagedKVCachePrefillKernel; void* args[] = {(void*)&paged_kv, (void*)&key, (void*)&value, (void*)&append_indptr}; + #ifdef USE_ROCM + FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #endif }); return cudaSuccess; } diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 15b4a8d9..3cfc7732 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -19,6 +19,14 @@ #include #include +#ifdef USE_ROCM + +#include + +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" +#endif // USE_ROCM + #include "layout.cuh" #include "math.cuh" #include "utils.cuh" @@ -318,7 +326,11 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ (void*)&smooth_b, (void*)&rope_rcp_scale, (void*)&rope_rcp_theta}; + #if USE_ROCM + FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #endif }); }); @@ -362,7 +374,11 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace( (void*)&smooth_b, (void*)&rope_rcp_scale, (void*)&rope_rcp_theta}; + #if USE_ROCM + FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #endif }); }); @@ -408,7 +424,11 @@ cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k, (void*)&smooth_b, (void*)&rope_rcp_scale, (void*)&rope_rcp_theta}; + #if USE_ROCM + FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #endif }); }); @@ -456,7 +476,11 @@ cudaError_t BatchQKApplyLlama31Rotary(DType* __restrict__ q, DType* __restrict__ (void*)&smooth_b, (void*)&rope_rcp_scale, (void*)&rope_rcp_theta}; + #if USE_ROCM + FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); + #endif }); }); diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 856b5325..f01f90a8 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -15,12 +15,25 @@ */ #ifndef FLASHINFER_UTILS_CUH_ #define FLASHINFER_UTILS_CUH_ + +#ifdef USE_ROCM + +#include + +#include "flashinfer/hip_cuda_type_utils.h" +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#else + #include #include #include #include #include +#endif + #include #include #include @@ -249,6 +262,7 @@ __forceinline__ __device__ __host__ T1 ceil_div(const T1 x, const T2 y) { return (x + y - 1) / y; } +#ifdef ROCM inline std::pair GetCudaComputeCapability() { int device_id = 0; cudaGetDevice(&device_id); @@ -257,6 +271,18 @@ inline std::pair GetCudaComputeCapability() { cudaDeviceGetAttribute(&minor, cudaDevAttrComputeCapabilityMinor, device_id); return std::make_pair(major, minor); } +#else + +// see hip device initialization and version +inline std::pair GetCudaComputeCapability() { + int device_id = 0; + hipGetDevice(&device_id); + int major = 0, minor = 0; + hipDeviceComputeCapability(&major, &minor, device_id); + return std::make_pair(major, minor); +} + +#endif template inline void DebugPrintCUDAArray(T* device_ptr, size_t size, std::string prefix = "") { diff --git a/include/flashinfer/vec_dtypes.cuh b/include/flashinfer/vec_dtypes.cuh index 3932c0d3..69b19ec7 100644 --- a/include/flashinfer/vec_dtypes.cuh +++ b/include/flashinfer/vec_dtypes.cuh @@ -16,21 +16,44 @@ #ifndef VEC_DTYPES_CUH_ #define VEC_DTYPES_CUH_ +#ifdef USE_ROCM + +#include + +#include "flashinfer/hip_cuda_type_utils.h" +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#else + #include #include #include #include +#endif // USE_ROCM + #include namespace flashinfer { -#if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900)) +// TODO (yiakwy) : remove +// #if (!defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 900)) +#if __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH__) #define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED #endif +#if USE_ROCM +// TODO(yiakwy) : since roc fp8 is different from NV fp8, more efforts need to port functionalities +#ifdef FLASHINFER_FP8_ENABLED +#undef FLASHINFER_FP8_ENABLED +#endif + +#endif + #define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ +// TODO (yiakwy) : add support in HIP #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 < 120400) && \ (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) // CUDA version < 12.4 and GPU architecture < 80 @@ -119,6 +142,8 @@ struct vec_cast { } }; +#ifdef FLASHINFER_FP8_ENABLED + template constexpr FLASHINFER_INLINE int get_exponent_bits() { if constexpr (std::is_same::value) { @@ -187,7 +212,7 @@ __device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) { *(half2*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); } else { constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); + const nv_bfloat162 bias_reg = __float22bfloat162_rn(*reinterpret_cast(&BIAS)); // Convert to bfloat162 and apply bias *(nv_bfloat162*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); @@ -353,6 +378,8 @@ struct vec_cast { } }; +#endif // FLASHINFER_FP8_ENABLED + template <> struct vec_cast { template @@ -433,6 +460,8 @@ FLASHINFER_INLINE void cast_store_impl(tgt_float_t* dst_ptr, /******************* vec_t<__nv_fp8_e4m3> *******************/ +#ifdef FLASHINFER_FP8_ENABLED + // __nv_fp8_e4m3 x 1 template <> struct vec_t<__nv_fp8_e4m3, 1> { @@ -925,6 +954,8 @@ struct vec_t<__nv_fp8_e5m2, vec_size> { } }; +#endif // FLASHINFER_FP8_ENABLED + /******************* vec_t *******************/ // half x 1 diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index 8098661f..d04b972a 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -209,3 +209,8 @@ TEST(FlashInferCorrectnessTest, TestBatchDecodeKernelCorrectnessE5M2) { TEST(FlashInferCorrectnessTest, TestCooperativeBatchDecodeKernelCorrectnessTestFP16) { TestCooperativeBatchDecodeKernelCorrectness(); } + +int main(int argc, char **argv) { + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} \ No newline at end of file diff --git a/src/test_norm.cu b/src/test_norm.cu index 082c8827..be8b2f91 100644 --- a/src/test_norm.cu +++ b/src/test_norm.cu @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include // isnan used + #include #include @@ -24,6 +26,7 @@ using namespace flashinfer; template void _TestRMSNormCorrectness(uint32_t batch_size, uint32_t d) { + std::vector x_host(batch_size * d); std::vector w_host(d); @@ -36,7 +39,7 @@ void _TestRMSNormCorrectness(uint32_t batch_size, uint32_t d) { thrust::device_vector x_device(x_host); thrust::device_vector w_device(w_host); thrust::device_vector y_device(batch_size * d); - + cudaError_t status = norm::RMSNorm( thrust::raw_pointer_cast(x_device.data()), thrust::raw_pointer_cast(w_device.data()), thrust::raw_pointer_cast(y_device.data()), batch_size, d, 1e-6); @@ -47,7 +50,7 @@ void _TestRMSNormCorrectness(uint32_t batch_size, uint32_t d) { bool nan_detected = false; size_t num_result_errors_atol_1e_3_rtol_1e_3 = 0; for (uint i = 0; i < batch_size * d; i++) { - if (isnan(float(y_host[i]))) { + if (std::isnan(float(y_host[i]))) { nan_detected = true; } num_result_errors_atol_1e_3_rtol_1e_3 += @@ -66,11 +69,11 @@ void _TestRMSNormCorrectness(uint32_t batch_size, uint32_t d) { template void TestRMSNormCorrectness() { - for (size_t batch_size : {1, 3, 7, 19, 733}) { - for (size_t d : {37, 128, 512, 1002, 3072, 4096, 8192, 16384}) { + for (size_t batch_size : {1}) { // {1, 3, 7, 19, 733} + for (size_t d : {3}) { // {37, 128, 512, 1002, 3072, 4096, 8192, 16384} _TestRMSNormCorrectness(batch_size, d); } } } -TEST(FlashInferCorrectnessTests, TestRMSNormFP16) { TestRMSNormCorrectness(); } +TEST(FlashInferCorrectnessTests, TestRMSNormFP16) { TestRMSNormCorrectness(); } \ No newline at end of file diff --git a/src/utils.h b/src/utils.h index 6785180e..015808b7 100644 --- a/src/utils.h +++ b/src/utils.h @@ -15,10 +15,18 @@ */ #pragma once +#ifdef USE_ROCM +#include +#include +#else + #include #include #include #include + +#endif + #include #include #include diff --git a/src/vec_dtypes.h b/src/vec_dtypes.h new file mode 100644 index 00000000..e69de29b From 5dfa16aa3446cbc22a96f50f1adaa82097315273 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Sun, 22 Sep 2024 11:50:42 +0000 Subject: [PATCH 02/15] add support decode --- CMakeLists.txt | 33 ++- include/flashinfer/attention/cascade.cuh | 8 + include/flashinfer/attention/decode.cuh | 16 +- include/flashinfer/attention/handler.cuh | 14 + include/flashinfer/attention/prefill.cuh | 13 + include/flashinfer/cp_async.cuh | 8 + include/flashinfer/decode_attention_decl.cuh | 10 + include/flashinfer/frag_layout_swizzle.cuh | 38 +++ include/flashinfer/hip_cuda_type_utils.h | 10 + include/flashinfer/hip_defs.h | 52 +++- include/flashinfer/hip_warp_sync_functions.h | 24 +- include/flashinfer/math.cuh | 63 ++++ include/flashinfer/mma.cuh | 10 + include/flashinfer/permuted_smem.cuh | 12 + include/flashinfer/prefill_attention_decl.cuh | 10 + include/flashinfer/utils.cuh | 2 +- include/flashinfer/vec_dtypes.cuh | 120 +++++++- include/hip/barrier.h | 80 +++++ include/hip/pipeline.h | 277 ++++++++++++++++++ src/test_single_decode.cu | 3 + 20 files changed, 792 insertions(+), 11 deletions(-) create mode 100644 include/hip/barrier.h create mode 100644 include/hip/pipeline.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 4cc71ad0..440cfda0 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -38,6 +38,7 @@ project(flashinfer C CXX) # set CMAKE_CXX_COMPILER to hipcc # set(CMAKE_FIND_DEBUG_MODE TRUE) +add_definitions(-Wall) find_package(HIP QUIET) if(HIP_FOUND) message(STATUS "Found HIP: " ${HIP_VERSION}) @@ -268,7 +269,15 @@ foreach(head_dim IN LISTS HEAD_DIMS) endforeach(logits_post_hook) endforeach(head_dim) -add_library(decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src}) +# TODO (yiakwy) : override add_libraries, rename sources +if (HIP_FOUND) + set_source_files_properties(${single_decode_kernels_sr} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + set_source_files_properties(${batch_decode_kernels_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_library(decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src}) +elseif(HIP_FOUND) + add_library(decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src}) +endif() + target_include_directories(decode_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_compile_options(decode_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all) if (HIP_FOUND) @@ -404,7 +413,14 @@ if (FLASHINFER_DECODE) message(STATUS "Compile single decode kernel tests.") file(GLOB_RECURSE TEST_DECODE_SRCS ${PROJECT_SOURCE_DIR}/src/test_single_decode.cu) message(STATUS "test source : ${TEST_DECODE_SRCS}") - add_executable(test_single_decode ${TEST_DECODE_SRCS}) + + if (HIP_FOUND) + set_source_files_properties(${TEST_DECODE_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(test_single_decode ${TEST_DECODE_SRCS}) + else(HIP_FOUND) + add_executable(test_single_decode ${TEST_DECODE_SRCS}) + endif() + target_include_directories(test_single_decode PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_single_decode PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) add_dependencies(test_single_decode dispatch_inc) @@ -413,9 +429,18 @@ if (FLASHINFER_DECODE) message(STATUS "Compile batch decode kernel benchmarks.") file(GLOB_RECURSE BENCH_DECODE_SRCS ${PROJECT_SOURCE_DIR}/src/bench_batch_decode.cu) - add_executable(bench_batch_decode ${BENCH_DECODE_SRCS}) + + if (HIP_FOUND) + set_source_files_properties(${BENCH_DECODE_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(bench_batch_decode ${BENCH_DECODE_SRCS}) + target_include_directories(bench_batch_decode PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/hipbench) + else(HIP_FOUND) + add_executable(bench_batch_decode ${BENCH_DECODE_SRCS}) + target_include_directories(bench_batch_decode PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench) + endif() + target_include_directories(bench_batch_decode PRIVATE ${FLASHINFER_INCLUDE_DIR}) - target_include_directories(bench_batch_decode PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench) + add_dependencies(bench_batch_decode dispatch_inc) target_link_libraries(bench_batch_decode PRIVATE nvbench::main decode_kernels prefill_kernels) target_compile_options(bench_batch_decode PRIVATE -Wno-switch-bool) diff --git a/include/flashinfer/attention/cascade.cuh b/include/flashinfer/attention/cascade.cuh index 9d71e7bf..bbb280ac 100644 --- a/include/flashinfer/attention/cascade.cuh +++ b/include/flashinfer/attention/cascade.cuh @@ -16,7 +16,15 @@ #ifndef FLASHINFER_CASCADE_CUH_ #define FLASHINFER_CASCADE_CUH_ +#ifdef USE_ROCM + +#include +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +# else #include +#endif // USE_ROCM #include "../cp_async.cuh" #include "../math.cuh" diff --git a/include/flashinfer/attention/decode.cuh b/include/flashinfer/attention/decode.cuh index c1bf4cc7..6e13a1a6 100644 --- a/include/flashinfer/attention/decode.cuh +++ b/include/flashinfer/attention/decode.cuh @@ -15,14 +15,28 @@ */ #ifndef FLASHINFER_DECODE_CUH_ #define FLASHINFER_DECODE_CUH_ + +#ifdef USE_ROCM + +#include +#include + +#include "flashinfer/hip_cuda_type_utils.h" +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +# else #include #include #include #include #include +// this is used +#include +#endif // USE_ROCM #include -#include + #include #include #include diff --git a/include/flashinfer/attention/handler.cuh b/include/flashinfer/attention/handler.cuh index e29b99c4..ce05ef87 100644 --- a/include/flashinfer/attention/handler.cuh +++ b/include/flashinfer/attention/handler.cuh @@ -16,9 +16,23 @@ #ifndef FLASHINFER_ATTENTION_HANDLER_CUH_ #define FLASHINFER_ATTENTION_HANDLER_CUH_ +#ifdef USE_ROCM + +#include +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#include + +#else + #include + +// Note this is part of NV SDK #include +#endif // USE_ROCM + #include #include #include diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 5ad6988c..cf059efd 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -15,12 +15,25 @@ */ #ifndef FLASHINFER_PREFILL_CUH_ #define FLASHINFER_PREFILL_CUH_ + +#ifdef USE_ROCM + +#include + +#include "flashinfer/hip_cuda_type_utils.h" +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#else + #include #include #include #include #include +#endif // USE_ROCM + #include "../cp_async.cuh" #include "../fastdiv.cuh" #include "../frag_layout_swizzle.cuh" diff --git a/include/flashinfer/cp_async.cuh b/include/flashinfer/cp_async.cuh index 9ca851fb..883a448f 100644 --- a/include/flashinfer/cp_async.cuh +++ b/include/flashinfer/cp_async.cuh @@ -16,7 +16,15 @@ #ifndef FLASHINFER_CP_ASYNC_CUH_ #define FLASHINFER_CP_ASYNC_CUH_ +#ifdef USE_ROCM + +#include +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#else #include +#endif // USE_ROCM namespace flashinfer { diff --git a/include/flashinfer/decode_attention_decl.cuh b/include/flashinfer/decode_attention_decl.cuh index 6f9ccf6f..56774d1e 100644 --- a/include/flashinfer/decode_attention_decl.cuh +++ b/include/flashinfer/decode_attention_decl.cuh @@ -16,8 +16,18 @@ #ifndef FLASHINFER_DECODE_ATTENTION_DECL_CUH_ #define FLASHINFER_DECODE_ATTENTION_DECL_CUH_ +#ifdef USE_ROCM + +#include +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#else + #include +#endif // USE_ROCM + #include "attention/handler.cuh" #include "attention/logits_post_hook.cuh" #include "layout.cuh" diff --git a/include/flashinfer/frag_layout_swizzle.cuh b/include/flashinfer/frag_layout_swizzle.cuh index 39cf92bc..39bebd7a 100644 --- a/include/flashinfer/frag_layout_swizzle.cuh +++ b/include/flashinfer/frag_layout_swizzle.cuh @@ -16,24 +16,62 @@ #ifndef FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_ #define FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_ +#if USE_ROCM + +#include + +#else + #include +#endif + #include __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) { + // TODO (yiakwy) : override __shfl_xor_sync for 32 bit mask + #ifdef USE_ROCM + uint32_t tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x1); + #else uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x1); + #endif + x = __byte_perm(x, tmp, ((threadIdx.x & 0x1) == 0) ? 0x5410 : 0x3276); + + #ifdef USE_ROCM + tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x2); + #else tmp = __shfl_xor_sync(0xffffffff, x, 0x2); + #endif + x = __byte_perm(x, tmp, ((threadIdx.x & 0x2) == 0) ? 0x5410 : 0x3276); return x; } __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) { + // TODO (yiakwy) : override __shfl_xor_sync for 32 bit mask + #ifdef USE_ROCM + uint32_t tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x4); + #else uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4); + #endif + x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175); + + #ifdef USE_ROCM + tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x8); + #else tmp = __shfl_xor_sync(0xffffffff, x, 0x8); + #endif + x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276); + + #ifdef USE_ROCM + tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x10); + #else tmp = __shfl_xor_sync(0xffffffff, x, 0x10); + #endifgi + x = __byte_perm(x, tmp, ((threadIdx.x & 0x10) == 0) ? 0x5410 : 0x3276); return x; } diff --git a/include/flashinfer/hip_cuda_type_utils.h b/include/flashinfer/hip_cuda_type_utils.h index 081d00f6..1bb26023 100644 --- a/include/flashinfer/hip_cuda_type_utils.h +++ b/include/flashinfer/hip_cuda_type_utils.h @@ -42,6 +42,8 @@ using __nv_bfloat162 = __hip_bfloat162; using nv_bfloat16 = __hip_bfloat16; using __nv_bfloat16 = __hip_bfloat16; +using half2 = __half2; + // ROCM FP8 is different from nv FP8 : https://github.com/ROCm/rocBLAS/blob/9b7f692abe3c54b88d1e77e045a7db7f1f188b69/library/include/internal/rocblas_hip_f8_impl.h#L39 // TODO (yiakwy) : FP8 datatype support @@ -64,6 +66,14 @@ __BF16_HOST_DEVICE_STATIC__ __hip_bfloat162 make_bfloat162(const __hip_bfloat16 // __float2bfloat162_rn: float -> __hip_bfloat162, // __bfloat1622float2: __hip_bfloat162 -> float2 +// half utils +// TODO (yiakwy) : add native half2 support implementation +__device__ half2 __hmax2(const half2 a, const half2 b) { + return half2{ + __float2half(__ocml_fmax_f32(__half2float(a.x), __half2float(b.x))), + __float2half(__ocml_fmax_f32(__half2float(a.y), __half2float(b.y)))}; +} + #endif // } // flashinfer diff --git a/include/flashinfer/hip_defs.h b/include/flashinfer/hip_defs.h index 9b090ab7..09475bc3 100644 --- a/include/flashinfer/hip_defs.h +++ b/include/flashinfer/hip_defs.h @@ -1,15 +1,62 @@ // adpated from MSC mscclpp project, also see examples from cholla (https://github.com/cholla-hydro/cholla/blob/main/src/utils/gpu.hpp) - +// Copyright LEI WANG (yiak.wy@gmail.com) // Copyright (c) Microsoft Corporation. // Licensed under the MIT license. #ifndef FLASHINFER_HIP_DEFS_H_ #define FLASHINFER_HIP_DEFS_H_ +#ifndef __HIP_PLATFORM_AMD__ +#define __HIP_PLATFORM_AMD__ +#endif + +#ifdef __HIP_PLATFORM_NVIDIA__ +#undef __HIP_PLATFORM_NVIDIA__ +#endif + #if defined(__HIP_PLATFORM_AMD__) #include - +#include + +// enum alias +using cudaFuncAttribute = hipFuncAttribute; +const cudaFuncAttribute cudaFuncAttributeMaxDynamicSharedMemorySize = hipFuncAttribute::hipFuncAttributeMaxDynamicSharedMemorySize; +const cudaFuncAttribute cudaFuncAttributePreferredSharedMemoryCarveout = hipFuncAttribute::hipFuncAttributePreferredSharedMemoryCarveout; +const cudaFuncAttribute cudaFuncAttributeMax = hipFuncAttribute::hipFuncAttributeMax; + +using cudaDeviceAttr = hipDeviceAttribute_t; +// Number of multiprocessors on the device +const cudaDeviceAttr cudaDevAttrMultiProcessorCount = hipDeviceAttribute_t::hipDeviceAttributeMultiprocessorCount; +const cudaDeviceAttr cudaDevAttrMaxSharedMemoryPerMultiprocessor = hipDeviceAttribute_t::hipDeviceAttributeMaxSharedMemoryPerMultiprocessor; + +// function alas +template +inline static hipError_t cudaFuncSetAttribute(Func&& func, const hipFuncAttribute& attr, int value) { + return hipFuncSetAttribute((void*)func, attr, value); +} + +template +static __inline__ __host__ __device__ +auto cudaLaunchKernel(Args&&... args) -> decltype(hipLaunchKernel(std::forward(args)...)) { + return hipLaunchKernel(std::forward(args)...); +} + +static __inline__ __host__ __device__ +hipError_t cudaDeviceGetAttribute(int *value, cudaDeviceAttr attr, int device) { + return hipDeviceGetAttribute(value, attr, device); +} + +template +inline static hipError_t cudaOccupancyMaxActiveBlocksPerMultiprocessor(int* numBlocks, + Func func, + int blockSize, + size_t dynamicSMemSize) { + return hipOccupancyMaxActiveBlocksPerMultiprocessor(numBlocks, (void*)func, + blockSize, dynamicSMemSize); +} + +// Type alias using cudaError_t = hipError_t; using cudaGraph_t = hipGraph_t; using cudaGraphExec_t = hipGraphExec_t; @@ -56,6 +103,7 @@ constexpr auto CU_MEM_ACCESS_FLAGS_PROT_READWRITE = hipMemAccessFlagsProtReadWri #define cudaDeviceGetPCIBusId(...) hipDeviceGetPCIBusId(__VA_ARGS__) #define cudaHostAlloc(...) hipHostMalloc(__VA_ARGS__) #define cudaMalloc(...) hipMalloc(__VA_ARGS__) +#define cudaMallocHost(...) hipMallocHost(__VA_ARGS__) #define cudaFree(...) hipFree(__VA_ARGS__) #define cudaFreeHost(...) hipHostFree(__VA_ARGS__) #define cudaMemset(...) hipMemset(__VA_ARGS__) diff --git a/include/flashinfer/hip_warp_sync_functions.h b/include/flashinfer/hip_warp_sync_functions.h index df699e5b..beffc12d 100644 --- a/include/flashinfer/hip_warp_sync_functions.h +++ b/include/flashinfer/hip_warp_sync_functions.h @@ -1,12 +1,18 @@ -// ported from #include +// ported from in SDK 6.2 #ifndef FLASHINFER_HIP_WARP_SYNC_FUNCTIONS_PORTED_H_ #define FLASHINFER_HIP_WARP_SYNC_FUNCTIONS_PORTED_H_ #include +// note in SDK we have this statement device_prop.warpSize +#ifndef __warpSize +#define __warpSize 64 +#endif + +// compiling for 64 bit, ignoring upper 32 bit #define __hip_adjust_mask_for_wave32(MASK) \ do { \ - if (warpSize == 32) MASK &= 0xFFFFFFFF; \ + if (__warpSize == 32) MASK &= 0xFFFFFFFF; \ } while (0) #if defined(NDEBUG) @@ -69,4 +75,18 @@ T __shfl_xor_sync(MaskT mask, T var, int laneMask, return __shfl_xor(var, laneMask, width); } +// used by libhipcxx +template +__device__ inline +T __shfl_sync(MaskT mask, T var, int srcLane, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl(var, srcLane, width); +} + #endif \ No newline at end of file diff --git a/include/flashinfer/math.cuh b/include/flashinfer/math.cuh index 55267848..76a344bc 100644 --- a/include/flashinfer/math.cuh +++ b/include/flashinfer/math.cuh @@ -67,6 +67,9 @@ __forceinline__ __device__ T shfl_xor_sync(T, int); template __forceinline__ __device__ T rsqrt(T); +template +__forceinline__ __device__ T tanh(T); + // sepicalization // TODO (yiakwy) : add equivalent asm version for fast exp computation (polynomial approx) @@ -80,6 +83,11 @@ inline __device__ half exp2(half x) { return hexp2(x); } +template<> +inline __device__ half2 exp2(half2 x) { + return h2exp2(x); +} + template<> __forceinline__ __device__ float log2(float x) { return log2f(x); @@ -93,6 +101,8 @@ inline __device__ half log2(half x) { template<> __forceinline__ __device__ float rcp(float x) { // TODO (yiakwy) : __frcp_rn is not supported in ROCM 6.2 + // TODO (yiakwy) : accelerate __frcp_rn for float input with fast rcp algorithm + // return __frcp_rn(x); return 1.f / x; } @@ -129,6 +139,32 @@ __forceinline__ __device__ float rsqrt(float x) { return rsqrtf(x); } +template<> +__forceinline__ __device__ float tanh(float x) { + return tanhf(x); +} + +template<> +__forceinline__ __device__ half tanh(half x) { + // TODO (yiakwy) : SDK 6.2 does not define htanh + /* + return htanh(x); + */ + // TODO (yiakwy) : optimize this with fast polynomial fitting + half a = hexp(x); + half b = hexp(-x); + return (a - b) / (a + b); +} + +template<> +__forceinline__ __device__ half2 tanh(half2 x) { + // TODO (yiakwy) : SDK 6.2 does not define h2tanh + /* + return h2tanh(x); + */ + return half2{tanh(x.x), tanh(x.y)}; +} + } // amdgpu /*! @@ -139,6 +175,13 @@ __forceinline__ __device__ float ptx_exp2(float x) { return amdgpu::exp2(x); } +__forceinline__ __device__ half ptx_exp2(half x) { + return amdgpu::exp2(x); +} + +__forceinline__ __device__ half2 ptx_exp2(half2 x) { + return amdgpu::exp2(x); +} /*! * \brief Wrapper of PTX lg2.approx instruction, which computes log2(x) @@ -167,6 +210,14 @@ __forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) { return amdgpu::shfl_xor_sync(x, lane_mask); } +__forceinline__ __device__ half shfl_xor_sync(half x, int lane_mask) { + return amdgpu::shfl_xor_sync(x, lane_mask); +} + +__forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) { + return amdgpu::shfl_xor_sync(x, lane_mask); +} + /*! * \brief Wrapper of PTX rsqrt approximation instruction, which computes 1/sqrt(x) * \param x input @@ -175,6 +226,18 @@ __forceinline__ __device__ float rsqrt(float x) { return amdgpu::rsqrt(x); } +__forceinline__ __device__ float tanh(float x) { + return amdgpu::tanh(x); +} + +__forceinline__ __device__ half tanh(half x) { + return amdgpu::tanh(x); +} + +__forceinline__ __device__ half2 tanh(half2 x) { + return amdgpu::tanh(x); +} + #else // NVIDIA PTX exlusive codes diff --git a/include/flashinfer/mma.cuh b/include/flashinfer/mma.cuh index 3c54a3f1..70d8c1f0 100644 --- a/include/flashinfer/mma.cuh +++ b/include/flashinfer/mma.cuh @@ -16,11 +16,21 @@ #ifndef FLASHINFER_MMA_CUH_ #define FLASHINFER_MMA_CUH_ +#if USE_ROCM + +#include "flashinfer/hip_cuda_type_utils.h" +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#else + #include #include #include #include +#endif // USE_ROCM + #include namespace flashinfer { diff --git a/include/flashinfer/permuted_smem.cuh b/include/flashinfer/permuted_smem.cuh index 0b0800d0..0958547c 100644 --- a/include/flashinfer/permuted_smem.cuh +++ b/include/flashinfer/permuted_smem.cuh @@ -16,12 +16,24 @@ #ifndef FLASHINFER_PERMUTED_SMEM_CUH_ #define FLASHINFER_PERMUTED_SMEM_CUH_ +#if USE_ROCM + +#include "flashinfer/hip_cuda_type_utils.h" +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#include + +#else + #include #include #include #include +#endif + #include "cp_async.cuh" #include "mma.cuh" diff --git a/include/flashinfer/prefill_attention_decl.cuh b/include/flashinfer/prefill_attention_decl.cuh index 46b15209..5158e4c8 100644 --- a/include/flashinfer/prefill_attention_decl.cuh +++ b/include/flashinfer/prefill_attention_decl.cuh @@ -16,8 +16,18 @@ #ifndef FLASHINFER_PREFILL_ATTENTION_DECL_CUH_ #define FLASHINFER_PREFILL_ATTENTION_DECL_CUH_ +#ifdef USE_ROCM + +#include +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#else + #include +#endif + #include "attention/handler.cuh" #include "attention/logits_post_hook.cuh" #include "attention/mask.cuh" diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index f01f90a8..1900f38d 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -32,7 +32,7 @@ #include #include -#endif +#endif // USE_ROCM #include #include diff --git a/include/flashinfer/vec_dtypes.cuh b/include/flashinfer/vec_dtypes.cuh index 69b19ec7..59e6c486 100644 --- a/include/flashinfer/vec_dtypes.cuh +++ b/include/flashinfer/vec_dtypes.cuh @@ -49,11 +49,14 @@ namespace flashinfer { #undef FLASHINFER_FP8_ENABLED #endif +// TODO (yiakwy) : add support bf16 +// TODO (yiakwy) : add support fp16 + #endif #define FLASHINFER_INLINE inline __attribute__((always_inline)) __device__ -// TODO (yiakwy) : add support in HIP +// TODO (yiakwy) : add support in HIP, hip_cuda_type_utils.h for details #if (__CUDACC_VER_MAJOR__ * 10000 + __CUDACC_VER_MINOR__ * 100 < 120400) && \ (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ < 800)) // CUDA version < 12.4 and GPU architecture < 80 @@ -121,6 +124,7 @@ struct vec_cast { } else { #pragma unroll for (size_t i = 0; i < vec_size / 2; ++i) { + // TODO (yiakwy) : NVIDIA/AMD does not implement real 32 bits half2 to 2xfloat in hardware, this does not accelerate ((float2*)dst)[i] = __half22float2(((half2*)src)[i]); } } @@ -427,11 +431,15 @@ struct vec_t { FLASHINFER_INLINE float_t* ptr(); }; +// src (float) -> dst (half) : float, __half, 8UL template FLASHINFER_INLINE void cast_from_impl(vec_t& dst, const vec_t& src) { + // src (float) -> dst (half) + /* vec_cast::cast( dst.ptr(), const_cast*>(&src)->ptr()); + */ } template @@ -686,6 +694,7 @@ struct vec_t<__nv_fp8_e4m3, vec_size> { ((uint4*)ptr)[i] = data[i]; } } + template FLASHINFER_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); @@ -762,6 +771,7 @@ struct vec_t<__nv_fp8_e5m2, 2> { FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template FLASHINFER_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); @@ -809,6 +819,7 @@ struct vec_t<__nv_fp8_e5m2, 4> { FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template FLASHINFER_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); @@ -857,6 +868,7 @@ struct vec_t<__nv_fp8_e5m2, 8> { FLASHINFER_INLINE void fill(__nv_fp8_e5m2 val); FLASHINFER_INLINE void load(const __nv_fp8_e5m2* ptr); FLASHINFER_INLINE void store(__nv_fp8_e5m2* ptr) const; + template FLASHINFER_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); @@ -869,6 +881,7 @@ struct vec_t<__nv_fp8_e5m2, 8> { FLASHINFER_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src); }; @@ -934,6 +947,7 @@ struct vec_t<__nv_fp8_e5m2, vec_size> { ((uint4*)ptr)[i] = data[i]; } } + template FLASHINFER_INLINE void cast_from(const vec_t& src) { cast_from_impl(*this, src); @@ -946,6 +960,7 @@ struct vec_t<__nv_fp8_e5m2, vec_size> { FLASHINFER_INLINE void cast_store(T* ptr) const { cast_store_impl(ptr, *this); } + FLASHINFER_INLINE static void memcpy(__nv_fp8_e5m2* dst, const __nv_fp8_e5m2* src) { #pragma unroll for (size_t i = 0; i < vec_size / 16; ++i) { @@ -1070,6 +1085,53 @@ FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { *((uint2*)dst) = *((uint2*)src); } +//**** test +// half x 8 +template <> +struct vec_t { + uint4 data; + + FLASHINFER_INLINE half& operator[](size_t i) { return ((half*)(&data))[i]; } + FLASHINFER_INLINE const half& operator[](size_t i) const { return ((const half*)(&data))[i]; } + FLASHINFER_INLINE half* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(half val); + FLASHINFER_INLINE void load(const half* ptr); + FLASHINFER_INLINE void store(half* ptr) const; + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(half* dst, const half* src); +}; + +FLASHINFER_INLINE void vec_t::fill(half val) { + *(half2*)(&data.x) = make_half2(val, val); + *(half2*)(&data.y) = make_half2(val, val); + *(half2*)(&data.z) = make_half2(val, val); + *(half2*)(&data.w) = make_half2(val, val); +} + +FLASHINFER_INLINE void vec_t::load(const half* ptr) { + data = *((uint4*)ptr); +} + +FLASHINFER_INLINE void vec_t::store(half* ptr) const { + *((uint4*)ptr) = data; +} + +FLASHINFER_INLINE void vec_t::memcpy(half* dst, const half* src) { + *((uint4*)dst) = *((uint4*)src); +} +//**** test end + // half x 8 or more template @@ -1420,6 +1482,62 @@ struct vec_t { } }; +// ***** test + +/* +template <> +struct vec_t; + */ + +template <> +struct vec_t { + unsigned vec_size = 8; + float4 data[2]; + + FLASHINFER_INLINE float& operator[](size_t i) { return ((float*)(data))[i]; } + FLASHINFER_INLINE const float& operator[](size_t i) const { return ((const float*)(data))[i]; } + FLASHINFER_INLINE float* ptr() { return reinterpret_cast(&data); } + FLASHINFER_INLINE void fill(float val) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = make_float4(val, val, val, val); + } + } + FLASHINFER_INLINE void load(const float* ptr) { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + data[i] = ((float4*)ptr)[i]; + } + } + FLASHINFER_INLINE void store(float* ptr) const { +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)ptr)[i] = data[i]; + } + } + template + FLASHINFER_INLINE void cast_from(const vec_t& src) { + cast_from_impl(*this, src); + } + template + FLASHINFER_INLINE void cast_load(const T* ptr) { + cast_load_impl(*this, ptr); + } + template + FLASHINFER_INLINE void cast_store(T* ptr) const { + cast_store_impl(ptr, *this); + } + FLASHINFER_INLINE static void memcpy(float* dst, const float* src) { + const unsigned vec_size = 8; +#pragma unroll + for (size_t i = 0; i < vec_size / 4; ++i) { + ((float4*)dst)[i] = ((float4*)src)[i]; + } + } +}; + +// ****** test end + } // namespace flashinfer #endif // VEC_DTYPES_CUH_ diff --git a/include/hip/barrier.h b/include/hip/barrier.h new file mode 100644 index 00000000..ce213665 --- /dev/null +++ b/include/hip/barrier.h @@ -0,0 +1,80 @@ +#pragma once + +#include +#include + +#include + +#include + +// libcxx/include/barrier.h +#include + + +namespace libhipcxx { + using namespace hip; + + using thread_scope = hip::thread_scope; + +template +class pipeline; + +enum async_contract_fulfillment +{ + none, + async +}; + +template +static inline __device__ constexpr bool __unused(_Ty&&...) {return true;} + +template +class barrier : public hip::std::__barrier_base<_CompletionF, _Scope> { +public: + barrier() = default; + + barrier(const barrier &) = delete; + barrier & operator=(const barrier &) = delete; + + __host__ __device__ constexpr + barrier(ptrdiff_t __expected, _CompletionF __completion = _CompletionF()) + : hip::std::__barrier_base<_CompletionF, _Scope>(__expected, __completion) { + } + + __host__ __device__ constexpr + friend void init(barrier * __b, ptrdiff_t __expected) { + new (__b) barrier(__expected); + } + + __host__ __device__ constexpr + friend void init(barrier * __b, ptrdiff_t __expected, _CompletionF __completion) { + new (__b) barrier(__expected, __completion); + } +}; + +// TODO (yiakwy) : verification, see MI300X ISA +__device__ void __trap(void) { __asm__ __volatile__("s_trap;"); } + +__device__ void __wait_all(void) { __asm__ volatile("s_barrier" ::); } + +// TODO (yiakwy) : __memorycpy_arrive_on_impl interface API for MI300x +struct __memcpy_arrive_on_impl { + template= thread_scope_block) && hip::std::is_same<_CompF, hip::std::__empty_completion>::value> + static inline __host__ __device__ void __arrive_on(barrier<_Scope, _CompF> & __barrier, async_contract_fulfillment __is_async) { + // TODO (yiakwy) : add impl for MI300X + // see details in // see details https://nvidia.github.io/cccl/libcudacxx/extended_api/memory_model.html + if (__is_async == async_contract_fulfillment::async) { + __wait_all(); + } + } + + template + static inline __host__ __device__ void __arrive_on(pipeline<_Scope> & __pipeline, async_contract_fulfillment __is_async) { + // pipeline does not sync on memcpy_async, defeat pipeline purpose otherwise + __unused(__pipeline); + __unused(__is_async); + } +}; + + +} // namespace libhipcxx \ No newline at end of file diff --git a/include/hip/pipeline.h b/include/hip/pipeline.h new file mode 100644 index 00000000..3d8fd3f5 --- /dev/null +++ b/include/hip/pipeline.h @@ -0,0 +1,277 @@ +// TODO (yiakwy) : to be integrated into libhipcxx; POC purpose, will be moved out soon +#pragma once + +// TODO (yiakwy) : only mi300x supported, other archs will be supported soon +#ifndef HIP_ENABLE_WARP_SYNC_BUILTINS +#define HIP_ENABLE_WARP_SYNC_BUILTINS +#endif + +#include + +// helpers +// ported from llvm project + +template +static __device__ inline +unsigned long long __match_any_sync(MaskT mask, T value) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __match_any(value) & mask; +} + +#ifdef HIP_ENABLE_WARP_SYNC_BUILTINS +static __device__ inline +unsigned long long __activemask() { + return __ballot(true); +} +#endif // HIP_ENABLE_WARP_SYNC_BUILTINS + +// ported from in SDK 6.2 +struct __pipeline_asm_helper { + __device__ static inline + uint32_t __lane_id() { + return __builtin_amdgcn_mbcnt_hi( + -1, __builtin_amdgcn_mbcnt_lo(-1, 0)); + } +}; + +__device__ static inline unsigned int __ffs(uint64_t input) { + return ( input == 0 ? -1 : __builtin_ctzll(input) ) + 1; +} + +// TODO (yiakwy) : these headers may not find relevant functions +#ifndef HIP_ENABLE_WARP_SYNC_BUILTINS +#define HIP_ENABLE_WARP_SYNC_BUILTINS +#endif +#include +#include + +#include + +// install from libhipcxx +#include +// #include + +#include "hip/barrier.h" + +#include "flashinfer/hip_warp_sync_functions.h" + + +namespace libhipcxx { + using namespace hip; + + using thread_scope = hip::thread_scope; + + template + class barrier; + + /* + template + using barrier = hip::barrier<_Scope>; + */ + + /* +enum thread_scope { + thread_scope_system = __ATOMIC_SYSTEM, + thread_scope_device = __ATOMIC_DEVICE, + thread_scope_block = __ATOMIC_BLOCK, + thread_scope_thread = __ATOMIC_THREAD +}; + */ + template + struct __pipeline_stage { + barrier<_Scope> __produced; + barrier<_Scope> __consumed; + }; + + template + class pipeline; + + // AMD uses 64 (__AMDGCN_WAVEFRONT_SIZE) threads wave, while NVIDIA uses 32 threads wave + using WAVE_MASK_TYPE=uint64_t; + + // TODO (yiakwy) : implement hip/pipline + // We mimic a pair barriers used by NVIDIA to synchronize device threads accessing to shared memroy or registers. + // + // Consumer threads wait on “consumer barrier” (no need proceed to the barrier) until data is available and arrive to "producer barriers" + // to notify the shared resources can be reuse. + // + // Once data is prepared, producer threads arrive to "consumer barrier" to notify consumer threads and wait on "producer barrier" (no need + // proceed to the barrier) to continue data production loop. + // + // Details can be found here : https://eel.is/c++draft/thread.barrier#class-1.3 + template + class pipeline { + private: + uint8_t __head; + uint8_t __tail; + const uint8_t __stages_count; + bool __consumed_phase_parity; + bool __produced_phase_parity; + bool __active; + const bool __partitioned; + char * const __shared_state; + + public: + // forbidden R-Val copies + pipeline(pipeline &&) = default; + pipeline & operator=(pipeline &&) = delete; + + pipeline(); + + void init() { + + } + + void copy() { + + } + + void clear() { + + } + + + __host__ __device__ ~pipeline() { + if (__active) quit(); + }; + + pipeline& operator=(pipeline const&) = delete; + + __host__ __device__ void producer_acquire(); + + __host__ __device__ void producer_commit(); + + __host__ __device__ void consumer_wait(); + + template + __host__ __device__ bool consumer_wait_for(hip::std::chrono::duration const& duration); + + template + __host__ __device__ + bool consumer_wait_until(hip::std::chrono::time_point const& time_point); + + __host__ __device__ void consumer_release(); + + __host__ __device__ bool quit(); + + private: + atomic * __shared_state_get_refcount() { + ptrdiff_t __refcount_offset = __stages_count * sizeof(__pipeline_stage<_Scope>); + return reinterpret_cast*>(__shared_state + __refcount_offset); + } + + __pipeline_stage<_Scope> * __shared_state_get_stage(uint8_t __stage) + { + ptrdiff_t __stage_offset = __stage * sizeof(__pipeline_stage<_Scope>); + return reinterpret_cast<__pipeline_stage<_Scope>*>(__shared_state + __stage_offset); + } + + }; + +} // namespace libhipcxx + +// TODO (yiakwy) : move implementation specialization to implementation folder (e.g. : impl/pipeline ) +namespace libhipcxx { + +// TODO (yiakwy) +template +pipeline<_Scope>::pipeline() { + +} + +template +__host__ __device__ +bool pipeline<_Scope>::quit() { + bool __elected; + WAVE_MASK_TYPE __sub_count; + const WAVE_MASK_TYPE __match_mask = __match_any_sync(__activemask(), reinterpret_cast(__shared_state_get_refcount())); + const WAVE_MASK_TYPE __elected_id = __ffs(__match_mask) - 1; + __elected = (__pipeline_asm_helper::__lane_id() == __elected_id); + __sub_count = __popc(__match_mask); + + __elected = true; + __sub_count = 1; + + bool __released = false; + if (__elected) { + const WAVE_MASK_TYPE __old = __shared_state_get_refcount()->fetch_sub(__sub_count); + const bool __last = (__old == __sub_count); + if (__last) { + for (uint8_t __stage = 0; __stage < __stages_count; ++__stage) { + __shared_state_get_stage(__stage)->__produced.~barrier(); + __shared_state_get_stage(__stage)->__consumed.~barrier(); + } + __released = true; + } + } + __active = false; + return __released; +} + +template +__host__ __device__ +void pipeline<_Scope>::producer_acquire() { + // wait for producer barrier that used resources can be reused + barrier<_Scope> & __stage_barrier = __shared_state_get_stage(__head)->__consumed; + __stage_barrier.wait_parity(__consumed_phase_parity); +} + +template +__host__ __device__ +void pipeline<_Scope>::producer_commit() { + // arrive to consumer barrier to notfiy the sources are available to use + barrier<_Scope> & __stage_barrier = __shared_state_get_stage(__head)->__produced; + __memcpy_arrive_on_impl::__arrive_on(__stage_barrier, async_contract_fulfillment::async); + (void)__stage_barrier.arrive(); + if (++__head == __stages_count) { + __head = 0; + __consumed_phase_parity = !__consumed_phase_parity; + } +} + +template +__host__ __device__ +void pipeline<_Scope>::consumer_wait() { + // wait for consumer barrier that data is available + barrier<_Scope> & __stage_barrier = __shared_state_get_stage(__tail)->__produced; + __stage_barrier.wait_parity(__produced_phase_parity); +} + +template +__host__ __device__ +void pipeline<_Scope>::consumer_release() { + // arrive producer barrier that the resources can be reused + (void)__shared_state_get_stage(__tail)->__consumed.arrive(); + if (++__tail == __stages_count) { + __tail = 0; + __produced_phase_parity = !__produced_phase_parity; + } +} + +template +template +__host__ __device__ +bool pipeline<_Scope>::consumer_wait_for(const hip::std::chrono::duration<_Rep, _Period> & __duration) { + // wait for at most __duration for producer to arrive consumer barrier + barrier<_Scope> & __stage_barrier = __shared_state_get_stage(__tail)->__produced; + return hip::std::__libcpp_thread_poll_with_backoff( + hip::std::__barrier_poll_tester_parity>( + &__stage_barrier, + __produced_phase_parity), + hip::std::chrono::duration_cast(__duration) + ); +} + +template +template +__host__ __device__ +bool pipeline<_Scope>::consumer_wait_until(const hip::std::chrono::time_point<_Clock, _Duration> & __time_point) { + return consumer_wait_for(__time_point - _Clock::now()); +} + +} // namespace libhipcxx \ No newline at end of file diff --git a/src/test_single_decode.cu b/src/test_single_decode.cu index b316486e..5e48e105 100644 --- a/src/test_single_decode.cu +++ b/src/test_single_decode.cu @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include // isnan used + #include #include @@ -104,6 +106,7 @@ TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestFP16) { #ifdef FLASHINFER_ENABLE_BF16 TEST(FlashInferCorrectnessTest, SingleDecodeKernelCorrectnessTestBF16) { + // TODO (yiakwy) TestSingleDecodeKernelCorrectness(); } #endif From 76ca1ca77aecf1443606b8776037736d1a51befc Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Sun, 22 Sep 2024 15:47:20 +0000 Subject: [PATCH 03/15] reproduce decode test --- include/flashinfer/frag_layout_swizzle.cuh | 6 +++--- src/test_single_decode.cu | 10 +++++----- 2 files changed, 8 insertions(+), 8 deletions(-) diff --git a/include/flashinfer/frag_layout_swizzle.cuh b/include/flashinfer/frag_layout_swizzle.cuh index 39bebd7a..b7930582 100644 --- a/include/flashinfer/frag_layout_swizzle.cuh +++ b/include/flashinfer/frag_layout_swizzle.cuh @@ -16,7 +16,7 @@ #ifndef FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_ #define FLASHINFER_FRAG_LAYOUT_SWIZZLE_CUH_ -#if USE_ROCM +#ifdef USE_ROCM #include @@ -24,7 +24,7 @@ #include -#endif +#endif // USE_ROCM #include @@ -70,7 +70,7 @@ __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x10); #else tmp = __shfl_xor_sync(0xffffffff, x, 0x10); - #endifgi + #endif x = __byte_perm(x, tmp, ((threadIdx.x & 0x10) == 0) ? 0x5410 : 0x3276); return x; diff --git a/src/test_single_decode.cu b/src/test_single_decode.cu index 5e48e105..3e51bfd4 100644 --- a/src/test_single_decode.cu +++ b/src/test_single_decode.cu @@ -83,12 +83,12 @@ void _TestDecodingKernelCorrectness(size_t num_qo_heads, size_t num_kv_heads, si template void TestSingleDecodeKernelCorrectness() { for (size_t num_qo_heads : {32}) { - for (size_t num_kv_heads : {4, 8, 32}) { + for (size_t num_kv_heads : {4}) {// for (size_t num_kv_heads : {4, 8, 32}) { for (size_t seq_len : - {1, 3, 9, 27, 81, 129, 257, 512, 1024, 2048, 4096, 8192, 16384, 32768}) { - for (size_t head_dim : {64, 128, 256}) { - for (unsigned int kv_layout : {0U, 1U}) { - for (unsigned int pos_encoding_mode : {0U, 1U}) { + {1}) { // {1, 3, 9, 27, 81, 129, 257, 512, 1024, 2048, 4096, 8192, 16384, 32768}) { + for (size_t head_dim : {64}) {// for (size_t head_dim : {64, 128, 256}) { + for (unsigned int kv_layout : {0U}) {// for (unsigned int kv_layout : {0U, 1U}) { + for (unsigned int pos_encoding_mode : {0U}) { // for (unsigned int pos_encoding_mode : {0U, 1U}) { _TestDecodingKernelCorrectness(num_qo_heads, num_kv_heads, seq_len, head_dim, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode)); From 512a45f22d6f84c10e0a85b252162a7963bbf0f2 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Mon, 23 Sep 2024 07:20:37 +0000 Subject: [PATCH 04/15] add support of fast_div --- CMakeLists.txt | 11 ++++- include/flashinfer/fastdiv.cuh | 13 ++++++ include/flashinfer/frag_layout_swizzle.cuh | 48 ++++++---------------- src/vec_dtypes.h | 0 4 files changed, 35 insertions(+), 37 deletions(-) delete mode 100644 src/vec_dtypes.h diff --git a/CMakeLists.txt b/CMakeLists.txt index 440cfda0..58170d93 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -271,7 +271,7 @@ endforeach(head_dim) # TODO (yiakwy) : override add_libraries, rename sources if (HIP_FOUND) - set_source_files_properties(${single_decode_kernels_sr} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + set_source_files_properties(${single_decode_kernels_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) set_source_files_properties(${batch_decode_kernels_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) hip_add_library(decode_kernels STATIC ${single_decode_kernels_src} ${batch_decode_kernels_src}) elseif(HIP_FOUND) @@ -634,7 +634,14 @@ endif(FLASHINFER_TVM_BINDING) if(FLASHINFER_FASTDIV_TEST) message(STATUS "Compile fastdiv test.") file(GLOB_RECURSE TEST_FASTDIV_SRCS ${PROJECT_SOURCE_DIR}/src/test_fastdiv.cu) - add_executable(test_fastdiv ${TEST_FASTDIV_SRCS}) + + if (HIP_FOUND) + set_source_files_properties(${TEST_FASTDIV_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(test_fastdiv ${TEST_FASTDIV_SRCS}) + else(HIP_FOUND) + add_executable(test_fastdiv ${TEST_FASTDIV_SRCS}) + endif() + target_include_directories(test_fastdiv PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_fastdiv PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) target_link_libraries(test_fastdiv PRIVATE gtest gtest_main) diff --git a/include/flashinfer/fastdiv.cuh b/include/flashinfer/fastdiv.cuh index b605a2c8..53b334f4 100644 --- a/include/flashinfer/fastdiv.cuh +++ b/include/flashinfer/fastdiv.cuh @@ -21,6 +21,19 @@ #define FLASHINFER_FASTDIV_CUH_ #include +#ifdef USE_ROCM + +#include + +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#else + +#include + +#endif // USE_ROCM + namespace flashinfer { struct uint_fastdiv { diff --git a/include/flashinfer/frag_layout_swizzle.cuh b/include/flashinfer/frag_layout_swizzle.cuh index b7930582..f59b5826 100644 --- a/include/flashinfer/frag_layout_swizzle.cuh +++ b/include/flashinfer/frag_layout_swizzle.cuh @@ -20,58 +20,36 @@ #include +#ifndef FULL_MASK +#define FULL_MASK 0xffffffffffffffff +#endif + #else #include +#ifndef FULL_MASK +#define FULL_MASK 0xffffffff +#endif + #endif // USE_ROCM #include __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b(uint32_t x) { - // TODO (yiakwy) : override __shfl_xor_sync for 32 bit mask - #ifdef USE_ROCM - uint32_t tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x1); - #else - uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x1); - #endif - + uint32_t tmp = __shfl_xor_sync(FULL_MASK, x, 0x1); x = __byte_perm(x, tmp, ((threadIdx.x & 0x1) == 0) ? 0x5410 : 0x3276); - - #ifdef USE_ROCM - tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x2); - #else - tmp = __shfl_xor_sync(0xffffffff, x, 0x2); - #endif - + tmp = __shfl_xor_sync(FULL_MASK, x, 0x2); x = __byte_perm(x, tmp, ((threadIdx.x & 0x2) == 0) ? 0x5410 : 0x3276); return x; } __device__ __forceinline__ uint32_t frag_layout_swizzle_16b_to_8b_trans(uint32_t x) { - // TODO (yiakwy) : override __shfl_xor_sync for 32 bit mask - #ifdef USE_ROCM - uint32_t tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x4); - #else - uint32_t tmp = __shfl_xor_sync(0xffffffff, x, 0x4); - #endif - + uint32_t tmp = __shfl_xor_sync(FULL_MASK, x, 0x4); x = __byte_perm(x, tmp, ((threadIdx.x & 0x4) == 0) ? 0x6420 : 0x3175); - - #ifdef USE_ROCM - tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x8); - #else - tmp = __shfl_xor_sync(0xffffffff, x, 0x8); - #endif - + tmp = __shfl_xor_sync(FULL_MASK, x, 0x8); x = __byte_perm(x, tmp, ((threadIdx.x & 0x8) == 0) ? 0x5410 : 0x3276); - - #ifdef USE_ROCM - tmp = __shfl_xor_sync(0xffffffffffffffff, x, 0x10); - #else - tmp = __shfl_xor_sync(0xffffffff, x, 0x10); - #endif - + tmp = __shfl_xor_sync(FULL_MASK, x, 0x10); x = __byte_perm(x, tmp, ((threadIdx.x & 0x10) == 0) ? 0x5410 : 0x3276); return x; } diff --git a/src/vec_dtypes.h b/src/vec_dtypes.h deleted file mode 100644 index e69de29b..00000000 From d053266619ca2b5eca85410ee7abb142f2021ba5 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Mon, 23 Sep 2024 07:42:13 +0000 Subject: [PATCH 05/15] add support of test_fast_dequant (surpress fp8 tests in this PR) --- CMakeLists.txt | 9 ++++++++- include/flashinfer/mma.cuh | 2 +- include/flashinfer/norm.cuh | 8 ++++---- include/flashinfer/permuted_smem.cuh | 2 +- include/flashinfer/pos_enc.cuh | 8 ++++---- include/flashinfer/vec_dtypes.cuh | 2 +- src/test_fast_dequant.cu | 12 ++++++++++++ 7 files changed, 31 insertions(+), 12 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 58170d93..4726b91d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -654,7 +654,14 @@ endif(FLASHINFER_FASTDIV_TEST) if(FLASHINFER_FASTDEQUANT_TEST) message(STATUS "Compile fast dequant test.") file(GLOB_RECURSE TEST_FAST_DEQUANT_SRCS ${PROJECT_SOURCE_DIR}/src/test_fast_dequant.cu) - add_executable(test_fast_dequant ${TEST_FAST_DEQUANT_SRCS}) + + if (HIP_FOUND) + set_source_files_properties(${TEST_FAST_DEQUANT_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(test_fast_dequant ${TEST_FAST_DEQUANT_SRCS}) + else(HIP_FOUND) + add_executable(test_fast_dequant ${TEST_FAST_DEQUANT_SRCS}) + endif() + target_include_directories(test_fast_dequant PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_fast_dequant PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) target_link_libraries(test_fast_dequant PRIVATE gtest gtest_main) diff --git a/include/flashinfer/mma.cuh b/include/flashinfer/mma.cuh index 70d8c1f0..e3a25894 100644 --- a/include/flashinfer/mma.cuh +++ b/include/flashinfer/mma.cuh @@ -16,7 +16,7 @@ #ifndef FLASHINFER_MMA_CUH_ #define FLASHINFER_MMA_CUH_ -#if USE_ROCM +#ifdef USE_ROCM #include "flashinfer/hip_cuda_type_utils.h" // CUDA API Portable interfaces diff --git a/include/flashinfer/norm.cuh b/include/flashinfer/norm.cuh index 26682101..aa2c1c1a 100644 --- a/include/flashinfer/norm.cuh +++ b/include/flashinfer/norm.cuh @@ -109,7 +109,7 @@ cudaError_t RMSNorm(T* input, T* weight, T* output, uint32_t batch_size, uint32_ DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = RMSNormKernel; - #if USE_ROCM + #ifdef USE_ROCM FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); @@ -210,7 +210,7 @@ cudaError_t FusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batch_siz DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = FusedAddRMSNormKernel; - #if USE_ROCM + #ifdef USE_ROCM FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); @@ -301,7 +301,7 @@ cudaError_t GemmaRMSNorm(T* input, T* weight, T* output, uint32_t batch_size, ui DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = GemmaRMSNormKernel; - #if USE_ROCM + #ifdef USE_ROCM FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); @@ -402,7 +402,7 @@ cudaError_t GemmaFusedAddRMSNorm(T* input, T* residual, T* weight, uint32_t batc DISPATCH_ALIGNED_VEC_SIZE(vec_size, VEC_SIZE, { auto kernel = GemmaFusedAddRMSNormKernel; - #if USE_ROCM + #ifdef USE_ROCM FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); diff --git a/include/flashinfer/permuted_smem.cuh b/include/flashinfer/permuted_smem.cuh index 0958547c..90aff6f4 100644 --- a/include/flashinfer/permuted_smem.cuh +++ b/include/flashinfer/permuted_smem.cuh @@ -16,7 +16,7 @@ #ifndef FLASHINFER_PERMUTED_SMEM_CUH_ #define FLASHINFER_PERMUTED_SMEM_CUH_ -#if USE_ROCM +#ifdef USE_ROCM #include "flashinfer/hip_cuda_type_utils.h" // CUDA API Portable interfaces diff --git a/include/flashinfer/pos_enc.cuh b/include/flashinfer/pos_enc.cuh index 3cfc7732..b50f96fb 100644 --- a/include/flashinfer/pos_enc.cuh +++ b/include/flashinfer/pos_enc.cuh @@ -326,7 +326,7 @@ cudaError_t BatchQKApplyRotaryInPlace(DType* __restrict__ q, DType* __restrict__ (void*)&smooth_b, (void*)&rope_rcp_scale, (void*)&rope_rcp_theta}; - #if USE_ROCM + #ifdef USE_ROCM FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); @@ -374,7 +374,7 @@ cudaError_t BatchQKApplyLlama31RotaryInPlace( (void*)&smooth_b, (void*)&rope_rcp_scale, (void*)&rope_rcp_theta}; - #if USE_ROCM + #ifdef USE_ROCM FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); @@ -424,7 +424,7 @@ cudaError_t BatchQKApplyRotary(DType* __restrict__ q, DType* __restrict__ k, (void*)&smooth_b, (void*)&rope_rcp_scale, (void*)&rope_rcp_theta}; - #if USE_ROCM + #ifdef USE_ROCM FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); @@ -476,7 +476,7 @@ cudaError_t BatchQKApplyLlama31Rotary(DType* __restrict__ q, DType* __restrict__ (void*)&smooth_b, (void*)&rope_rcp_scale, (void*)&rope_rcp_theta}; - #if USE_ROCM + #ifdef USE_ROCM FLASHINFER_CUDA_CALL(hipLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); #else FLASHINFER_CUDA_CALL(cudaLaunchKernel((void*)kernel, nblks, nthrs, args, 0, stream)); diff --git a/include/flashinfer/vec_dtypes.cuh b/include/flashinfer/vec_dtypes.cuh index 59e6c486..3f53639c 100644 --- a/include/flashinfer/vec_dtypes.cuh +++ b/include/flashinfer/vec_dtypes.cuh @@ -43,7 +43,7 @@ namespace flashinfer { #define FLASHINFER_HARDWARE_FP8_CONVERSION_ENABLED #endif -#if USE_ROCM +#ifdef USE_ROCM // TODO(yiakwy) : since roc fp8 is different from NV fp8, more efforts need to port functionalities #ifdef FLASHINFER_FP8_ENABLED #undef FLASHINFER_FP8_ENABLED diff --git a/src/test_fast_dequant.cu b/src/test_fast_dequant.cu index 2ffbdc1c..40290c21 100644 --- a/src/test_fast_dequant.cu +++ b/src/test_fast_dequant.cu @@ -57,6 +57,16 @@ void TestFastDequant() { } } +#ifdef USE_ROCM +// TODO(yiakwy) : since roc fp8 is different from NV fp8, more efforts need to port functionalities +#ifdef FLASHINFER_FP8_ENABLED +#undef FLASHINFER_FP8_ENABLED +#endif + +#endif + +#ifdef FLASHINFER_FP8_ENABLED + TEST(FlashInferCorrectnessTest, TestFastDequantCorrectnessE4M3ToFloat16) { TestFastDequant<__nv_fp8_e4m3, half>(); } @@ -69,3 +79,5 @@ TEST(FlashInferCorrectnessTest, TestFastDequantCorrectnessE4M3ToBFloat16) { TEST(FlashInferCorrectnessTest, TestFastDequantCorrectnessE5M2ToBFloat16) { TestFastDequant<__nv_fp8_e5m2, __nv_bfloat16>(); } + +#endif From 7583155203ed3dbf11afa649c42dfbe6b1f4510d Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Mon, 23 Sep 2024 10:49:00 +0000 Subject: [PATCH 06/15] add support of fast sampling --- CMakeLists.txt | 15 +- include/flashinfer/hip_warp_sync_functions.h | 13 ++ include/flashinfer/sampling.cuh | 156 +++++++++++++++++-- src/test_sampling.cu | 1 + 4 files changed, 172 insertions(+), 13 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 4726b91d..139ace23 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -22,7 +22,7 @@ if(NOT DEFINED HIP_CMAKE_PATH) # Add hip-config.cmake to CMake module search path. # set(HIP_CMAKE_PATH "${ROCM_HOME}/share/rocm/cmake" "${ROCM_HOME}/share/rocmcmakebuildtools/cmake/" CACHE PATH "Path to which HIP has been installed") # NOTE(yiakwy) : adding ${ROCM_HOME}/lib/cmake/hip has conflicts with 3rdparty/mscclpp - set(HIP_CMAKE_PATH "${ROCM_HOME}/lib/cmake/AMDDeviceLibs" "${ROCM_HOME}/lib/cmake/amd_comgr" "${ROCM_HOME}/lib/cmake/hsa-runtime64" CACHE PATH "Path to which HIP has been installed") + set(HIP_CMAKE_PATH "${ROCM_HOME}/lib/cmake/AMDDeviceLibs" "${ROCM_HOME}/lib/cmake/amd_comgr" "${ROCM_HOME}/lib/cmake/hsa-runtime64" "${ROCM_HOME}/lib/cmake/hipcub" "${ROCM_HOME}/lib/cmake/composable_kernel" CACHE PATH "Path to which HIP has been installed") message(WARNING "System variable HIP_CMAKE_PATH is nonexist, defaults to ${HIP_CMAKE_PATH}") set(CMAKE_PREFIX_PATH "${ROCM_HOME};${CMAKE_PREFIX_PATH}") @@ -556,10 +556,19 @@ if (FLASHINFER_SAMPLING) message(STATUS "Compile sampling kernel tests.") file(GLOB_RECURSE TEST_SAMPLING_SRCS ${PROJECT_SOURCE_DIR}/src/test_sampling.cu) - add_executable(test_sampling ${TEST_SAMPLING_SRCS}) + + set(THIS_BIANRY_LIB "") + if (HIP_FOUND) + set_source_files_properties(${TEST_SAMPLING_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(test_sampling ${TEST_SAMPLING_SRCS}) + # set(THIS_BIANRY_LIB "hipcub") + else (HIP_FOUND) + add_executable(test_sampling ${TEST_SAMPLING_SRCS}) + endif() + target_include_directories(test_sampling PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_sampling PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) - target_link_libraries(test_sampling PRIVATE gtest gtest_main) + target_link_libraries(test_sampling PRIVATE gtest gtest_main ${THIS_BIANRY_LIB}) target_compile_options(test_sampling PRIVATE -Wno-switch-bool) if (HIP_FOUND) diff --git a/include/flashinfer/hip_warp_sync_functions.h b/include/flashinfer/hip_warp_sync_functions.h index beffc12d..d7ec9bd4 100644 --- a/include/flashinfer/hip_warp_sync_functions.h +++ b/include/flashinfer/hip_warp_sync_functions.h @@ -89,4 +89,17 @@ T __shfl_sync(MaskT mask, T var, int srcLane, return __shfl(var, srcLane, width); } +template +__device__ inline +T __shfl_up_sync(MaskT mask, T var, unsigned int delta, + int width = __AMDGCN_WAVEFRONT_SIZE) { + static_assert( + __hip_internal::is_integral::value && sizeof(MaskT) == 8, + "The mask must be a 64-bit integer. " + "Implicitly promoting a smaller integer is almost always an error."); + __hip_adjust_mask_for_wave32(mask); + __hip_check_mask(mask); + return __shfl_up(var, delta, width); +} + #endif \ No newline at end of file diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index 4df2a006..ef04810f 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -16,9 +16,39 @@ #ifndef FLASHINFER_SAMPLING_CUH_ #define FLASHINFER_SAMPLING_CUH_ +#ifdef USE_ROCM + +#include + +#include + +#include +#include +#include + +#include + +#include "flashinfer/hip_warp_sync_functions.h" + +// CUDA API Portable interfaces +#include "flashinfer/hip_defs.h" + +#ifndef FULL_MASK +#define FULL_MASK 0xffffffffffffffff +#endif + +#else + #include #include #include + +#ifndef FULL_MASK +#define FULL_MASK 0xffffffff +#endif + +#endif + #include #include "math.cuh" @@ -29,8 +59,19 @@ namespace flashinfer { namespace sampling { +#ifdef USE_ROCM + +using namespace hipcub; + +// do hip namespace alias +namespace cub = hipcub; + +#else + using namespace cub; +#endif + #define DISPATCH_DETERMINISTIC(deterministic, DETERMINISTIC, ...) \ if (deterministic) { \ constexpr bool DETERMINISTIC = true; \ @@ -119,20 +160,20 @@ __device__ __forceinline__ void DeterministicInclusiveSum( #pragma unroll for (uint32_t offset = 1; offset < 32; offset *= 2) { - T tmp = __shfl_up_sync(0xffffffff, thread_exclusive_prefix_sum, offset); + T tmp = __shfl_up_sync(FULL_MASK, thread_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { thread_exclusive_prefix_sum += tmp; } } - T warp_sum = __shfl_sync(0xffffffff, thread_exclusive_prefix_sum, threadIdx.x | 0xffffffff); + T warp_sum = __shfl_sync(FULL_MASK, thread_exclusive_prefix_sum, threadIdx.x | FULL_MASK); if (threadIdx.x % 32 == 31) { thread_exclusive_prefix_sum = 0; } #pragma unroll for (uint32_t offset = 16; offset >= 1; offset /= 2) { - T tmp = __shfl_xor_sync(0xffffffff, thread_exclusive_prefix_sum, offset); + T tmp = __shfl_xor_sync(FULL_MASK, thread_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; } @@ -150,7 +191,7 @@ __device__ __forceinline__ void DeterministicInclusiveSum( #pragma unroll for (uint32_t offset = 1; offset < 32; offset *= 2) { - T tmp = __shfl_up_sync(0xffffffff, warp_exclusive_prefix_sum, offset); + T tmp = __shfl_up_sync(FULL_MASK, warp_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { warp_exclusive_prefix_sum += tmp; } @@ -162,7 +203,7 @@ __device__ __forceinline__ void DeterministicInclusiveSum( #pragma unroll for (uint32_t offset = 16; offset >= 1; offset /= 2) { - T tmp = __shfl_xor_sync(0xffffffff, warp_exclusive_prefix_sum, offset); + T tmp = __shfl_xor_sync(FULL_MASK, warp_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; } @@ -196,26 +237,42 @@ __device__ __forceinline__ void DeviceSamplingFromProb( prob_greater_than_threshold[j] = (prob_vec[j] > threshold) ? prob_vec[j] : T(0); valid[j] = prob_vec[j] > threshold && (i * BLOCK_THREADS + tx) * VEC_SIZE < d; } + + #ifdef USE_ROCM + T aggregate_local = + BlockReduce(temp_storage->block_prim.reduce) + .template Sum(prob_greater_than_threshold); + #else T aggregate_local = BlockReduce(temp_storage->block_prim.reduce) .Sum(prob_greater_than_threshold); + #endif + if (tx == 0) { temp_storage->data.block_aggregate.value = aggregate_local; } __syncthreads(); aggregate_local = temp_storage->data.block_aggregate.value; - if (aggregate + aggregate_local > u) { + #ifdef USE_ROCM + if constexpr (false) { + #else if constexpr (DETERMINISTIC) { + #endif + // (TODO) yiakwy : fix this function in ROCM platform DeterministicInclusiveSum( prob_greater_than_threshold, inclusive_cdf, temp_storage); } else { + #ifdef USE_ROCM + BlockScan(temp_storage->block_prim.scan) + .template InclusiveSum(prob_greater_than_threshold, inclusive_cdf); + #else BlockScan(temp_storage->block_prim.scan) .InclusiveSum(prob_greater_than_threshold, inclusive_cdf); - - __syncthreads(); + #endif } - + // NOTE (yiakwy) : sync all threads in a divergent block is dangerous, moved here + __syncthreads(); #pragma unroll for (uint32_t j = 0; j < VEC_SIZE; ++j) { greater_than_u[j] = inclusive_cdf[j] + aggregate > u; @@ -226,8 +283,16 @@ __device__ __forceinline__ void DeviceSamplingFromProb( BlockAdjacentDifference(temp_storage->block_prim.adj_diff) .SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); #else + + #ifdef USE_ROCM + // ROCM has deprecated FlagHeads API + BlockAdjacentDifference(temp_storage->block_prim.adj_diff) + .template SubtractLeft(greater_than_u, greater_than_u_diff, BoolDiffOp()); + #else BlockAdjacentDifference(temp_storage->block_prim.adj_diff) .FlagHeads(greater_than_u_diff, greater_than_u, BoolDiffOp(), 0); + #endif + #endif __syncthreads(); @@ -313,7 +378,7 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, if ((i * BLOCK_THREADS + tx) * VEC_SIZE < d) { probs_vec.load(probs + bx * d + (i * BLOCK_THREADS + tx) * VEC_SIZE); } - + // TODO (yiakwy) : kernel corruption here (2) DeviceSamplingFromProb(i, d, pivot, u, probs_vec, aggregate, &temp_storage); @@ -339,14 +404,22 @@ __global__ void TopKSamplingFromProbKernel(DType* probs, DType* uniform_samples, (probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; } + #ifdef USE_ROCM + aggregate_gt_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .template Sum(probs_gt_pivot); + #else aggregate_gt_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( temp_storage.block_prim.reduce_pair) .Sum(probs_gt_pivot); + #endif + if (tx == 0) { temp_storage.data.block_aggregate.pair = aggregate_gt_pivot; } __syncthreads(); } + q = temp_storage.data.block_aggregate.pair.value; if (temp_storage.data.block_aggregate.pair.count < k) { break; @@ -426,8 +499,14 @@ __global__ void TopPSamplingFromProbKernel(DType* probs, DType* uniform_samples, probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0); } + #ifdef USE_ROCM + aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) + .template Sum(probs_gt_pivot); + #else aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) .Sum(probs_gt_pivot); + #endif + if (tx == 0) { temp_storage.data.block_aggregate.value = aggregate_gt_pivot; } @@ -486,8 +565,15 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples, for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_[j] = probs_vec[j]; } + + #ifdef USE_ROCM + max_p = max(max_p, BlockReduce(temp_storage.block_prim.reduce) + .template Reduce(probs_, cub::Max())); + #else max_p = max(max_p, BlockReduce(temp_storage.block_prim.reduce) .Reduce(probs_, cub::Max())); + #endif + __syncthreads(); } if (tx == 0) { @@ -535,8 +621,14 @@ __global__ void MinPSamplingFromProbKernel(DType* probs, DType* uniform_samples, probs_gt_pivot[j] = (probs_vec[j] > pivot) ? probs_vec[j] : DType(0); } + #ifdef USE_ROCM + aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) + .template Sum(probs_gt_pivot); + #else aggregate_gt_pivot += BlockReduce(temp_storage.block_prim.reduce) .Sum(probs_gt_pivot); + #endif + if (tx == 0) { temp_storage.data.block_aggregate.value = aggregate_gt_pivot; } @@ -619,9 +711,16 @@ __global__ void TopKTopPSamplingFromProbKernel(DType* probs, DType* uniform_samp (probs_vec[j] > pivot && (i * BLOCK_THREADS + tx) * VEC_SIZE + j < d)}; } + #ifdef USE_ROCM + aggregate_gt_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .template Sum(probs_gt_pivot); + #else aggregate_gt_pivot += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( temp_storage.block_prim.reduce_pair) .Sum(probs_gt_pivot); + #endif + if (tx == 0) { temp_storage.data.block_aggregate.pair = aggregate_gt_pivot; } @@ -712,6 +811,7 @@ cudaError_t TopKSamplingFromProb(T* probs, T* uniform_samples, IdType* output, b DETERMINISTIC, T, IdType>; FLASHINFER_CUDA_CALL( cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size)); + // TODO (yiakwy) : kernel corruption here FLASHINFER_CUDA_CALL( cudaLaunchKernel((void*)kernel, nblks, nthrs, args, smem_size, stream)); })}); @@ -844,10 +944,18 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* for (uint32_t j = 0; j < VEC_SIZE; ++j) { probs_greater_than_pivot[j] = probs_vec[j]; } + + #ifdef USE_ROCM + threadlocal_max_val = + max(threadlocal_max_val, + BlockReduce(temp_storage.block_prim.reduce) + .template Reduce(probs_greater_than_pivot, cub::Max())); + #else threadlocal_max_val = max(threadlocal_max_val, BlockReduce(temp_storage.block_prim.reduce) .Reduce(probs_greater_than_pivot, cub::Max())); + #endif __syncthreads(); } if (tx == 0) { @@ -886,9 +994,16 @@ __global__ void TopPRenormProbKernel(DType* probs, DType* renormed_prob, DType* max_le_high = max(max_le_high, probs_vec[j]); } } + + #ifdef USE_ROCM + threadlocal_sum += + BlockReduce(temp_storage.block_prim.reduce) + .template Sum(probs_greater_than_pivot); + #else threadlocal_sum += BlockReduce(temp_storage.block_prim.reduce) .Sum(probs_greater_than_pivot); + #endif __syncthreads(); } min_gt_low = BlockReduce(temp_storage.block_prim.reduce) @@ -1009,9 +1124,16 @@ __global__ void TopKMaskLogitsKernel(DType* logits, DType* masked_logits, IdType max_le_high = max(max_le_high, logits_vec[j]); } } + + #ifdef USE_ROCM + threadlocal_count_sum += + BlockReduce(temp_storage.block_prim.reduce_int) + .template Sum(probs_greater_than_pivot_count); + #else threadlocal_count_sum += BlockReduce(temp_storage.block_prim.reduce_int) .Sum(probs_greater_than_pivot_count); + #endif __syncthreads(); } min_gt_low = @@ -1128,9 +1250,16 @@ __global__ void TopKRenormProbKernel(DType* probs, DType* renormed_prob, IdType* max_le_high = max(max_le_high, probs_vec[j]); } } + + #ifdef USE_ROCM + threadlocal_sum += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( + temp_storage.block_prim.reduce_pair) + .template Sum(probs_greater_than_pivot_pair); + #else threadlocal_sum += BlockReduce, BLOCK_THREADS, REDUCE_ALGORITHM>( temp_storage.block_prim.reduce_pair) .Sum(probs_greater_than_pivot_pair); + #endif __syncthreads(); } min_gt_low = @@ -1311,9 +1440,16 @@ __global__ void ChainSpeculativeSampling(DType* draft_probs, IdType* draft_token for (uint32_t j = 0; j < VEC_SIZE; ++j) { relu_q_minus_p[j] = max(q_vec[j] - p_vec[j], DType(0)); } + + #ifdef USE_ROCM + sum_relu_q_minus_p += + BlockReduce(temp_storage.block_prim.reduce) + .template Sum(relu_q_minus_p); + #else sum_relu_q_minus_p += BlockReduce(temp_storage.block_prim.reduce) .Sum(relu_q_minus_p); + #endif __syncthreads(); } if (tx == 0) { diff --git a/src/test_sampling.cu b/src/test_sampling.cu index 8a0a05fe..3a66acf8 100644 --- a/src/test_sampling.cu +++ b/src/test_sampling.cu @@ -1923,6 +1923,7 @@ TEST(FlashInferCorrectnessTests, TestTopPSamplingFromProbFP32) { TestTopPSamplingFromProb(); } + TEST(FlashInferCorrectnessTests, TestSamplingFromProbOneHotFP32) { TestSamplingFromProbOneHot(); } From f172fc076a3f40f887814e19560495138dc0d3ff Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Mon, 23 Sep 2024 11:11:57 +0000 Subject: [PATCH 07/15] fix wrong wavesize for DeterministicInclusiveSum --- include/flashinfer/sampling.cuh | 28 ++++++++++++++++------------ src/test_batch_decode.cu | 5 ----- 2 files changed, 16 insertions(+), 17 deletions(-) diff --git a/include/flashinfer/sampling.cuh b/include/flashinfer/sampling.cuh index ef04810f..b3b264a1 100644 --- a/include/flashinfer/sampling.cuh +++ b/include/flashinfer/sampling.cuh @@ -37,6 +37,8 @@ #define FULL_MASK 0xffffffffffffffff #endif +#define WAVE_SIZE 64 + #else #include @@ -47,6 +49,8 @@ #define FULL_MASK 0xffffffff #endif +#define WAVE_SIZE 32 + #endif #include @@ -159,7 +163,7 @@ __device__ __forceinline__ void DeterministicInclusiveSum( T thread_exclusive_prefix_sum = thread_sum; #pragma unroll - for (uint32_t offset = 1; offset < 32; offset *= 2) { + for (uint32_t offset = 1; offset < WAVE_SIZE; offset *= 2) { T tmp = __shfl_up_sync(FULL_MASK, thread_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { thread_exclusive_prefix_sum += tmp; @@ -167,12 +171,12 @@ __device__ __forceinline__ void DeterministicInclusiveSum( } T warp_sum = __shfl_sync(FULL_MASK, thread_exclusive_prefix_sum, threadIdx.x | FULL_MASK); - if (threadIdx.x % 32 == 31) { + if (threadIdx.x % WAVE_SIZE == WAVE_SIZE - 1) { thread_exclusive_prefix_sum = 0; } #pragma unroll - for (uint32_t offset = 16; offset >= 1; offset /= 2) { + for (uint32_t offset = WAVE_SIZE / 2; offset >= 1; offset /= 2) { T tmp = __shfl_xor_sync(FULL_MASK, thread_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { thread_exclusive_prefix_sum = tmp + thread_exclusive_prefix_sum; @@ -182,27 +186,27 @@ __device__ __forceinline__ void DeterministicInclusiveSum( } } - smem_prefix_sum[threadIdx.x / 32] = warp_sum; + smem_prefix_sum[threadIdx.x / WAVE_SIZE] = warp_sum; __syncthreads(); - if (threadIdx.x < 32) { + if (threadIdx.x < WAVE_SIZE) { T warp_exclusive_prefix_sum = - (threadIdx.x < BLOCK_THREADS / 32) ? smem_prefix_sum[threadIdx.x] : 0; + (threadIdx.x < BLOCK_THREADS / WAVE_SIZE) ? smem_prefix_sum[threadIdx.x] : 0; #pragma unroll - for (uint32_t offset = 1; offset < 32; offset *= 2) { + for (uint32_t offset = 1; offset < WAVE_SIZE; offset *= 2) { T tmp = __shfl_up_sync(FULL_MASK, warp_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { warp_exclusive_prefix_sum += tmp; } } - if (threadIdx.x % 32 == 31) { + if (threadIdx.x % WAVE_SIZE == WAVE_SIZE - 1) { warp_exclusive_prefix_sum = 0; } #pragma unroll - for (uint32_t offset = 16; offset >= 1; offset /= 2) { + for (uint32_t offset = WAVE_SIZE / 2; offset >= 1; offset /= 2) { T tmp = __shfl_xor_sync(FULL_MASK, warp_exclusive_prefix_sum, offset); if ((threadIdx.x + 1) % (offset * 2) == 0) { warp_exclusive_prefix_sum = tmp + warp_exclusive_prefix_sum; @@ -211,7 +215,7 @@ __device__ __forceinline__ void DeterministicInclusiveSum( warp_exclusive_prefix_sum = tmp; } } - if (threadIdx.x < BLOCK_THREADS / 32) { + if (threadIdx.x < BLOCK_THREADS / WAVE_SIZE) { smem_prefix_sum[threadIdx.x] = warp_exclusive_prefix_sum; } } @@ -219,7 +223,7 @@ __device__ __forceinline__ void DeterministicInclusiveSum( #pragma unroll for (uint32_t i = 0; i < VEC_SIZE; ++i) { - out_data[i] = smem_prefix_sum[threadIdx.x / 32] + thread_exclusive_prefix_sum + thread_data[i]; + out_data[i] = smem_prefix_sum[threadIdx.x / WAVE_SIZE] + thread_exclusive_prefix_sum + thread_data[i]; } } @@ -255,7 +259,7 @@ __device__ __forceinline__ void DeviceSamplingFromProb( aggregate_local = temp_storage->data.block_aggregate.value; if (aggregate + aggregate_local > u) { #ifdef USE_ROCM - if constexpr (false) { + if constexpr (true) { #else if constexpr (DETERMINISTIC) { #endif diff --git a/src/test_batch_decode.cu b/src/test_batch_decode.cu index d04b972a..26038038 100644 --- a/src/test_batch_decode.cu +++ b/src/test_batch_decode.cu @@ -208,9 +208,4 @@ TEST(FlashInferCorrectnessTest, TestBatchDecodeKernelCorrectnessE5M2) { TEST(FlashInferCorrectnessTest, TestCooperativeBatchDecodeKernelCorrectnessTestFP16) { TestCooperativeBatchDecodeKernelCorrectness(); -} - -int main(int argc, char **argv) { - ::testing::InitGoogleTest(&argc, argv); - return RUN_ALL_TESTS(); } \ No newline at end of file From c1144e4fa2cc057e9179f6a02548b2b193826977 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Mon, 23 Sep 2024 13:09:40 +0000 Subject: [PATCH 08/15] add support of test_page --- CMakeLists.txt | 21 ++++++++++++++++++--- 1 file changed, 18 insertions(+), 3 deletions(-) diff --git a/CMakeLists.txt b/CMakeLists.txt index 139ace23..23913d56 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -510,7 +510,14 @@ endif(FLASHINFER_PREFILL) if (FLASHINFER_PAGE) message(STATUS "Compile page kernel tests.") file(GLOB_RECURSE TEST_PAGE_SRCS ${PROJECT_SOURCE_DIR}/src/test_page.cu) - add_executable(test_page ${TEST_PAGE_SRCS}) + + if (HIP_FOUND) + set_source_files_properties(${TEST_PAGE_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(test_page ${TEST_PAGE_SRCS}) + else(HIP_FOUND) + add_executable(test_page ${TEST_PAGE_SRCS}) + endif() + target_include_directories(test_page PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_page PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) target_link_libraries(test_page PRIVATE gtest gtest_main) @@ -548,9 +555,17 @@ endif(FLASHINFER_CASCADE) if (FLASHINFER_SAMPLING) message(STATUS "Compile sampling kernel benchmarks.") file(GLOB_RECURSE BENCH_SAMPLING_SRCS ${PROJECT_SOURCE_DIR}/src/bench_sampling.cu) - add_executable(bench_sampling ${BENCH_SAMPLING_SRCS}) + + if (HIP_FOUND) + set_source_files_properties(${BENCH_SAMPLING_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(bench_sampling ${BENCH_SAMPLING_SRCS}) + target_include_directories(bench_sampling PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/hipbench) + else(HIP_FOUND) + add_executable(bench_sampling ${BENCH_SAMPLING_SRCS}) + target_include_directories(bench_sampling PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench) + endif() + target_include_directories(bench_sampling PRIVATE ${FLASHINFER_INCLUDE_DIR}) - target_include_directories(bench_sampling PRIVATE ${PROJECT_SOURCE_DIR}/3rdparty/nvbench) target_link_libraries(bench_sampling PRIVATE nvbench::main) target_compile_options(bench_sampling PRIVATE -Wno-switch-bool) From 1c8ed6174d9ad4888bd8235267c43d59b07f5fa3 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Tue, 24 Sep 2024 04:09:15 +0000 Subject: [PATCH 09/15] add support of prefill (part 1). AMD mfma inst --- .gitmodules | 5 +- CMakeLists.txt | 31 +- include/flashinfer/attention/prefill.cuh | 377 +++++++++++++++++-- include/flashinfer/hip_cuda_type_utils.h | 11 + include/flashinfer/hip_defs.h | 2 +- include/flashinfer/hip_warp_sync_functions.h | 2 +- include/flashinfer/mma.cuh | 70 +++- include/flashinfer/permuted_smem.cuh | 30 ++ include/flashinfer/utils.cuh | 6 +- include/flashinfer/vec_dtypes.cuh | 2 +- src/test_single_prefill.cu | 36 +- src/utils.h | 3 +- 12 files changed, 510 insertions(+), 65 deletions(-) diff --git a/.gitmodules b/.gitmodules index b9576b70..52842e14 100644 --- a/.gitmodules +++ b/.gitmodules @@ -2,8 +2,9 @@ path = 3rdparty/nvbench url = https://github.com/NVIDIA/nvbench.git [submodule "3rdparty/hipbench"] - path = 3rdparty/hipbench - url = https://github.com/ROCm/hipBench.git + path = 3rdparty/hipbench + # url = https://github.com/ROCm/hipBench.git + url = https://github.com/yiakwy-xpu-ml-framework-team/hipbench [submodule "3rdparty/googletest"] path = 3rdparty/googletest url = https://github.com/google/googletest.git diff --git a/CMakeLists.txt b/CMakeLists.txt index 23913d56..f3ea1a42 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -36,6 +36,9 @@ set(CMAKE_MODULE_PATH "${HIP_CMAKE_PATH}" ${CMAKE_MODULE_PATH}) ##### Flash infer project project(flashinfer C CXX) +set(CMAKE_CXX_FLAGS_DEBUG "-g -ggdb -O0") # clang++ crashes without -O2 +set( CMAKE_EXPORT_COMPILE_COMMANDS ON CACHE BOOL "" FORCE ) + # set CMAKE_CXX_COMPILER to hipcc # set(CMAKE_FIND_DEBUG_MODE TRUE) add_definitions(-Wall) @@ -393,7 +396,15 @@ foreach(head_dim IN LISTS HEAD_DIMS) endforeach(logits_post_hook) endforeach(head_dim) -add_library(prefill_kernels STATIC ${single_prefill_kernels_src} ${batch_paged_prefill_kernels_src} ${batch_ragged_prefill_kernels_src}) +if (HIP_FOUND) + set_source_files_properties(${single_prefill_kernels_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + set_source_files_properties(${batch_paged_prefill_kernels_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + set_source_files_properties(${batch_ragged_prefill_kernels_src} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_library(prefill_kernels STATIC ${single_prefill_kernels_src} ${batch_paged_prefill_kernels_src} ${batch_ragged_prefill_kernels_src}) +else(HIP_FOUND) + add_library(prefill_kernels STATIC ${single_prefill_kernels_src} ${batch_paged_prefill_kernels_src} ${batch_ragged_prefill_kernels_src}) +endif() + target_include_directories(prefill_kernels PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_compile_options(prefill_kernels PRIVATE -Xcompiler=-fPIC --fatbin-options -compress-all) if (HIP_FOUND) @@ -474,7 +485,14 @@ if (FLASHINFER_PREFILL) message(STATUS "Compile single prefill kernel tests.") file(GLOB_RECURSE TEST_PREFILL_SRCS ${PROJECT_SOURCE_DIR}/src/test_single_prefill.cu) - add_executable(test_single_prefill ${TEST_PREFILL_SRCS}) + + if (HIP_FOUND) + set_source_files_properties(${TEST_PREFILL_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(test_single_prefill ${TEST_PREFILL_SRCS}) + else(HIP_FOUND) + add_executable(test_single_prefill ${TEST_PREFILL_SRCS}) + endif() + target_include_directories(test_single_prefill PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_single_prefill PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) add_dependencies(test_single_prefill dispatch_inc) @@ -540,7 +558,14 @@ if (FLASHINFER_CASCADE) message(STATUS "Compile cascade kernel tests.") file(GLOB_RECURSE TEST_CASCADE_SRCS ${PROJECT_SOURCE_DIR}/src/test_cascade.cu) - add_executable(test_cascade ${TEST_CASCADE_SRCS}) + + if (HIP_FOUND) + set_source_files_properties(${TEST_CASCADE_SRCS} PROPERTIES HIP_SOURCE_PROPERTY_FORMAT 1) + hip_add_executable(test_cascade ${TEST_CASCADE_SRCS}) + else(HIP_FOUND) + add_executable(test_cascade ${TEST_CASCADE_SRCS}) + endif() + target_include_directories(test_cascade PRIVATE ${FLASHINFER_INCLUDE_DIR}) target_include_directories(test_cascade PRIVATE ${gtest_SOURCE_DIR}/include ${gtest_SOURCE_DIR}) add_dependencies(test_cascade dispatch_inc) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index cf059efd..c779ef37 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -24,6 +24,13 @@ // CUDA API Portable interfaces #include "flashinfer/hip_defs.h" +#include + +#include + +// device print +#include + #else #include @@ -55,7 +62,12 @@ namespace cg = cooperative_groups; using cp_async::SharedMemFillMode; using mma::MMAMode; +#ifdef USE_ROCM +// TODO (yiakwy) : use AMD constants +constexpr uint32_t warp_size = 64; +#else constexpr uint32_t warp_size = 32; +#endif namespace { @@ -198,7 +210,7 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { #pragma unroll for (uint32_t j = 0; j < num_frags_y / (8 / sizeof(T)); ++j) { - smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * num_elems_per_128b(); } @@ -215,7 +227,7 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* static_assert(num_frags_z * 2 % num_warps_x == 0); #pragma unroll for (uint32_t i = 0; i < num_frags_z * 2 / num_warps_x; ++i) { - smem.load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_row(*smem_offset); kv_idx += num_warps * 8; @@ -248,7 +260,7 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint3 DType* gptr = produce_v ? paged_kv.v_data + kv_offset[i] : paged_kv.k_data + kv_offset[i]; #pragma unroll for (uint32_t j = 0; j < num_frags_y / (8 / sizeof(DType)); ++j) { - smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, gptr, kv_idx < kv_len); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); gptr += 8 * num_elems_per_128b(); } @@ -265,7 +277,7 @@ __device__ __forceinline__ void page_produce_kv(smem_t smem, uint3 #pragma unroll for (uint32_t i = 0; i < num_frags_z * 2 / num_warps_x; ++i) { DType* gptr = produce_v ? paged_kv.v_data + kv_offset[i] : paged_kv.k_data + kv_offset[i]; - smem.load_128b_async(*smem_offset, gptr, kv_idx < kv_len); + smem.template load_128b_async(*smem_offset, gptr, kv_idx < kv_len); kv_idx += num_warps * 8; *smem_offset = smem.template advance_offset_by_row(*smem_offset); @@ -329,7 +341,22 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, const uint32_t lane_idx = threadIdx.x, warp_idx_x = get_warp_idx_x(); if (get_warp_idx_z() == 0) { - uint32_t q_smem_offset_w = q_smem->get_permuted_offset( + + // NOTE(yiakwy) : each thread of a 32 threads block, cooperatively load 128 bit (uint4/float4/halfx8) data from system memory to shared memory + // qsmem shape = (_, 128 Byte) + // -- frags x -> (but loaded into SMEM the next 16 rows) + // qsmem row/col 0 1 ... 7 warp_idx {0..3} + // 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... 60 61 62 63 0 | + // 1 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 ... 124 125 126 127 0 | + // 2 . . . . . . . . . . . . . . . . ... . . . . 0 frags y + // 3 . . . . . . . . . . . . . . . . ... . . . . 0 | + // ... . . . . . . . . . . . . . . . . ... . . . . 0 | + // 0+4*3 . . . . . . . . . . . . . . . . ... . . . . 0 v + // 1+4*3 . . . . . . . . . . . . . . . . ... . . . . 0 + // 2+4*3 . . . . . . . . . . . . . . . . ... . . . . 0 + // 3+4*3 . . . . . . . . . . . . . . . . ... . . . . 0 + // qsmem is (num_frags_x x 16) x 64 (128 bit) matrix fragment + uint32_t q_smem_offset_w = q_smem->template get_permuted_offset( warp_idx_x * num_frags_x * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll @@ -339,15 +366,32 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, uint32_t q, r; group_size.divmod(packed_offset + lane_idx / 8 + fx * 16 + j * 4, q, r); const uint32_t q_idx = q; + + // NOTE (yiakwy) : q_ptr = q[bz/*head*/, bx{0} * num_rows_per_cta{16} + warp_idx_x * num_frags_x * 16 + lane_idx / 8 + j * 4 /*seqlen*/, 0/*hdim*/] + (lane_idx % 8) * 8 DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h; + #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { // load q fragment from gmem to smem - q_smem->load_128b_async(q_smem_offset_w, q_ptr, + // NOTE (yiakwy) : qsmem[warp_idx_x * num_frags_x * 16 + lane_idx / 8 + j * 4, lane_idx % 8] = q[bz/*head*/, get_warp_idx_x<1, 4>() * 16 + lane_idx / 8 + j * 4/*seqlen*/, 0/*hdim*/] + (lane_idx % 8) * 8 + q_smem->template load_128b_async(q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); + + b128_t* smem_ptr = q_smem->base + (lane_idx / 8 + j * 4) * 8 + lane_idx % 8; + float16_t *s = reinterpret_cast(smem_ptr); + // #ifdef DEBUG + printf("[load q from global] (x=%d,z=%d,j=%d), q_smem[%d, %d](%f..%f) = q[H=%d,N_CTX=%d, %d](%f..%f)\n", threadIdx.x, threadIdx.z, j, lane_idx / 8 + j * 4, lane_idx % 8, (float)(*(s)), (float)(*(s+7)), 0, lane_idx / 8 + j * 4, (lane_idx % 8) * 8, (float)q_ptr[0], (float)q_ptr[7]); + // #endif + q_smem_offset_w = q_smem->template advance_offset_by_column<8>(q_smem_offset_w, fyo); - q_ptr += 8 * num_elems_per_128b(); + + // NOTE(yiakwy) : no need to increment at the last iteration + if (fyo + 1 < num_frags_y / 4) { + q_ptr += 8 * num_elems_per_128b(); + } } + + // TODO (yiakwy) : rewrite q_smem_offset_w = q_smem->template advance_offset_by_row<4, channel_size_128b_q>(q_smem_offset_w) - 2 * num_frags_y; @@ -462,7 +506,6 @@ __device__ __forceinline__ void k_smem_inplace_apply_rotary(const uint32_t kv_id static_assert(num_warps_z == 1); const uint32_t warp_idx = get_warp_idx_x(); // horizontal-axis: y - // horizontal-axis: y // vertical-axis: z // | 1-16 | 16-32 | 32-48 | 48-64 | // | 1-16 | warp_idx=0 | warp_idx=1 | warp_idx=0 | warp_idx=1 | @@ -536,17 +579,148 @@ __device__ __forceinline__ void compute_qk( constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); + + #ifdef USE_ROCM + + // TODO (yiakwy) : REMOVE + if (threadIdx.x == 0 && threadIdx.z == 0) { + printf("[compute_qk] channel_size_128b_q=%d, channel_size_128b_kv=%d\n", channel_size_128b_q, channel_size_128b_kv); + printf("[compute_qk] num_frags_x=%d, num_frags_y=%d, num_frags_z=%d\n", num_frags_x, num_frags_y, num_frags_z); + } + + // NOTE(yiakwy) : each thread of 64=16x4 threads block, cooperatively loads 4 x consecutive fp16/bf16 data to cover 16x16 matrix frag + uint32_t a_frag[num_frags_x][num_frags_y][2]; + uint32_t b_frag[num_frags_x][num_frags_z][2]; + + // hence + uint32_t lane_id = threadIdx.x + threadIdx.z * 32; + + uint32_t lane_id_x = lane_id % 16; + uint32_t lane_id_y = lane_id / 16; + + uint32_t warp_idx_x = get_warp_idx_x<1, 4>(); + + using float16_t = rocwmma::float16_t; + + using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t; + using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; + + #define MTX_FRAG_LDA 64 + + #else + + // NOTE(yiakwy) : each thread of 32=8x4 threads block, cooperatively loads 2 x fp16/bf16 data, and repeat 4 (x4) times in 4 warps to cover 16x16 matrix frag uint32_t a_frag[num_frags_x][4], b_frag[4]; + + #endif + // compute q*k^T + #ifdef USE_ROCM + +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + + // TODO (yiakwy) : check + if (lane_id >= 64) { + continue; + } + + // load q +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + + // NOTE (yiakwy) : q_smem has shape of (num_frags_x, 16, 8x8), v_mfma_m16n16k16_fp16 will be applied 4 times along feat dim, then do sum + b128_t* smem_ptr = q_smem->base + (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; + float16_t *s = reinterpret_cast(smem_ptr); + + float16x4 *a = reinterpret_cast(a_frag[fx][fy]); + + float ref_0 = (float)(*(s+(threadIdx.x / 8 + fy * 4) * 64 + threadIdx.x % 8 )); + float ref_1 = (float)(*(s+(threadIdx.x / 8 + fy * 4) * 64 + threadIdx.x % 8 + 7)); + printf("[compute_qk] s[%d, %d]=%f..%f\n", threadIdx.x / 8 + fy * 4, threadIdx.x % 8, ref_0, ref_1); + + // TODO (yiakwy) : replaced with more efficient load instruction +#pragma unroll + for (uint32_t j=0; j < 4; j++) { + // NOTE (yiakwy) : loads 1 columns of data + uint32_t offset = lane_id_x * MTX_FRAG_LDA + j + lane_id_y * 4 + fy * 16; + s += offset; + + (*a)[j] = *(s); + + // TODO (yiakwy) : REMOVE + if (fx==0 && fy== 0) { + // printf("[compute_qk] (fy=%d, lane_id_x=%d, lane_id_y=%d, j=%d), [compute_qk] a_frag[fx=%d][fy=%d][j=%d]=%f, s[%d]=%f\n", fy, lane_id_x, lane_id_y, j, fx, fy, j, (float)((*a)[j]), offset, (float)(*s)); + } + } + } + + *q_smem_offset_r = + q_smem->template advance_offset_by_row<16, channel_size_128b_q>(*q_smem_offset_r); + + // load k +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + + if constexpr (sizeof(DTypeKV) == 1) { + assert(0 && "KV Cache with FP8 data type is not supported in ROCM"); + } + + b128_t* smem_ptr = k_smem->base + (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; + float16_t *s = reinterpret_cast(smem_ptr); + + float16x4 *b = reinterpret_cast(b_frag[fx][fz]); + + // TODO (yiakwy) : replaced with more efficient load instruction +#pragma unroll + for (int j=0; j < 4; j++) { + // NOTE (yiakwy) : loads 16 consecutive data of 1 row + s += lane_id_x + lane_id_y * MTX_FRAG_LDA * 4 + j * MTX_FRAG_LDA + fz * MTX_FRAG_LDA * 16; + + (*b)[j] = *s; + // TODO (yiakwy) : REMOVE + if (fx==0 && fz== 0) { + // printf("[compute_qk] (%d, %d, %d), [compute_qk] b_frag[%d][%d][%d]=%f\n", lane_id_x, j, lane_id_y, fx, fz, j, (float)((*b)[j])); + } + } + + *k_smem_offset_r = + k_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*k_smem_offset_r); + + } + + // compute + assert( num_frags_y == num_frags_z && "num_frags_y is not equal to num_frags_z"); + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + float16x4 *a = reinterpret_cast(a_frag[fx][fz]); + float16x4 *b = reinterpret_cast(b_frag[fx][fz]); + + floatx4 *d = reinterpret_cast(s_frag[fx][fz]); + + if constexpr (std::is_same::value) { + *d = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *d, 0, 0, 0); + } else { + // TODO (yiakwy) : device cast fp32 to fp16 + assert(0 && "AMD v_mfma instruction does not support fp16 output."); + // *d = __builtin_amdgcn_mfma_f16_16x16x16f16(*a, *b, *d, 0, 0, 0); + } + } + } + + #else + #pragma unroll + // NOTE(yiakwy) each thead read 2 elments and repeat 4 times (num_frags_y), threads cooperatively loads 16x64 for (uint32_t fy = 0; fy < num_frags_y; ++fy) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { q_smem->ldmatrix_m8n8x4(*q_smem_offset_r, a_frag[fx]); + + // NOTE (yiakwy) : move to the next 16 rows *q_smem_offset_r = - q_smem->template advance_offset_by_row<16, channel_size_128b_q>(*q_smem_offset_r); - } + q_smem->template advance_offset_by_row<16, channel_size_128b_q>(*q_smem_offset_r); + } *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, fy) - num_frags_x * 16 * channel_size_128b_q; @@ -561,7 +735,7 @@ __device__ __forceinline__ void compute_qk( } b_frag_f8[0] = frag_layout_swizzle_16b_to_8b(b_frag_f8[0]); b_frag_f8[1] = frag_layout_swizzle_16b_to_8b(b_frag_f8[1]); - vec_cast::cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8); + vec_cast::template cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8); } else { k_smem->ldmatrix_m8n8x4(*k_smem_offset_r, b_frag); } @@ -598,6 +772,9 @@ __device__ __forceinline__ void compute_qk( num_frags_z * 16 * channel_size_128b_kv; } } + +#endif // USE_ROCM + *q_smem_offset_r -= num_frags_y * 2; *k_smem_offset_r -= num_frags_y * sizeof(DTypeKV); @@ -776,13 +953,18 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); + #ifdef USE_ROCM + using float16x4 = __attribute__((__vector_size__(4 * sizeof(rocwmma::float16_t)))) rocwmma::float16_t; + using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; + #endif + DTypeQ s_frag_f16[num_frags_x][num_frags_z][8]; if constexpr (std::is_same::value) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t fz = 0; fz < num_frags_z; ++fz) { - vec_cast::cast<8>(s_frag_f16[fx][fz], s_frag[fx][fz]); + vec_cast::template cast<8>(s_frag_f16[fx][fz], s_frag[fx][fz]); } } } @@ -791,11 +973,29 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + + #ifdef USE_ROCM + // TODO (yiakwy) : check feasibility with v_mfma_fp32_m16xn16xk16_fp16 +#pragma unroll + for (int i=0; i < 8; ++i) { + if constexpr (std::is_same::value) { + // device cast from half to float + d[fx][i] += (float)s_frag_f16[fx][fz][i]; + } else { + // device cast from half to float + d[fx][i] += (float)s_frag[fx][fz][i]; + } + } + + #else + if constexpr (std::is_same::value) { mma::rowsum_f16f16f32(d[fx], s_frag_f16[fx][fz]); } else { mma::rowsum_f16f16f32(d[fx], s_frag[fx][fz]); } + + #endif // USE_ROCM } } @@ -805,6 +1005,10 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t b_frag[4]; if constexpr (sizeof(DTypeKV) == 1) { + + // TODO (yiakwy) : add FP8 support for KV Cache + assert(0 && "FP8 KV Cache is not supported."); + uint32_t b_frag_f8[2]; if (fy % 2 == 0) { v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_f8); @@ -813,13 +1017,44 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, } b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); - vec_cast::cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8); + vec_cast::template cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8); swap(b_frag[1], b_frag[2]); } else { + + #ifdef USE_ROCM + + b128_t* smem_ptr = v_smem->base + *v_smem_offset_r; + + float16x4 *b = reinterpret_cast(b_frag); + +#pragma unroll + for (int j=0; j < 4; j++) { + (*b)[j] = (rocwmma::float16_t)(*b)[j]; + } + + #else v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); + #endif // USE_ROCM + } #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + + #ifdef USE_ROCM + + float16x4 *b = reinterpret_cast(b_frag); + floatx4 *o = reinterpret_cast(o_frag[fx][fz]); + + if constexpr (std::is_same::value) { + float16x4 *a = reinterpret_cast(s_frag_f16[fx][fz]); + *o = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *o, 0, 0, 0); + } else { + float16x4 *a = reinterpret_cast(s_frag[fx][fz]); + *o = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *o, 0, 0, 0); + } + + #else + if constexpr (std::is_same::value) { mma::mma_sync_m16n16k16_row_col_f16f16f32( o_frag[fx][fy], (uint32_t*)(s_frag_f16[fx][fz]), b_frag); @@ -827,7 +1062,11 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, mma::mma_sync_m16n16k16_row_col_f16f16f32(o_frag[fx][fy], (uint32_t*)s_frag[fx][fz], b_frag); } + + #endif // USE_ROCM } + + // TODO (yiakwy) : fix if constexpr (sizeof(DTypeKV) == 1) { if (fy % 2 == 1) { *v_smem_offset_r = v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, fy / 2); @@ -985,13 +1224,13 @@ __device__ __forceinline__ void write_o_reg_gmem( #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { uint32_t o_frag_f16[4]; - vec_cast::cast<8>((DTypeOut*)o_frag_f16, o_frag[fx][fy]); + vec_cast::template cast<8>((DTypeOut*)o_frag_f16, o_frag[fx][fy]); #ifdef FLASHINFER_STMATRIX_M8N8X4_ENABLED - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( (warp_idx_x * num_frags_x + fx) * 16 + lane_idx % 16, fy * 2 + lane_idx / 16); o_smem->stmatrix_m8n8x4(o_smem_offset_w, o_frag_f16); #else - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( (warp_idx_x * num_frags_x + fx) * 16 + lane_idx / 4, fy * 2); ((uint32_t*)(o_smem->base + o_smem_offset_w))[lane_idx % 4] = o_frag_f16[0]; ((uint32_t*)(o_smem->base + o_smem_offset_w + 8 * channel_size_128b_out))[lane_idx % 4] = @@ -1003,7 +1242,7 @@ __device__ __forceinline__ void write_o_reg_gmem( } } - uint32_t o_smem_offset_w = o_smem->get_permuted_offset( + uint32_t o_smem_offset_w = o_smem->template get_permuted_offset( warp_idx_x * num_frags_x * 16 + lane_idx / 8, lane_idx % 8); #pragma unroll @@ -1083,8 +1322,10 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC const uint32_t bx = blockIdx.x, chunk_idx = blockIdx.y, kv_head_idx = blockIdx.z; const uint32_t num_kv_heads = gridDim.z, num_qo_heads = num_kv_heads * group_size; constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; + const tensor_info_t qkv_info(qo_len, kv_len, num_qo_heads, num_kv_heads, q_stride_n, q_stride_h, kv_stride_n, kv_stride_h, /*head_dim=*/num_frags_y * 16); + float alibi_slopes[num_frags_x][2]; const uint32_t num_chunks = gridDim.y; @@ -1098,13 +1339,15 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC const uint32_t window_left = (maybe_window_left >= 0) ? maybe_window_left : kv_len; constexpr uint32_t head_dim = num_frags_y * 16; - constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_q = head_dim / num_elems_per_128b(); // e.g.:64/8 constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); + constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); // e.g.: 64/4 extern __shared__ uint8_t smem[]; + // NOTE(yiakwy) : e.g. 1x4 fragments, threads cooperatively ld/st 16 bf16/half elements in each frag DTypeQKAccum s_frag[num_frags_x][num_frags_z][8]; + // NOTE(yiakwy) : e.g. 1x4 fragments, threads cooperatively ld/st 16 bf16/half elements in each frag float o_frag[num_frags_x][num_frags_y][8]; DTypeQKAccum m[num_frags_x][2]; float d[num_frags_x][2]; @@ -1114,14 +1357,19 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC } init_states(o_frag, m, d); - // cooperative fetch q fragment from gmem to reg + // TODO (yiakwy) : to be used by load_q_global_smem, double check to compute offset of q + // cooperatively fetch q fragment from gmem to reg const uint32_t qo_packed_idx_base = (bx * num_warps_x + get_warp_idx_x()) * num_frags_x * 16; + constexpr SwizzleMode swizzle_mode_q = SwizzleMode::k128B; smem_t qo_smem(smem); + + // TODO (yiakwy) : to be used by load_q_global_smem, double check to compute offset of q DTypeQ* q_ptr_base = q + qkv_info.get_q_elem_offset(0, kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b()); + DTypeOut* o_ptr_base = partition_kv ? o + chunk_idx * num_qo_heads * head_dim + @@ -1129,10 +1377,15 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC (lane_idx % 8) * num_elems_per_128b()) : o + qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b()); - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, lane_idx / 16); + if (threadIdx.x == 0 && threadIdx.z == 0) { + printf("[prefill kernel] channel_size_128b_q = %d\n", channel_size_128b_q); + } + + // NOTE(yiakwy) : FA2 outter loop (block level) load q first and iterate over sequence dimension inside a block load_q_global_smem( qo_packed_idx_base, qo_len, q_ptr_base, q_stride_n, q_stride_h, group_size, &qo_smem); @@ -1146,8 +1399,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC qo_packed_idx_base, qo_len, kv_len, group_size, &qo_smem, &q_smem_offset_r, rope_freq, sm_scale); } else { + if (threadIdx.x==0 && threadIdx.z==0) { + printf("[prefill kernel] skip q_smem_inplace_multiply_sm_scale.\n"); + } + // TODO (yiakwy) : recover + /* q_smem_inplace_multiply_sm_scale(&qo_smem, sm_scale); + */ } if constexpr (pos_encoding_mode == PosEncodingMode::kALiBi) { @@ -1202,14 +1461,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC v + qkv_info.get_kv_elem_offset( chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, kv_head_idx, (lane_idx % kv_frag_cols) * num_elems_per_128b()); - uint32_t k_smem_offset_r = k_smem.get_permuted_offset( + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), - v_smem_offset_r = v_smem.get_permuted_offset( + v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + lane_idx % 16, lane_idx / 16), - kv_smem_offset_w = k_smem.get_permuted_offset( + kv_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); produce_kv( k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, 0, chunk_size); @@ -1218,6 +1477,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, 0, chunk_size); cp_async::commit_group(); + // NOTE (yiakwy) : kv inner loop #pragma unroll 1 for (uint32_t iter = 0; iter < num_iterations; ++iter) { cp_async::wait_group<1>(); @@ -1270,7 +1530,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); - + // compute sfm*v compute_sfm_v( &v_smem, &v_smem_offset_r, s_frag, o_frag, d); @@ -1281,6 +1541,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC (iter + 1) * 16 * num_warps_z * num_frags_z, chunk_size); cp_async::commit_group(); } + cp_async::wait_group<0>(); block.sync(); @@ -1314,6 +1575,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC lse[(qo_idx * num_chunks + chunk_idx) * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]); } else { + // TODO (yiakwy) : REMOVE + uint32_t warp_idx = get_warp_idx(); + // if (warp_idx == 0) { + printf("[write lse] (qo_idx=%d, qo_head_idx=%d), warp_idx=%d, (y, z)=(%d, %d), d[%d][%d]=%f, m[%d][%d]=%f", qo_idx, qo_head_idx, warp_idx, threadIdx.y, threadIdx.z, fx, j, d[fx][j], fx, j, float(m[fx][j])); + // } lse[qo_idx * num_qo_heads + qo_head_idx] = math::ptx_log2(d[fx][j]) + float(m[fx][j]); } @@ -1415,7 +1681,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg : o + qkv_info.get_o_elem_offset(o_indptr[request_idx], kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b()); - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + // 32x4 -> 16x8 + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, lane_idx / 16); @@ -1489,14 +1756,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithRagg num_warps_z * num_frags_z * sizeof(DTypeKV)) * 16 * head_dim); - uint32_t k_smem_offset_r = k_smem.get_permuted_offset( + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), - v_smem_offset_r = v_smem.get_permuted_offset( + v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + lane_idx % 16, lane_idx / 16), - kv_smem_offset_w = k_smem.get_permuted_offset( + kv_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); DTypeKV* k_ptr = @@ -1719,7 +1986,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage : o + get_elem_offset_impl(o_indptr[request_idx], kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b(), num_qo_heads * head_dim, head_dim); - uint32_t q_smem_offset_r = qo_smem.get_permuted_offset( + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, lane_idx / 16); @@ -1772,14 +2039,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void BatchPrefillWithPage 16 * head_dim); size_t kv_offset[num_frags_z * (swizzle_mode_kv == SwizzleMode::k128B ? 4 : 2) / num_warps_x]; - uint32_t k_smem_offset_r = k_smem.get_permuted_offset( + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), - v_smem_offset_r = v_smem.get_permuted_offset( + v_smem_offset_r = v_smem.template get_permuted_offset( get_warp_idx_z() * num_frags_z * 16 + lane_idx % 16, lane_idx / 16), - kv_smem_offset_w = k_smem.get_permuted_offset( + kv_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); const IdType last_indptr = paged_kv.indptr[paged_kv.batch_size]; @@ -1968,6 +2235,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched( cudaStream_t stream) { const float log2_rope_rcp_scale = -std::log2f(rope_scale); const float log2_rope_rcp_theta = -std::log2f(rope_theta); + // TODO (yiakwy) : REMOVE + // e.x.: q: (1/*qo_heads*/, 2/*qo_len*/, 64) kv: (1/*kv_heads*/, 2/*kv_len*/, 64) if (kv_len < qo_len && MASK_MODE == MaskMode::kCausal) { std::ostringstream err_msg; err_msg << "When mask_mode is set to MaskMode::kCausal, kv_len must be greater than or equal " @@ -1985,6 +2254,14 @@ cudaError_t SinglePrefillWithKVCacheDispatched( warp_layout = WarpLayout::k4x1x2; } else { auto compute_capacity = GetCudaComputeCapability(); + #ifdef USE_ROCM + // TODO (yiakwy) : tuning warp layout, ROCM 6.2 SDK output 9.4 + if (unpacked_qo_len > 16) { + warp_layout = WarpLayout::k4x1x1; + } else { + warp_layout = WarpLayout::k1x4x1; + } + #else if (compute_capacity.first >= 8) { // Ampere or newer if (unpacked_qo_len > 16) { @@ -1996,6 +2273,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched( // NOTE(Zihao): not enough shared memory on Turing for 1x4x1 layout warp_layout = WarpLayout::k4x1x1; } + #endif } DISPATCH_WARP_LAYOUT(warp_layout, WARP_LAYOUT, { @@ -2011,6 +2289,7 @@ cudaError_t SinglePrefillWithKVCacheDispatched( &max_smem_per_sm, cudaDevAttrMaxSharedMemoryPerMultiprocessor, dev_id)); // we expect each sm execute two threadblocks // TODO(Zihao): fix the following computation + // TODO (yiakwy) : MI300X returns 64KB (i.e.: 2**16 addresable locations) for max_smem_per_sm, note for HEAD_DIM=64, DTypeQ=half (16 * HEAD_DIM * sizeof(DTypeQ) * 16) = 2**(4 + 6 + 1 + 4) = 2**15 const int num_ctas_per_sm = max_smem_per_sm > (16 * HEAD_DIM * sizeof(DTypeQ) * 16) ? 2 : 1; const int max_smem_per_threadblock = max_smem_per_sm / num_ctas_per_sm; @@ -2022,10 +2301,17 @@ cudaError_t SinglePrefillWithKVCacheDispatched( ? 2 : (8 / num_frags_x); // TODO(Zihao): fix the following computation + // NOTE(yiakwy) : for HEAD_DIM=64, DTypeQ=half and num_warps_z=4, max_num_frags_z_smem=32KB / 2**(4 + 6 + 1/*dtypeQ*/ + 1 + 2/*warp_size*/) = 2**15/(2**14 - delta) = 1 + /* const uint32_t max_num_frags_z_smem = (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - num_frags_x * num_warps_x) / (2 * num_warps_z); + */ + const uint32_t max_num_frags_z_smem = + (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ))) / + (2 * num_warps_z); + // TODO (yiakwy) : fix here // control num_frags_z for maximum warp occupancy DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { if constexpr (is_invalid_configuration( @@ -2039,8 +2325,8 @@ cudaError_t SinglePrefillWithKVCacheDispatched( " and report the issue to the developers."; throw std::invalid_argument(err_msg.str()); } else { - constexpr uint32_t num_threads = (num_warps_x * num_warps_z) * warp_size; - constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; + constexpr uint32_t num_threads = (num_warps_x * num_warps_z) * warp_size; // 4x1x64=256 + constexpr uint32_t num_rows_per_cta = num_frags_x * num_warps_x * 16; // 1x4x16=64 auto kernel = SinglePrefillWithKVCacheKernel 0) { uint32_t chunk_size = max(ceil_div(kv_len, max_num_kv_chunks), 256); - num_chunks = ceil_div(kv_len, chunk_size); + num_chunks = ceil_div(kv_len, chunk_size); } else { num_chunks = 0; } + // TODO(yiakwy) : REMOVE + std::cout << "qo_len : " << qo_len << std::endl; + std::cout << "kv_len : " << kv_len << std::endl; + + std::cout << "num_blocks_per_sm : " << num_blocks_per_sm << std::endl; + std::cout << "max_num_kv_chunks : " << max_num_kv_chunks << std::endl; + std::cout << "num_chunks : " << num_chunks << std::endl; + + std::cout << "num_rows_per_cta : " << num_rows_per_cta << std::endl; + std::cout << "num_threads : " << num_threads << std::endl; + std::cout << "num_warps_x : " << num_warps_x << std::endl; + std::cout << "num_warps_z : " << num_warps_z << std::endl; + if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv bool partition_kv = false; @@ -2331,9 +2630,15 @@ cudaError_t BatchPrefillWithPagedKVCacheDispatched( ? 2 : (8 / num_frags_x); // TODO(Zihao): fix the following computation + // NOTE (yiakwy) : fix max_num_frags_z_smem + /* const uint32_t max_num_frags_z_smem = (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ)) - num_frags_x * num_warps_x) / (2 * num_warps_z); + */ + const uint32_t max_num_frags_z_smem = + (max_smem_per_threadblock / (16 * HEAD_DIM * sizeof(DTypeQ))) / + (2 * num_warps_z); DISPATCH_NUM_FRAGS_Z(min(max_num_frags_z_smem, max_num_frags_z_reg), num_frags_z, { if constexpr (is_invalid_configuration( diff --git a/include/flashinfer/hip_cuda_type_utils.h b/include/flashinfer/hip_cuda_type_utils.h index 1bb26023..1fff1977 100644 --- a/include/flashinfer/hip_cuda_type_utils.h +++ b/include/flashinfer/hip_cuda_type_utils.h @@ -34,6 +34,17 @@ THE SOFTWARE. // CUDA DEVICE API Supported : https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Device_API_supported_by_HIP.html +// #if defined(_Float16) && !defined(float16_t) +// NOTE(yiakwy) : used by rocWMMA +// TODO(yiakwy) : unifying fp16/half definition + +#include + +// using float16_t = _Float16; +using float16_t = rocwmma::float16_t; + +// #endif + /*! \brief Struct to packet two 16 bit brain floating point numbers. */ using nv_bfloat162 = __hip_bfloat162; using __nv_bfloat162 = __hip_bfloat162; diff --git a/include/flashinfer/hip_defs.h b/include/flashinfer/hip_defs.h index 09475bc3..ff12e60b 100644 --- a/include/flashinfer/hip_defs.h +++ b/include/flashinfer/hip_defs.h @@ -30,7 +30,7 @@ using cudaDeviceAttr = hipDeviceAttribute_t; const cudaDeviceAttr cudaDevAttrMultiProcessorCount = hipDeviceAttribute_t::hipDeviceAttributeMultiprocessorCount; const cudaDeviceAttr cudaDevAttrMaxSharedMemoryPerMultiprocessor = hipDeviceAttribute_t::hipDeviceAttributeMaxSharedMemoryPerMultiprocessor; -// function alas +// function alias template inline static hipError_t cudaFuncSetAttribute(Func&& func, const hipFuncAttribute& attr, int value) { return hipFuncSetAttribute((void*)func, attr, value); diff --git a/include/flashinfer/hip_warp_sync_functions.h b/include/flashinfer/hip_warp_sync_functions.h index d7ec9bd4..135a7026 100644 --- a/include/flashinfer/hip_warp_sync_functions.h +++ b/include/flashinfer/hip_warp_sync_functions.h @@ -4,7 +4,7 @@ #include -// note in SDK we have this statement device_prop.warpSize +// note in SDK we have this value from statement device_prop.warpSize #ifndef __warpSize #define __warpSize 64 #endif diff --git a/include/flashinfer/mma.cuh b/include/flashinfer/mma.cuh index e3a25894..2ab5905a 100644 --- a/include/flashinfer/mma.cuh +++ b/include/flashinfer/mma.cuh @@ -17,11 +17,20 @@ #define FLASHINFER_MMA_CUH_ #ifdef USE_ROCM +#include #include "flashinfer/hip_cuda_type_utils.h" // CUDA API Portable interfaces #include "flashinfer/hip_defs.h" +#ifndef FULL_MASK +#define FULL_MASK 0xffffffff +#endif + +#include + +// using bfloat16x4 = __attribute__((__vector_size__(4 * sizeof(bfloat16_t)))) bfloat16_t; + #else #include @@ -29,6 +38,10 @@ #include #include +#ifndef FULL_MASK +#define FULL_MASK 0xffffffffffffffff +#endif + #endif // USE_ROCM #include @@ -84,7 +97,11 @@ __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) : "r"(smem_int_ptr)); #else + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("v_mfma_f32_8x8x4bf16 not supported, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); + #endif #endif } @@ -103,7 +120,11 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_left_half(uint32_t* R, T* smem_p : "=r"(R[0]), "=r"(R[1]) : "r"(smem_int_ptr)); #else + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("ROCM suppoort of ldmatrix_m8n8x4_left_half is pending, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); + #endif #endif } @@ -122,7 +143,11 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_right_half(uint32_t* R, T* smem_ : "=r"(R[0]), "=r"(R[1]) : "r"(smem_int_ptr)); #else + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("ROCM support of ldmatrix_m8n8x4_right_half is pending, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); + #endif #endif } @@ -141,7 +166,11 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans(uint32_t* R, T* smem_ptr) : "=r"(R[0]), "=r"(R[1]), "=r"(R[2]), "=r"(R[3]) : "r"(smem_int_ptr)); #else + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("ROCM support of ldmatrix_m8n8x4_trans is pending, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); + #endif #endif } @@ -160,7 +189,11 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans_left_half(uint32_t* R, T* : "=r"(R[0]), "=r"(R[1]) : "r"(smem_int_ptr)); #else + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("ROCM support of ldmatrix_m8n8x4_trans_left_half is pending, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); + #endif #endif } @@ -180,6 +213,11 @@ __device__ __forceinline__ void ldmatrix_m8n8x4_trans_right_half(uint32_t* R, T* : "r"(smem_int_ptr)); #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("ROCM support of ldmatrix_m8n8x4_trans_right_half is pending, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); + #endif #endif } @@ -203,10 +241,10 @@ __device__ __forceinline__ void stmatrix_m8n8x4(uint32_t* R, T* smem_ptr) { uint4 word; #pragma unroll for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { - word.x = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4); - word.y = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 1); - word.z = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 2); - word.w = __shfl_sync(0xffffffff, R[reg_id], (tx % 8) * 4 + 3); + word.x = __shfl_sync(FULL_MASK, R[reg_id], (tx % 8) * 4); + word.y = __shfl_sync(FULL_MASK, R[reg_id], (tx % 8) * 4 + 1); + word.z = __shfl_sync(FULL_MASK, R[reg_id], (tx % 8) * 4 + 2); + word.w = __shfl_sync(FULL_MASK, R[reg_id], (tx % 8) * 4 + 3); if (tx / 8 == reg_id) { *(uint4*)smem_ptr = word; } @@ -310,8 +348,12 @@ __device__ __forceinline__ void mma_sync_m16n16k32_row_col_f8f8f32(float* C, uin } } #else + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("ROCM support of mma_sync_m16n16k32_row_col_f8f8f32 is pending, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else FLASHINFER_RUNTIME_ASSERT( "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"); + #endif #endif } @@ -482,7 +524,11 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f32(float* C, u FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); } #else + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("ROCM support of mma_sync_m16n16k16_row_col_f16f16f32 is pending, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); + #endif #endif } @@ -520,8 +566,12 @@ __device__ __forceinline__ void rowsum_f8f8f32(float* d, DType* s) { "r"(1010580540), "f"(d[0]), "f"(d[1])); } #else + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("ROCM fp8 mma instruction is pending, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else FLASHINFER_RUNTIME_ASSERT( "fp8 mma instruction is only available for sm89, PTX 8.4+ and CUDA 12.4+"); + #endif #endif } @@ -584,7 +634,11 @@ __device__ __forceinline__ void rowsum_f16f16f32(float* d, DType* s) { FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); } #else - FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("v_mfma_f32_16x8x{8,16}_fp16 is not supported, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); + #endif #endif } @@ -704,7 +758,11 @@ __device__ __forceinline__ void mma_sync_m16n16k16_row_col_f16f16f16(uint32_t* C : "r"(A[2]), "r"(A[3]), "r"(B[3]), "r"(C[2]), "r"(C[3])); } #else - FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for mma instruction"); + #ifdef USE_ROCM + FLASHINFER_RUNTIME_ASSERT("v_mfma_f32_16x16x16_fp16 is pending, see https://rocmdocs.amd.com/projects/llvm-project/en/latest/LLVM/llvm/html/AMDGPU/AMDGPUAsmGFX940.html"); + #else + FLASHINFER_RUNTIME_ASSERT("Unsupported CUDA architecture for ldmatrix instruction"); + #endif #endif } diff --git a/include/flashinfer/permuted_smem.cuh b/include/flashinfer/permuted_smem.cuh index 90aff6f4..d7886161 100644 --- a/include/flashinfer/permuted_smem.cuh +++ b/include/flashinfer/permuted_smem.cuh @@ -24,6 +24,8 @@ #include +#include + #else #include @@ -75,6 +77,14 @@ struct smem_t { */ template static __device__ __forceinline__ uint32_t get_permuted_offset(uint32_t i, uint32_t j) { + + #ifdef USE_ROCM + + // TODO (yiakwy) : add swizzle mode + return i * stride + j; + + #else + if constexpr (swizzle_mode == SwizzleMode::k128B) { return i * stride + (j ^ (i % 8)); } else { @@ -82,11 +92,20 @@ struct smem_t { static_assert(stride == 4); return i * stride + (j ^ ((i / 2) % 4)); } + + #endif // USE_ROCM } template static __device__ __forceinline__ uint32_t advance_offset_by_column(uint32_t offset, uint32_t step_idx) { + #ifdef USE_ROCM + + // TODO(yiakwy) : add swizzle mode + return offset + step_size; + + #else + if constexpr (swizzle_mode == SwizzleMode::k128B) { static_assert(step_size == 2 || step_size == 4 || step_size % 8 == 0, "Unsupported step size"); @@ -103,10 +122,19 @@ struct smem_t { static_assert(step_size == 2, "Unsupported step size"); return (offset ^ 0x2) + (step_idx % 2 == 1) * 4; } + + #endif } template static __device__ __forceinline__ uint32_t advance_offset_by_row(uint32_t offset) { + #ifdef USE_ROCM + + // TODO(yiakwy) : add swizzle mode + return offset + step_size * row_stride; + + #else + if constexpr (swizzle_mode == SwizzleMode::k128B) { static_assert(step_size == 4 || step_size % 8 == 0, "Unsupported step size"); if constexpr (step_size == 4) { @@ -124,6 +152,8 @@ struct smem_t { return offset + step_size * row_stride; } } + + #endif // USE_ROCM } __device__ __forceinline__ void ldmatrix_m8n8x4(uint32_t offset, uint32_t* R) { diff --git a/include/flashinfer/utils.cuh b/include/flashinfer/utils.cuh index 1900f38d..f183d9d1 100644 --- a/include/flashinfer/utils.cuh +++ b/include/flashinfer/utils.cuh @@ -278,7 +278,11 @@ inline std::pair GetCudaComputeCapability() { int device_id = 0; hipGetDevice(&device_id); int major = 0, minor = 0; - hipDeviceComputeCapability(&major, &minor, device_id); + hipError_t err = hipDeviceComputeCapability(&major, &minor, device_id); + if(err != hipSuccess) + { + throw std::runtime_error("hip_api_call"); + } return std::make_pair(major, minor); } diff --git a/include/flashinfer/vec_dtypes.cuh b/include/flashinfer/vec_dtypes.cuh index 3f53639c..91e1319f 100644 --- a/include/flashinfer/vec_dtypes.cuh +++ b/include/flashinfer/vec_dtypes.cuh @@ -216,7 +216,7 @@ __device__ void fast_dequant_f8f16x4(uint32_t* input, uint2* output) { *(half2*)&(output->y) = __hmul2(*reinterpret_cast(&Out2), bias_reg); } else { constexpr uint32_t BIAS = (BIAS_OFFSET + 127) << 23; - const nv_bfloat162 bias_reg = __float22bfloat162_rn(*reinterpret_cast(&BIAS)); + const nv_bfloat162 bias_reg = __float2bfloat162_rn(*reinterpret_cast(&BIAS)); // Convert to bfloat162 and apply bias *(nv_bfloat162*)&(output->x) = __hmul2(*reinterpret_cast(&Out1), bias_reg); diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index 08afb71b..2d603898 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -13,6 +13,8 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#include // isnan used + #include #include @@ -34,6 +36,10 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu std::vector v(kv_len * num_kv_heads * head_dim); std::vector o(qo_len * num_qo_heads * head_dim); + // TODO (yiakwy) : we will do a simple test + // q = torch.ones((H=1, N_CTX=2,D_HEAD=64), dtype=torch.float16, device="cuda", requires_grad=False) // kv_layout=1 + // k = q, v = q + // p = torch.matmul(q, k.transpose(1, 2)) // 2 x 2 matrix p[i][j] = 64 utils::vec_normal_(q); utils::vec_normal_(k); utils::vec_normal_(v); @@ -44,7 +50,7 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu thrust::device_vector v_d(v); thrust::device_vector o_d(o); thrust::device_vector tmp_d(16 * 1024 * 1024); - + cudaError_t status = flashinfer::SinglePrefillWithKVCache( thrust::raw_pointer_cast(q_d.data()), thrust::raw_pointer_cast(k_d.data()), thrust::raw_pointer_cast(v_d.data()), thrust::raw_pointer_cast(o_d.data()), @@ -85,13 +91,13 @@ void _TestSinglePrefillKernelCorrectness(size_t qo_len, size_t kv_len, size_t nu template void TestSinglePrefillKernelLongContextCorrectness(bool allow_fp16_qk_reduction) { - for (size_t qo_len : {1, 31, 63, 127}) { + for (size_t qo_len : {1}) { // for (size_t qo_len : {1, 31, 63, 127}) { for (size_t kv_len : {31717}) { for (size_t num_heads : {1}) { - for (size_t head_dim : {64, 128, 256}) { - for (bool causal : {false, true}) { - for (size_t pos_encoding_mode : {0, 1}) { - for (size_t kv_layout : {0, 1}) { + for (size_t head_dim : {64}) { // for (size_t head_dim : {64, 128, 256}) { + for (bool causal : {false}) { // for (bool causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) { // for (size_t pos_encoding_mode : {0, 1}) { + for (size_t kv_layout : {0}) {// for (size_t kv_layout : {0, 1}) { _TestSinglePrefillKernelCorrectness( qo_len, kv_len, num_heads, num_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), allow_fp16_qk_reduction); @@ -129,13 +135,13 @@ template void TestSinglePrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction) { float rtol = std::is_same::value ? 1e-2 : 1e-3; float atol = std::is_same::value ? 1e-2 : 1e-3; - for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { - for (size_t num_qo_heads : {32}) { - for (size_t num_kv_heads : {4, 8, 32}) { - for (size_t head_dim : {64, 128, 256}) { - for (bool causal : {false, true}) { - for (size_t pos_encoding_mode : {0, 1}) { - for (size_t kv_layout : {0, 1}) { + for (size_t qkv_len : {16}) { // for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { + for (size_t num_qo_heads : {1}) { // for (size_t num_qo_heads : {32}) { + for (size_t num_kv_heads : {1}) { // for (size_t num_kv_heads : {4, 8, 32}) { + for (size_t head_dim : {64}) { // for (size_t head_dim : {64, 128, 256}) { + for (bool causal : {false}) { // for (bool causal : {false, true}) { + for (size_t pos_encoding_mode : {0}) {// for (size_t pos_encoding_mode : {0, 1}) { + for (size_t kv_layout : {1}) { // for (size_t kv_layout : {0, 1}) { _TestSinglePrefillKernelCorrectness( qkv_len, qkv_len, num_qo_heads, num_kv_heads, head_dim, causal, QKVLayout(kv_layout), PosEncodingMode(pos_encoding_mode), @@ -215,6 +221,7 @@ void TestSinglePrefillFP8KernelCorrectness(bool allow_fp16_qk_reduction) { } } +/* TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessFP16) { TestSinglePrefillKernelLongContextCorrectness(false); } @@ -222,11 +229,13 @@ TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessFP1 TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessFP16QKHalfAccum) { TestSinglePrefillKernelLongContextCorrectness(true); } +*/ TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessFP16) { TestSinglePrefillKernelShortContextCorrectness(false); } +/* TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelShortContextCorrectnessFP16QKHalfAccum) { TestSinglePrefillKernelShortContextCorrectness(true); } @@ -238,6 +247,7 @@ TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16) { TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelCorrectnessTestFP16QKHalfAccum) { TestSinglePrefillKernelCorrectness(true); } +*/ #ifdef FLASHINFER_ENABLE_BF16 TEST(FlashInferCorrectnessTest, TestSinglePrefillKernelLongContextCorrectnessBF16) { diff --git a/src/utils.h b/src/utils.h index 015808b7..6ad480c2 100644 --- a/src/utils.h +++ b/src/utils.h @@ -81,7 +81,8 @@ void vec_normal_(std::vector& vec, float mean = 0.f, float std = 1.f) { std::mt19937 gen{rd()}; std::normal_distribution d{mean, std}; for (size_t i = 0; i < vec.size(); ++i) { - vec[i] = T(d(gen)); + // TODO (yiakwy) : RECOVER + vec[i] = T(1.f);//T(d(gen)); } } From 3d2d75d5c9eb772b27721c9c938b04cca806b6a2 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Fri, 11 Oct 2024 11:54:55 +0000 Subject: [PATCH 10/15] update compute_qk fragment layout --- include/flashinfer/attention/prefill.cuh | 148 ++++++++++++++--------- src/test_single_prefill.cu | 2 +- 2 files changed, 90 insertions(+), 60 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index c779ef37..629a58df 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -344,11 +344,11 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, // NOTE(yiakwy) : each thread of a 32 threads block, cooperatively load 128 bit (uint4/float4/halfx8) data from system memory to shared memory // qsmem shape = (_, 128 Byte) - // -- frags x -> (but loaded into SMEM the next 16 rows) + // -- frags y -> // qsmem row/col 0 1 ... 7 warp_idx {0..3} // 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... 60 61 62 63 0 | // 1 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 ... 124 125 126 127 0 | - // 2 . . . . . . . . . . . . . . . . ... . . . . 0 frags y + // 2 . . . . . . . . . . . . . . . . ... . . . . 0 frags x // 3 . . . . . . . . . . . . . . . . ... . . . . 0 | // ... . . . . . . . . . . . . . . . . ... . . . . 0 | // 0+4*3 . . . . . . . . . . . . . . . . ... . . . . 0 v @@ -589,10 +589,11 @@ __device__ __forceinline__ void compute_qk( } // NOTE(yiakwy) : each thread of 64=16x4 threads block, cooperatively loads 4 x consecutive fp16/bf16 data to cover 16x16 matrix frag - uint32_t a_frag[num_frags_x][num_frags_y][2]; - uint32_t b_frag[num_frags_x][num_frags_z][2]; + uint32_t a_frag[num_frags_x][2]; + uint32_t b_frag[2]; // hence + // TODO (yiakwy) : if we change blckDim.x from 32 to 64 uint32_t lane_id = threadIdx.x + threadIdx.z * 32; uint32_t lane_id_x = lane_id % 16; @@ -605,7 +606,8 @@ __device__ __forceinline__ void compute_qk( using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t; using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; - #define MTX_FRAG_LDA 64 + #define MTX_FRAG_LDA (head_dim) + #define MTX_FRAG_LDB (num_frags_z * 16) #else @@ -618,82 +620,90 @@ __device__ __forceinline__ void compute_qk( #ifdef USE_ROCM #pragma unroll - for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + + // load q +#pragma unroll + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - // TODO (yiakwy) : check if (lane_id >= 64) { continue; } - // load q -#pragma unroll - for (uint32_t fy = 0; fy < num_frags_y; ++fy) { - - // NOTE (yiakwy) : q_smem has shape of (num_frags_x, 16, 8x8), v_mfma_m16n16k16_fp16 will be applied 4 times along feat dim, then do sum - b128_t* smem_ptr = q_smem->base + (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; - float16_t *s = reinterpret_cast(smem_ptr); + // NOTE (yiakwy) : q_smem has shape of (num_frags_x, 16, 8x8), v_mfma_m16n16k16_fp16 will be applied 4 times along feat dim + b128_t* smem_ptr = q_smem->base + (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; + float16_t *s = reinterpret_cast(smem_ptr); - float16x4 *a = reinterpret_cast(a_frag[fx][fy]); - - float ref_0 = (float)(*(s+(threadIdx.x / 8 + fy * 4) * 64 + threadIdx.x % 8 )); - float ref_1 = (float)(*(s+(threadIdx.x / 8 + fy * 4) * 64 + threadIdx.x % 8 + 7)); - printf("[compute_qk] s[%d, %d]=%f..%f\n", threadIdx.x / 8 + fy * 4, threadIdx.x % 8, ref_0, ref_1); + float16x4 *a = reinterpret_cast(a_frag[fx][fy]); - // TODO (yiakwy) : replaced with more efficient load instruction -#pragma unroll - for (uint32_t j=0; j < 4; j++) { - // NOTE (yiakwy) : loads 1 columns of data - uint32_t offset = lane_id_x * MTX_FRAG_LDA + j + lane_id_y * 4 + fy * 16; - s += offset; + #ifdef DEBUG + if (lane_id < 32) { + uint32_t nv_lane_id = threadIdx.x; + uint32_t nv_mtx_frag_thread_load_row_offset = nv_lane_id % 16 * 16; + uint32_t nv_mtx_frag_thread_load_col_offset = nv_lane_id / 16 * 8 + fy * 16; + uint32_t nv_mtx_frag_thread_load_offset = nv_mtx_frag_thread_load_row_offset * 64 + nv_mtx_frag_thread_load_col_offset; - (*a)[j] = *(s); + float ref_0 = (float)(*(s+ nv_mtx_frag_thread_load_offset)); + float ref_1 = (float)(*(s+ nv_mtx_frag_thread_load_offset + 8)); - // TODO (yiakwy) : REMOVE - if (fx==0 && fy== 0) { - // printf("[compute_qk] (fy=%d, lane_id_x=%d, lane_id_y=%d, j=%d), [compute_qk] a_frag[fx=%d][fy=%d][j=%d]=%f, s[%d]=%f\n", fy, lane_id_x, lane_id_y, j, fx, fy, j, (float)((*a)[j]), offset, (float)(*s)); - } - } + printf("[compute_qk] s[%d, %d]=%f..%f\n", nv_mtx_frag_thread_load_row_offset, nv_mtx_frag_thread_load_col_offset, ref_0, ref_1); } - *q_smem_offset_r = - q_smem->template advance_offset_by_row<16, channel_size_128b_q>(*q_smem_offset_r); + #endif - // load k + // TODO (yiakwy) : replaced with more efficient load instruction #pragma unroll - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + for (uint32_t j=0; j < 4; j++) { + // NOTE (yiakwy) : loads 1 columns (16xfp16) of data + uint32_t offset = lane_id_x * MTX_FRAG_LDA + j + lane_id_y * 4; + s += offset; - if constexpr (sizeof(DTypeKV) == 1) { - assert(0 && "KV Cache with FP8 data type is not supported in ROCM"); - } + (*a)[j] = *(s); + } + } // num_frags_x + + // NOTE(yiakwy) : next to 16 = 2x8 columns + *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, fy) - + num_frags_x * 16 * channel_size_128b_q; + + // load k +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + + if (lane_id >= 64) { + continue; + } - b128_t* smem_ptr = k_smem->base + (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; - float16_t *s = reinterpret_cast(smem_ptr); + if constexpr (sizeof(DTypeKV) == 1) { + assert(0 && "KV Cache with FP8 data type is not supported in ROCM"); + } - float16x4 *b = reinterpret_cast(b_frag[fx][fz]); + b128_t* smem_ptr = k_smem->base + (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; + float16_t *s = reinterpret_cast(smem_ptr); - // TODO (yiakwy) : replaced with more efficient load instruction -#pragma unroll - for (int j=0; j < 4; j++) { - // NOTE (yiakwy) : loads 16 consecutive data of 1 row - s += lane_id_x + lane_id_y * MTX_FRAG_LDA * 4 + j * MTX_FRAG_LDA + fz * MTX_FRAG_LDA * 16; + float16x4 *b = reinterpret_cast(b_frag); - (*b)[j] = *s; - // TODO (yiakwy) : REMOVE - if (fx==0 && fz== 0) { - // printf("[compute_qk] (%d, %d, %d), [compute_qk] b_frag[%d][%d][%d]=%f\n", lane_id_x, j, lane_id_y, fx, fz, j, (float)((*b)[j])); - } - } + // TODO (yiakwy) : replaced with more efficient load instruction +#pragma unroll + for (int j=0; j < 4; j++) { + // NOTE (yiakwy) : loads 16 consecutive data of 1 row + s += lane_id_x + lane_id_y * MTX_FRAG_LDB * 4 + j * MTX_FRAG_LDB + fz * 16; - *k_smem_offset_r = - k_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*k_smem_offset_r); + (*b)[j] = *s; + // TODO (yiakwy) : REMOVE + if (fy==0 && fz== 0) { + printf("[compute_qk] (lane_id_x=%d, lane_id_y=%d, j=%d), [compute_qk] b_frag[%d]=%f\n", lane_id_x, lane_id_y, j, j, (float)((*b)[j])); + } } + *k_smem_offset_r = + k_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*k_smem_offset_r); + // compute - assert( num_frags_y == num_frags_z && "num_frags_y is not equal to num_frags_z"); - for (uint32_t fz = 0; fz < num_frags_z; ++fz) { - float16x4 *a = reinterpret_cast(a_frag[fx][fz]); - float16x4 *b = reinterpret_cast(b_frag[fx][fz]); + for (uint32_t fx = 0; fx < num_frags_x; ++fx) { + float16x4 *a = reinterpret_cast(a_frag[fx]); + float16x4 *b = reinterpret_cast(b_frag); floatx4 *d = reinterpret_cast(s_frag[fx][fz]); @@ -705,12 +715,31 @@ __device__ __forceinline__ void compute_qk( // *d = __builtin_amdgcn_mfma_f16_16x16x16f16(*a, *b, *d, 0, 0, 0); } } + } + if constexpr (sizeof(DTypeKV) == 1) { + assert(0 && "FP8 KV Cache will be suppported soon."); + } else { + *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, fy) - + num_frags_z * 16 * channel_size_128b_kv; + } } #else #pragma unroll - // NOTE(yiakwy) each thead read 2 elments and repeat 4 times (num_frags_y), threads cooperatively loads 16x64 + // NOTE(yiakwy) each thead read 2 elments and repeat 4xnum_frags_y times , threads cooperatively loads 16x64 + // + // frag_a: + // Dtype=fp16/bf16 + // cols 0 .. 15 16 .. 31 32 63 + // frag_x\frag_y rows 0 1 2 .. 4 + // 0 0 + // .. + // 15 + // 1 16 + // + //frag_b + // for (uint32_t fy = 0; fy < num_frags_y; ++fy) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { @@ -721,6 +750,7 @@ __device__ __forceinline__ void compute_qk( q_smem->template advance_offset_by_row<16, channel_size_128b_q>(*q_smem_offset_r); } + // NOTE(yiakwy) : next to 16 = 2x8 columns *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, fy) - num_frags_x * 16 * channel_size_128b_q; diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index 2d603898..24520acd 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -135,7 +135,7 @@ template void TestSinglePrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction) { float rtol = std::is_same::value ? 1e-2 : 1e-3; float atol = std::is_same::value ? 1e-2 : 1e-3; - for (size_t qkv_len : {16}) { // for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { + for (size_t qkv_len : {4}) { // for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { for (size_t num_qo_heads : {1}) { // for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {1}) { // for (size_t num_kv_heads : {4, 8, 32}) { for (size_t head_dim : {64}) { // for (size_t head_dim : {64, 128, 256}) { From 6d583d85ceb144071445986db0ab62ccde53b7ac Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Fri, 11 Oct 2024 13:41:26 +0000 Subject: [PATCH 11/15] update qk_compute layout and comments --- include/flashinfer/attention/prefill.cuh | 51 +++++++++++++----------- src/test_single_prefill.cu | 2 +- 2 files changed, 29 insertions(+), 24 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 629a58df..911df5f1 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -16,6 +16,8 @@ #ifndef FLASHINFER_PREFILL_CUH_ #define FLASHINFER_PREFILL_CUH_ +#include + #ifdef USE_ROCM #include @@ -373,13 +375,16 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { // load q fragment from gmem to smem - // NOTE (yiakwy) : qsmem[warp_idx_x * num_frags_x * 16 + lane_idx / 8 + j * 4, lane_idx % 8] = q[bz/*head*/, get_warp_idx_x<1, 4>() * 16 + lane_idx / 8 + j * 4/*seqlen*/, 0/*hdim*/] + (lane_idx % 8) * 8 - q_smem->template load_128b_async(q_smem_offset_w, q_ptr, - q_idx < qo_upper_bound); + // NOTE (yiakwy) : qsmem[warp_idx_x * num_frags_x * 16 + lane_idx / 8 + j * 4, lane_idx % 8] = q[bz/*head*/, warp_id_x * 16 + lane_idx / 8 + j * 4/*seqlen*/, 0/*hdim*/] + (lane_idx % 8) * 8 + if (qo_upper_bound >= 16) { + q_smem->template load_128b_async(q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); + } else { + q_smem->template load_128b_async(q_smem_offset_w, q_ptr, q_idx < qo_upper_bound); + } - b128_t* smem_ptr = q_smem->base + (lane_idx / 8 + j * 4) * 8 + lane_idx % 8; - float16_t *s = reinterpret_cast(smem_ptr); // #ifdef DEBUG + b128_t* smem_ptr = q_smem->base + (lane_idx / 8 + fx * 16 + j * 4 ) * 8 + lane_idx % 8; + float16_t *s = reinterpret_cast(smem_ptr); printf("[load q from global] (x=%d,z=%d,j=%d), q_smem[%d, %d](%f..%f) = q[H=%d,N_CTX=%d, %d](%f..%f)\n", threadIdx.x, threadIdx.z, j, lane_idx / 8 + j * 4, lane_idx % 8, (float)(*(s)), (float)(*(s+7)), 0, lane_idx / 8 + j * 4, (lane_idx % 8) * 8, (float)q_ptr[0], (float)q_ptr[7]); // #endif @@ -594,7 +599,7 @@ __device__ __forceinline__ void compute_qk( // hence // TODO (yiakwy) : if we change blckDim.x from 32 to 64 - uint32_t lane_id = threadIdx.x + threadIdx.z * 32; + uint32_t lane_id = threadIdx.x + threadIdx.z * blockDim.x; uint32_t lane_id_x = lane_id % 16; uint32_t lane_id_y = lane_id / 16; @@ -619,6 +624,10 @@ __device__ __forceinline__ void compute_qk( // compute q*k^T #ifdef USE_ROCM + if (lane_id > 64) { + return; + } + #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { @@ -626,10 +635,6 @@ __device__ __forceinline__ void compute_qk( #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { - if (lane_id >= 64) { - continue; - } - // NOTE (yiakwy) : q_smem has shape of (num_frags_x, 16, 8x8), v_mfma_m16n16k16_fp16 will be applied 4 times along feat dim b128_t* smem_ptr = q_smem->base + (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; float16_t *s = reinterpret_cast(smem_ptr); @@ -648,7 +653,6 @@ __device__ __forceinline__ void compute_qk( printf("[compute_qk] s[%d, %d]=%f..%f\n", nv_mtx_frag_thread_load_row_offset, nv_mtx_frag_thread_load_col_offset, ref_0, ref_1); } - #endif // TODO (yiakwy) : replaced with more efficient load instruction @@ -660,6 +664,9 @@ __device__ __forceinline__ void compute_qk( (*a)[j] = *(s); } + + *q_smem_offset_r = + q_smem->template advance_offset_by_row<16, channel_size_128b_q>(*q_smem_offset_r); } // num_frags_x // NOTE(yiakwy) : next to 16 = 2x8 columns @@ -670,10 +677,6 @@ __device__ __forceinline__ void compute_qk( #pragma unroll for (uint32_t fz = 0; fz < num_frags_z; ++fz) { - if (lane_id >= 64) { - continue; - } - if constexpr (sizeof(DTypeKV) == 1) { assert(0 && "KV Cache with FP8 data type is not supported in ROCM"); } @@ -687,14 +690,9 @@ __device__ __forceinline__ void compute_qk( #pragma unroll for (int j=0; j < 4; j++) { // NOTE (yiakwy) : loads 16 consecutive data of 1 row - s += lane_id_x + lane_id_y * MTX_FRAG_LDB * 4 + j * MTX_FRAG_LDB + fz * 16; + s += lane_id_x + lane_id_y * MTX_FRAG_LDB * 4 + j * MTX_FRAG_LDB; (*b)[j] = *s; - - // TODO (yiakwy) : REMOVE - if (fy==0 && fz== 0) { - printf("[compute_qk] (lane_id_x=%d, lane_id_y=%d, j=%d), [compute_qk] b_frag[%d]=%f\n", lane_id_x, lane_id_y, j, j, (float)((*b)[j])); - } } *k_smem_offset_r = @@ -731,8 +729,8 @@ __device__ __forceinline__ void compute_qk( // // frag_a: // Dtype=fp16/bf16 - // cols 0 .. 15 16 .. 31 32 63 - // frag_x\frag_y rows 0 1 2 .. 4 + // cols 0 .. 15 16 .. 31 32 .. 63 + // frag_x\frag_y rows 0 1 2 .. 3 // 0 0 // .. // 15 @@ -740,6 +738,13 @@ __device__ __forceinline__ void compute_qk( // //frag_b // + // cols 0 .. 15 16 .. 31 32 .. 63 + // frag_z\frag_y rows 0 1 .. 2 .. 3 + // 0 0 + // .. + // 15 + // 1 16 + // for (uint32_t fy = 0; fy < num_frags_y; ++fy) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index 24520acd..5640e59f 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -135,7 +135,7 @@ template void TestSinglePrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction) { float rtol = std::is_same::value ? 1e-2 : 1e-3; float atol = std::is_same::value ? 1e-2 : 1e-3; - for (size_t qkv_len : {4}) { // for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { + for (size_t qkv_len : {2}) { // for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { for (size_t num_qo_heads : {1}) { // for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {1}) { // for (size_t num_kv_heads : {4, 8, 32}) { for (size_t head_dim : {64}) { // for (size_t head_dim : {64, 128, 256}) { From 4a16a90a55e0c0558371dc4270320ed2a6e3b9bd Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Mon, 14 Oct 2024 16:50:26 +0000 Subject: [PATCH 12/15] update kv fragment thread mapping --- include/flashinfer/attention/prefill.cuh | 167 +++++++++++++++++++---- src/test_single_prefill.cu | 2 +- 2 files changed, 140 insertions(+), 29 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 911df5f1..be382189 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -205,23 +205,52 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; if constexpr (swizzle_mode == SwizzleMode::k128B) { + // NOTE(yiakwy) : each thread of a 32 threads block, cooperatively load 128 bit (uint4/float4/halfx8) data from system memory to shared memory + // kvsmem shape =(num_frags_z x 16, (num_frags_y / 4) * 64) + // -- num_frags_y --> + // kvsmem warps row/col 0 1 ... 7 + // 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... 60 61 62 63 + // 0 0 + // 1 + // 2 + // 3 + // 1 0+4*1 + // .. .. + // 3 0+4*3 + // 1+4*3 + // 2+4*3 + // 3+4*3 + // uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; // NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps static_assert(num_frags_z * 4 % num_warps_x == 0); + + // NOTE (yiakwy) : for kv = (1/*head*/, 16/*seq*/, 64), at least 128 rows will be loaded #pragma unroll - for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { + for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x && ((warp_idx * 4 + lane_idx / 8 + i * 16) < kv_len); ++i) { // for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { #pragma unroll for (uint32_t j = 0; j < num_frags_y / (8 / sizeof(T)); ++j) { + + // NOTE (yiakwy) : kvsmem[warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8] = kv[0, warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8] smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); + + b128_t* smem_ptr = smem.base + (warp_idx * 4 + lane_idx / 8 + i * 16) * 8 + lane_idx % 8 + j * 8; + float16_t *s = reinterpret_cast(smem_ptr); + printf("[produce_kv] (i=%d,j=%d,warp_idx=%d), kv_smem[%d, %d]=kv[H=0, N_CTX=%d/%d, %d](%f..%f)\n", i, j, warp_idx, warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8, warp_idx * 4 + lane_idx / 8 + i * 16, kv_len, lane_idx % 8 + j * 8, (float)(*s), (float)(*(s+8))); + *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * num_elems_per_128b(); } kv_idx += num_warps * 4; + + // NOTE (yiakwy) : reset columns offset, ahead to next 16 rows *smem_offset = smem.template advance_offset_by_row(*smem_offset) - sizeof(T) * num_frags_y; + // NOTE (yiakwy) : reset columns offset, ahead to next 16 rows *gptr += num_warps * 4 * kv_stride_n - sizeof(T) * num_frags_y * num_elems_per_128b(); } + // NOTE (yiakwy) : reset kv smem pointer *smem_offset -= num_warps_z * num_frags_z * 16 * channel_size_128b_kv; } else { uint32_t kv_idx = kv_idx_base + warp_idx * 8 + lane_idx / 4; @@ -344,8 +373,10 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, if (get_warp_idx_z() == 0) { + // TODO (yiakwy) : only half a warp concurrency if blockDim.x == 32 in ROCm platform + // NOTE(yiakwy) : each thread of a 32 threads block, cooperatively load 128 bit (uint4/float4/halfx8) data from system memory to shared memory - // qsmem shape = (_, 128 Byte) + // qsmem shape = (_, (num_frags_y / 4) * 64 /*hidden_size*/) // -- frags y -> // qsmem row/col 0 1 ... 7 warp_idx {0..3} // 0 0 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 ... 60 61 62 63 0 | @@ -598,13 +629,17 @@ __device__ __forceinline__ void compute_qk( uint32_t b_frag[2]; // hence - // TODO (yiakwy) : if we change blckDim.x from 32 to 64 + // TODO (yiakwy) : what if we change blckDim.x from 32 to 64 uint32_t lane_id = threadIdx.x + threadIdx.z * blockDim.x; uint32_t lane_id_x = lane_id % 16; uint32_t lane_id_y = lane_id / 16; + // TODO (yiakwy) : replace these variables later uint32_t warp_idx_x = get_warp_idx_x<1, 4>(); + uint32_t warp_idx_z = get_warp_idx_z<1, 4>(); + + uint32_t warp_idx = get_warp_idx<1, 4>(); using float16_t = rocwmma::float16_t; @@ -619,14 +654,12 @@ __device__ __forceinline__ void compute_qk( // NOTE(yiakwy) : each thread of 32=8x4 threads block, cooperatively loads 2 x fp16/bf16 data, and repeat 4 (x4) times in 4 warps to cover 16x16 matrix frag uint32_t a_frag[num_frags_x][4], b_frag[4]; - #endif + #endif // USE_ROCM // compute q*k^T #ifdef USE_ROCM - if (lane_id > 64) { - return; - } + if (lane_id < 64U) { #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { @@ -639,30 +672,20 @@ __device__ __forceinline__ void compute_qk( b128_t* smem_ptr = q_smem->base + (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; float16_t *s = reinterpret_cast(smem_ptr); - float16x4 *a = reinterpret_cast(a_frag[fx][fy]); - - #ifdef DEBUG - if (lane_id < 32) { - uint32_t nv_lane_id = threadIdx.x; - uint32_t nv_mtx_frag_thread_load_row_offset = nv_lane_id % 16 * 16; - uint32_t nv_mtx_frag_thread_load_col_offset = nv_lane_id / 16 * 8 + fy * 16; - uint32_t nv_mtx_frag_thread_load_offset = nv_mtx_frag_thread_load_row_offset * 64 + nv_mtx_frag_thread_load_col_offset; - - float ref_0 = (float)(*(s+ nv_mtx_frag_thread_load_offset)); - float ref_1 = (float)(*(s+ nv_mtx_frag_thread_load_offset + 8)); - - printf("[compute_qk] s[%d, %d]=%f..%f\n", nv_mtx_frag_thread_load_row_offset, nv_mtx_frag_thread_load_col_offset, ref_0, ref_1); - } - #endif + float16x4 *a = reinterpret_cast(a_frag[fx]); // TODO (yiakwy) : replaced with more efficient load instruction #pragma unroll for (uint32_t j=0; j < 4; j++) { - // NOTE (yiakwy) : loads 1 columns (16xfp16) of data + // NOTE (yiakwy) : 16 threads loads 4 columns (16x4fp16) of data cooperatively uint32_t offset = lane_id_x * MTX_FRAG_LDA + j + lane_id_y * 4; s += offset; (*a)[j] = *(s); + + if (j==0) { + printf("[compute_qk] (fx=%d, fy=%d) a_frag[%d][%d]=%f\n", fx, fy, lane_id_x, j + lane_id_y * 4, (float)((*a)[j])); + } } *q_smem_offset_r = @@ -681,20 +704,25 @@ __device__ __forceinline__ void compute_qk( assert(0 && "KV Cache with FP8 data type is not supported in ROCM"); } - b128_t* smem_ptr = k_smem->base + (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; + b128_t* smem_ptr = k_smem->base + (warp_idx * 4) * channel_size_128b_q; float16_t *s = reinterpret_cast(smem_ptr); float16x4 *b = reinterpret_cast(b_frag); - // TODO (yiakwy) : replaced with more efficient load instruction + // TODO (yiakwy) : replaced with more efficient load inst #pragma unroll for (int j=0; j < 4; j++) { // NOTE (yiakwy) : loads 16 consecutive data of 1 row s += lane_id_x + lane_id_y * MTX_FRAG_LDB * 4 + j * MTX_FRAG_LDB; (*b)[j] = *s; + + if (j==0) { + printf("[compute_qk] (fz=%d, fy=%d) b_frag[%d][%d]=%f\n", fz, fy, lane_id_y * 4 + j, lane_id_x, (float)((*b)[j])); + } } + // NOTE(yiakwy) : k is still in row-major layout *k_smem_offset_r = k_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*k_smem_offset_r); @@ -707,6 +735,15 @@ __device__ __forceinline__ void compute_qk( if constexpr (std::is_same::value) { *d = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *d, 0, 0, 0); + + __asm__ volatile("s_barrier" ::); + if (fx == 0 && fy == 0 && fz == 0) { + + for (int reg_id=0; reg_id < 4; reg_id++) { + printf("[compute_qk] s_frag[fx=%d][fy=0][fz=%d][%d, %d] = %f\n", fx, fz, reg_id * 16 + lane_id_y * 4, lane_id_x, (*d)[reg_id]); + } + + } } else { // TODO (yiakwy) : device cast fp32 to fp16 assert(0 && "AMD v_mfma instruction does not support fp16 output."); @@ -722,10 +759,11 @@ __device__ __forceinline__ void compute_qk( } } + } // if lane_id < 64 #else #pragma unroll - // NOTE(yiakwy) each thead read 2 elments and repeat 4xnum_frags_y times , threads cooperatively loads 16x64 + // NOTE(yiakwy) each thead read 2 elments and repeat 4xnum_frags_y times , threads cooperatively loads 16x64 fp16 elements // // frag_a: // Dtype=fp16/bf16 @@ -737,7 +775,7 @@ __device__ __forceinline__ void compute_qk( // 1 16 // //frag_b - // + //Dtype=fp16/bf16 // cols 0 .. 15 16 .. 31 32 .. 63 // frag_z\frag_y rows 0 1 .. 2 .. 3 // 0 0 @@ -818,11 +856,23 @@ __device__ __forceinline__ void compute_qk( for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + +#ifdef USE_ROCM + +#pragma unroll + for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { + s_frag[fx][fz][reg_id] = + apply_logits_post_hook(s_frag[fx][fz][reg_id], soft_cap); + } + +#else #pragma unroll for (uint32_t reg_id = 0; reg_id < 8; ++reg_id) { s_frag[fx][fz][reg_id] = apply_logits_post_hook(s_frag[fx][fz][reg_id], soft_cap); } +#endif // USE_ROCM + } } } else { @@ -831,11 +881,22 @@ __device__ __forceinline__ void compute_qk( for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #pragma unroll for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + +#ifdef USE_ROCM + + for (uint32_t reg_id = 0; reg_id < 2; ++reg_id) { + *(half2*)(&s_frag[fx][fz][reg_id * 2]) = apply_logits_post_hook( + *(half2*)(&s_frag[fx][fz][reg_id * 2]), soft_cap); + } + +#else #pragma unroll for (uint32_t reg_id = 0; reg_id < 4; ++reg_id) { *(half2*)(&s_frag[fx][fz][reg_id * 2]) = apply_logits_post_hook( *(half2*)(&s_frag[fx][fz][reg_id * 2]), soft_cap); } +#endif + } } } @@ -1401,10 +1462,22 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC smem_t qo_smem(smem); // TODO (yiakwy) : to be used by load_q_global_smem, double check to compute offset of q + #ifdef USE_ROCM + DTypeQ* q_ptr_base = + q + qkv_info.get_q_elem_offset(0, kv_head_idx * group_size, 0/*threads related offset computed in function blocks*/); + #else DTypeQ* q_ptr_base = q + qkv_info.get_q_elem_offset(0, kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b()); + #endif // USE_ROCM-q + #ifdef USE_ROCM + DTypeOut* o_ptr_base = + partition_kv + ? o + chunk_idx * num_qo_heads * head_dim + + qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, 0) + : o + qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, 0); + #else DTypeOut* o_ptr_base = partition_kv ? o + chunk_idx * num_qo_heads * head_dim + @@ -1412,9 +1485,18 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC (lane_idx % 8) * num_elems_per_128b()) : o + qkv_info.get_o_elem_offset(0, kv_head_idx * group_size, (lane_idx % 8) * num_elems_per_128b()); + #endif // USE_ROCM-o + + // TODO (yiakwy) : refactor + #ifdef USE_ROCM + // used by compute_qk for reading smem + uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( + get_warp_idx_x() * num_frags_x * 16, 0); + #else uint32_t q_smem_offset_r = qo_smem.template get_permuted_offset( get_warp_idx_x() * num_frags_x * 16 + lane_idx % 16, - lane_idx / 16); + lane_idx / 16); + #endif if (threadIdx.x == 0 && threadIdx.z == 0) { printf("[prefill kernel] channel_size_128b_q = %d\n", channel_size_128b_q); @@ -1488,6 +1570,30 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC : chunk_size) / (16 * num_warps_z * num_frags_z); + // TODO (yiakwy) : refactor + #ifdef USE_ROCM + + DTypeKV* k_ptr = + k + qkv_info.get_kv_elem_offset( + chunk_start + warp_idx * kv_frag_rows + 0/* nvgpu : (lane_idx / 8) */, kv_head_idx, + 0/* nvgpu : (lane_idx % 8 ) * 8 */); + DTypeKV* v_ptr = + v + qkv_info.get_kv_elem_offset( + chunk_start + warp_idx * kv_frag_rows + 0, kv_head_idx, + 0); + // NOTE (yiakwy) : _w is used for storing (produce_kv) and _r is used for reading (compute_qk) + // NOTE (yiakwy) : We reuse NV GPU uint4 loading layout + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( + get_warp_idx_z() * num_frags_z * 16 + + 0/* nvgpu : (lane_idx / 16) * 8 + lane_idx % 8 */, + 0/* nvgpu : (lane_idx % 16) / 8) => {0, 1}*/), + v_smem_offset_r = v_smem.template get_permuted_offset( + get_warp_idx_z() * num_frags_z * 16 + 0 /* nvgpu : lane_idx % 16 => {0..15} */, + 0/* nvgpu : lane_idx / 16 => {0, 1} */), + kv_smem_offset_w = k_smem.template get_permuted_offset( + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); + + #else DTypeKV* k_ptr = k + qkv_info.get_kv_elem_offset( chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, kv_head_idx, @@ -1505,6 +1611,8 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC lane_idx / 16), kv_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); + #endif // USE_ROCM + produce_kv( k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, 0, chunk_size); cp_async::commit_group(); @@ -1519,11 +1627,14 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC block.sync(); if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { + // TODO (yiakwy) : recover + /* k_smem_inplace_apply_rotary( chunk_start + iter * 16 * num_warps_z * num_frags_z, &k_smem, &k_smem_offset_r, rope_freq); block.sync(); + */ } // compute attention score diff --git a/src/test_single_prefill.cu b/src/test_single_prefill.cu index 5640e59f..2d603898 100644 --- a/src/test_single_prefill.cu +++ b/src/test_single_prefill.cu @@ -135,7 +135,7 @@ template void TestSinglePrefillKernelShortContextCorrectness(bool allow_fp16_qk_reduction) { float rtol = std::is_same::value ? 1e-2 : 1e-3; float atol = std::is_same::value ? 1e-2 : 1e-3; - for (size_t qkv_len : {2}) { // for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { + for (size_t qkv_len : {16}) { // for (size_t qkv_len : {2, 3, 5, 7, 11, 13, 17, 19, 23, 29, 31, 37}) { for (size_t num_qo_heads : {1}) { // for (size_t num_qo_heads : {32}) { for (size_t num_kv_heads : {1}) { // for (size_t num_kv_heads : {4, 8, 32}) { for (size_t head_dim : {64}) { // for (size_t head_dim : {64, 128, 256}) { From 61dfb1b7d39d2ab321bd88e706d38526e3a38aa8 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Tue, 15 Oct 2024 10:20:37 +0000 Subject: [PATCH 13/15] update compute_qk logging info --- include/flashinfer/attention/prefill.cuh | 60 ++++++++++++++++++------ 1 file changed, 45 insertions(+), 15 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index be382189..2984850f 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -234,9 +234,12 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* // NOTE (yiakwy) : kvsmem[warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8] = kv[0, warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8] smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); - b128_t* smem_ptr = smem.base + (warp_idx * 4 + lane_idx / 8 + i * 16) * 8 + lane_idx % 8 + j * 8; + T* kv_base_r = *gptr; + T* kv_ptr = kv_base_r + (/* warp_idx * 4 + lane_idx / 8 */ + i * 16) * kv_stride_n + /* lane_idx % 8 */ + j * 8; + + b128_t* smem_ptr = smem.base + (/* warp_idx * 4 + lane_idx / 8 */ + i * 16) * 8 + /* lane_idx % 8 */ + j * 8; float16_t *s = reinterpret_cast(smem_ptr); - printf("[produce_kv] (i=%d,j=%d,warp_idx=%d), kv_smem[%d, %d]=kv[H=0, N_CTX=%d/%d, %d](%f..%f)\n", i, j, warp_idx, warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8, warp_idx * 4 + lane_idx / 8 + i * 16, kv_len, lane_idx % 8 + j * 8, (float)(*s), (float)(*(s+8))); + printf("[produce_kv] (i=%d,j=%d,warp_idx=%d), kv_smem[%d, %d] (%f..%f) = kv[H=0, N_CTX=%d/%d, %d](%f..%f,%f)\n", i, j, warp_idx, warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8, (float)(*s), (float)(*(s+7)), warp_idx * 4 + lane_idx / 8 + i * 16, kv_len, lane_idx % 8 + j * 8, (float)(*kv_ptr), (float)(*(kv_ptr+6)), (float)(*(kv_ptr+7))); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * num_elems_per_128b(); @@ -647,6 +650,7 @@ __device__ __forceinline__ void compute_qk( using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; #define MTX_FRAG_LDA (head_dim) + // #define MTX_FRAG_LDB (head_dim) #define MTX_FRAG_LDB (num_frags_z * 16) #else @@ -683,8 +687,8 @@ __device__ __forceinline__ void compute_qk( (*a)[j] = *(s); - if (j==0) { - printf("[compute_qk] (fx=%d, fy=%d) a_frag[%d][%d]=%f\n", fx, fy, lane_id_x, j + lane_id_y * 4, (float)((*a)[j])); + if (fx == 0 && fy == 0) { + printf("[compute_qk] (x=%d, y=%d, j=%d) (fx=%d, fy=%d) a_mtx_frag[%d, %d]=%f\n", lane_id_x, lane_id_y, j, fx, fy, lane_id_x, j + lane_id_y * 4, (float)((*a)[j])); } } @@ -711,14 +715,15 @@ __device__ __forceinline__ void compute_qk( // TODO (yiakwy) : replaced with more efficient load inst #pragma unroll - for (int j=0; j < 4; j++) { + for (uint32_t j=0; j < 4; j++) { // NOTE (yiakwy) : loads 16 consecutive data of 1 row s += lane_id_x + lane_id_y * MTX_FRAG_LDB * 4 + j * MTX_FRAG_LDB; + // s += lane_id_x * MTX_FRAG_LDB + j + lane_id_y * 4; (*b)[j] = *s; - if (j==0) { - printf("[compute_qk] (fz=%d, fy=%d) b_frag[%d][%d]=%f\n", fz, fy, lane_id_y * 4 + j, lane_id_x, (float)((*b)[j])); + if (fy == 0 && fz == 0) { + printf("[compute_qk] (x=%d, y=%d, j=%d) (fz=%d, fy=%d) b_mtx_frag[%d, %d]=%f\n", lane_id_x, lane_id_y, j, fz, fy, lane_id_y * 4 + j, lane_id_x, (float)((*b)[j])); } } @@ -731,23 +736,21 @@ __device__ __forceinline__ void compute_qk( float16x4 *a = reinterpret_cast(a_frag[fx]); float16x4 *b = reinterpret_cast(b_frag); - floatx4 *d = reinterpret_cast(s_frag[fx][fz]); - if constexpr (std::is_same::value) { + floatx4 *d = reinterpret_cast(s_frag[fx][fz]); *d = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *d, 0, 0, 0); - __asm__ volatile("s_barrier" ::); + if (fx == 0 && fy == 0 && fz == 0) { - for (int reg_id=0; reg_id < 4; reg_id++) { - printf("[compute_qk] s_frag[fx=%d][fy=0][fz=%d][%d, %d] = %f\n", fx, fz, reg_id * 16 + lane_id_y * 4, lane_id_x, (*d)[reg_id]); + for (uint32_t reg_id=0; reg_id < 4; reg_id++) { + printf("[compute_qk] (x=%d, y=%d, reg_id=%d) s_frag[fx=%d][fy=0][fz=%d][%d, %d] = %f\n", lane_id_x, lane_id_y, reg_id, fx, fz, reg_id + lane_id_y * 4, lane_id_x, (*d)[reg_id]); } } } else { // TODO (yiakwy) : device cast fp32 to fp16 assert(0 && "AMD v_mfma instruction does not support fp16 output."); - // *d = __builtin_amdgcn_mfma_f16_16x16x16f16(*a, *b, *d, 0, 0, 0); } } } @@ -1499,7 +1502,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC #endif if (threadIdx.x == 0 && threadIdx.z == 0) { - printf("[prefill kernel] channel_size_128b_q = %d\n", channel_size_128b_q); + printf("[single prefill kernel] channel_size_128b_q = %d\n", channel_size_128b_q); } // NOTE(yiakwy) : FA2 outter loop (block level) load q first and iterate over sequence dimension inside a block @@ -1517,7 +1520,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC sm_scale); } else { if (threadIdx.x==0 && threadIdx.z==0) { - printf("[prefill kernel] skip q_smem_inplace_multiply_sm_scale.\n"); + printf("[single prefill kernel] skip q_smem_inplace_multiply_sm_scale.\n"); } // TODO (yiakwy) : recover /* @@ -1613,9 +1616,16 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); #endif // USE_ROCM + if (threadIdx.x==0 && threadIdx.z==0) { + printf("[single prefill kernel] ===== producing key =====\n"); + } produce_kv( k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, 0, chunk_size); cp_async::commit_group(); + + if (threadIdx.x==0 && threadIdx.z==0) { + printf("[single prefill kernel] ***** producing value *****\n"); + } produce_kv( v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, 0, chunk_size); cp_async::commit_group(); @@ -1638,6 +1648,9 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC } // compute attention score + if (threadIdx.x == 0 && threadIdx.z == 0) { + printf("[single prefill kernel] start calling compute_qk...\n"); + } compute_qk(&qo_smem, &q_smem_offset_r, &k_smem, &k_smem_offset_r, s_frag, logits_soft_cap); @@ -1670,22 +1683,39 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC update_mdo_states(s_frag, o_frag, m, d); block.sync(); + + // TODO (yiakwy) : REMOVE + if (threadIdx.x == 0 && threadIdx.z == 0) { + printf("[single prefill kernel] calling pdate_mdo_states completes."); + } + + // NOTE (yiakwy) : prepare the next loading + if (iter + 1 < num_iterations) { + produce_kv( k_smem, &kv_smem_offset_w, &k_ptr, kv_stride_n, (iter + 1) * 16 * num_warps_z * num_frags_z, chunk_size); cp_async::commit_group(); cp_async::wait_group<1>(); block.sync(); + + } // compute sfm*v compute_sfm_v( &v_smem, &v_smem_offset_r, s_frag, o_frag, d); block.sync(); + + // NOTE (yiakwy) : prepare the next loading + if (iter + 1 < num_iterations) { + produce_kv( v_smem, &kv_smem_offset_w, &v_ptr, kv_stride_n, (iter + 1) * 16 * num_warps_z * num_frags_z, chunk_size); cp_async::commit_group(); + + } } cp_async::wait_group<0>(); From 4856dcf144f917c2b7c2ff787c334f36c22ce7e6 Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Mon, 21 Oct 2024 11:39:51 +0000 Subject: [PATCH 14/15] update kv fragment warps mapping (add boundary checking, e.g. qkv shape=(1,16,64)) --- include/flashinfer/attention/prefill.cuh | 86 ++++++++++++++---------- src/utils.h | 2 +- 2 files changed, 52 insertions(+), 36 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index 2984850f..ec6bfbac 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -202,7 +202,10 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* constexpr uint32_t head_dim = num_frags_y * 16; constexpr uint32_t num_warps = num_warps_x * num_warps_z; constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); - const uint32_t warp_idx = get_warp_idx(), lane_idx = threadIdx.x; + // TODO (yiakwy) : compute it; + constexpr uint32_t kv_frag_cols = 8; + const uint32_t warp_idx = get_warp_idx(); + const uint32_t warp_idx_z = warp_idx, lane_idx = threadIdx.x; if constexpr (swizzle_mode == SwizzleMode::k128B) { // NOTE(yiakwy) : each thread of a 32 threads block, cooperatively load 128 bit (uint4/float4/halfx8) data from system memory to shared memory @@ -221,25 +224,26 @@ __device__ __forceinline__ void produce_kv(smem_t smem, uint32_t* // 2+4*3 // 3+4*3 // - uint32_t kv_idx = kv_idx_base + warp_idx * 4 + lane_idx / 8; + uint32_t kv_idx = kv_idx_base + warp_idx_z * 4/*kv_frag_rows*/ + lane_idx / 8; // NOTE(Zihao): num_frags_z * 4 / num_warps_x = num_warps_z * num_frags_z * 4 / num_warps static_assert(num_frags_z * 4 % num_warps_x == 0); + + T* kv_base_r = *gptr; + uint32_t kv_offset_r = ( lane_idx % kv_frag_cols ) * channel_size_128b_kv + (lane_idx / kv_frag_cols ) * kv_stride_n; // NOTE (yiakwy) : for kv = (1/*head*/, 16/*seq*/, 64), at least 128 rows will be loaded #pragma unroll - for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x && ((warp_idx * 4 + lane_idx / 8 + i * 16) < kv_len); ++i) { // for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { + for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x && ((warp_idx_z * 4/*kv_frag_rows*/ + lane_idx / 8/*kv_frag_cols*/ + i * 16) < kv_len); ++i) { // for (uint32_t i = 0; i < num_frags_z * 4 / num_warps_x; ++i) { #pragma unroll for (uint32_t j = 0; j < num_frags_y / (8 / sizeof(T)); ++j) { // NOTE (yiakwy) : kvsmem[warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8] = kv[0, warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8] - smem.template load_128b_async(*smem_offset, *gptr, kv_idx < kv_len); - - T* kv_base_r = *gptr; - T* kv_ptr = kv_base_r + (/* warp_idx * 4 + lane_idx / 8 */ + i * 16) * kv_stride_n + /* lane_idx % 8 */ + j * 8; + smem.template load_128b_async(*smem_offset, kv_base_r + kv_offset_r, kv_idx < kv_len); - b128_t* smem_ptr = smem.base + (/* warp_idx * 4 + lane_idx / 8 */ + i * 16) * 8 + /* lane_idx % 8 */ + j * 8; + T* kv_ptr = kv_base_r + kv_offset_r + (/* warp_idx * 4 + lane_idx / 8 */ + i * 16) * kv_stride_n + /* lane_idx % 8 */ + j * 8; + b128_t* smem_ptr = smem.base + *smem_offset + (/* warp_idx * 4 + lane_idx / 8 */ + i * 16) * 8/* 64=8x8 */ + /* lane_idx % 8 */ + j * 8; float16_t *s = reinterpret_cast(smem_ptr); - printf("[produce_kv] (i=%d,j=%d,warp_idx=%d), kv_smem[%d, %d] (%f..%f) = kv[H=0, N_CTX=%d/%d, %d](%f..%f,%f)\n", i, j, warp_idx, warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8, (float)(*s), (float)(*(s+7)), warp_idx * 4 + lane_idx / 8 + i * 16, kv_len, lane_idx % 8 + j * 8, (float)(*kv_ptr), (float)(*(kv_ptr+6)), (float)(*(kv_ptr+7))); + printf("[produce_kv] (i=%d,j=%d,warp_idx=%d, x = %d, z = %d), kv_smem[%d, %d] (%f..%f) = kv[H=0, N_CTX=%d/%d, %d](%f..%f,%f)\n", i, j, warp_idx, threadIdx.x, threadIdx.z, warp_idx * 4 + lane_idx / 8 + i * 16, lane_idx % 8 + j * 8, (float)(*s), (float)(*(s+7)), warp_idx * 4 + lane_idx / 8 + i * 16, kv_len, lane_idx % 8 + j * 8, (float)(*kv_ptr), (float)(*(kv_ptr+6)), (float)(*(kv_ptr+7))); *smem_offset = smem.template advance_offset_by_column<8>(*smem_offset, j); *gptr += 8 * num_elems_per_128b(); @@ -404,7 +408,10 @@ __device__ __forceinline__ void load_q_global_smem(uint32_t packed_offset, const uint32_t q_idx = q; // NOTE (yiakwy) : q_ptr = q[bz/*head*/, bx{0} * num_rows_per_cta{16} + warp_idx_x * num_frags_x * 16 + lane_idx / 8 + j * 4 /*seqlen*/, 0/*hdim*/] + (lane_idx % 8) * 8 + /* DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h; + */ + DTypeQ* q_ptr = q_ptr_base + q * q_stride_n + r * q_stride_h + lane_idx % 8 * 8; #pragma unroll for (uint32_t fyo = 0; fyo < num_frags_y / 4; ++fyo) { @@ -632,17 +639,21 @@ __device__ __forceinline__ void compute_qk( uint32_t b_frag[2]; // hence - // TODO (yiakwy) : what if we change blckDim.x from 32 to 64 - uint32_t lane_id = threadIdx.x + threadIdx.z * blockDim.x; + // TODO (yiakwy) : z={0,1} is used for lane mappping, z={2,3} used for warps mapping what if we change blckDim.x from 32 to 64 + uint32_t lane_id = ( threadIdx.x + threadIdx.z * blockDim.x ) % 64 ; uint32_t lane_id_x = lane_id % 16; uint32_t lane_id_y = lane_id / 16; // TODO (yiakwy) : replace these variables later + /* uint32_t warp_idx_x = get_warp_idx_x<1, 4>(); uint32_t warp_idx_z = get_warp_idx_z<1, 4>(); uint32_t warp_idx = get_warp_idx<1, 4>(); + */ + uint32_t warp_idx_z = get_warp_idx_z<1, 4>(); + uint32_t warp64_idx_z = warp_idx_z / 2; using float16_t = rocwmma::float16_t; @@ -650,8 +661,7 @@ __device__ __forceinline__ void compute_qk( using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; #define MTX_FRAG_LDA (head_dim) - // #define MTX_FRAG_LDB (head_dim) - #define MTX_FRAG_LDB (num_frags_z * 16) + #define MTX_FRAG_LDB (head_dim) #else @@ -663,7 +673,8 @@ __device__ __forceinline__ void compute_qk( // compute q*k^T #ifdef USE_ROCM - if (lane_id < 64U) { + // if (lane_id < 64U) { + if (warp64_idx_z * num_frags_z * 16U < 16U/*kv_len*/ ) { #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { @@ -673,7 +684,7 @@ __device__ __forceinline__ void compute_qk( for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // NOTE (yiakwy) : q_smem has shape of (num_frags_x, 16, 8x8), v_mfma_m16n16k16_fp16 will be applied 4 times along feat dim - b128_t* smem_ptr = q_smem->base + (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; + b128_t* smem_ptr = q_smem->base + *q_smem_offset_r;// (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; float16_t *s = reinterpret_cast(smem_ptr); float16x4 *a = reinterpret_cast(a_frag[fx]); @@ -683,12 +694,11 @@ __device__ __forceinline__ void compute_qk( for (uint32_t j=0; j < 4; j++) { // NOTE (yiakwy) : 16 threads loads 4 columns (16x4fp16) of data cooperatively uint32_t offset = lane_id_x * MTX_FRAG_LDA + j + lane_id_y * 4; - s += offset; - (*a)[j] = *(s); + (*a)[j] = *(s + offset); if (fx == 0 && fy == 0) { - printf("[compute_qk] (x=%d, y=%d, j=%d) (fx=%d, fy=%d) a_mtx_frag[%d, %d]=%f\n", lane_id_x, lane_id_y, j, fx, fy, lane_id_x, j + lane_id_y * 4, (float)((*a)[j])); + printf("[compute_qk] (x=%d, y=%d, z=%d) (lane_id_x=%d, lane_id_y=%d, j=%d) (fx=%d, fy=%d) a_mtx_frag[%d, %d]=%f, *(s)=%f\n", threadIdx.x, threadIdx.y, threadIdx.z, lane_id_x, lane_id_y, j, fx, fy, lane_id_x, j + lane_id_y * 4, (float)((*a)[j]), (float)(*(s+offset))); } } @@ -708,7 +718,7 @@ __device__ __forceinline__ void compute_qk( assert(0 && "KV Cache with FP8 data type is not supported in ROCM"); } - b128_t* smem_ptr = k_smem->base + (warp_idx * 4) * channel_size_128b_q; + b128_t* smem_ptr = k_smem->base + *k_smem_offset_r; float16_t *s = reinterpret_cast(smem_ptr); float16x4 *b = reinterpret_cast(b_frag); @@ -717,13 +727,13 @@ __device__ __forceinline__ void compute_qk( #pragma unroll for (uint32_t j=0; j < 4; j++) { // NOTE (yiakwy) : loads 16 consecutive data of 1 row - s += lane_id_x + lane_id_y * MTX_FRAG_LDB * 4 + j * MTX_FRAG_LDB; - // s += lane_id_x * MTX_FRAG_LDB + j + lane_id_y * 4; + uint32_t offset = lane_id_x + (lane_id_y * 4 + j) * MTX_FRAG_LDB; + // uint32_t offset = lane_id_x * MTX_FRAG_LDB + j + lane_id_y * 4; - (*b)[j] = *s; + (*b)[j] = *(s+offset); if (fy == 0 && fz == 0) { - printf("[compute_qk] (x=%d, y=%d, j=%d) (fz=%d, fy=%d) b_mtx_frag[%d, %d]=%f\n", lane_id_x, lane_id_y, j, fz, fy, lane_id_y * 4 + j, lane_id_x, (float)((*b)[j])); + printf("[compute_qk] (x=%d, y=%d, z=%d) (lane_id_x=%d, lane_id_y=%d, j=%d) (fz=%d, fy=%d) b_mtx_frag[%d, %d]=%f\n", threadIdx.x, threadIdx.y, threadIdx.z, lane_id_x, lane_id_y, j, fz, fy, lane_id_y * 4 + j, lane_id_x, (float)((*b)[j])); } } @@ -762,7 +772,8 @@ __device__ __forceinline__ void compute_qk( } } - } // if lane_id < 64 + // } // if lane_id < 64 + } // if warp64_idx_z * num_frags_z * 16 < kv_len #else #pragma unroll @@ -1579,23 +1590,28 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC DTypeKV* k_ptr = k + qkv_info.get_kv_elem_offset( chunk_start + warp_idx * kv_frag_rows + 0/* nvgpu : (lane_idx / 8) */, kv_head_idx, - 0/* nvgpu : (lane_idx % 8 ) * 8 */); + 0/* nvgpu : (lane_idx % 8 ) * 8 */); DTypeKV* v_ptr = v + qkv_info.get_kv_elem_offset( chunk_start + warp_idx * kv_frag_rows + 0, kv_head_idx, 0); - // NOTE (yiakwy) : _w is used for storing (produce_kv) and _r is used for reading (compute_qk) - // NOTE (yiakwy) : We reuse NV GPU uint4 loading layout + + // NOTE (yiakwy) : _w is used for storing (produce_kv) and _r is used for reading (compute_qk), (32x2, warp_idz) + // NOTE (yiakwy) : We reuse NV GPU uint4 loading layout for writing + uint32_t warp_idx_z = get_warp_idx_z(); + uint32_t warp64_idx_z = warp_idx_z / 2; /*(32, 1, 2) threads to form a warp*/ + uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( - get_warp_idx_z() * num_frags_z * 16 + - 0/* nvgpu : (lane_idx / 16) * 8 + lane_idx % 8 */, - 0/* nvgpu : (lane_idx % 16) / 8) => {0, 1}*/), + warp64_idx_z * num_frags_z * 16 + + 0/* nvgpu ldmatrix layout : (lane_idx / 16) * 8 + lane_idx % 8 */, + 0/* nvgpu ldmatrix layout : (lane_idx % 16) / 8) => {0, 1}*/), v_smem_offset_r = v_smem.template get_permuted_offset( - get_warp_idx_z() * num_frags_z * 16 + 0 /* nvgpu : lane_idx % 16 => {0..15} */, - 0/* nvgpu : lane_idx / 16 => {0, 1} */), + warp64_idx_z * num_frags_z * 16 + + 0 /* nvgpu ldmatrix layout : lane_idx % 16 => {0..15} */, + 0/* nvgpu ldmatrix layout : lane_idx / 16 => {0, 1} */), kv_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); - + #else DTypeKV* k_ptr = k + qkv_info.get_kv_elem_offset( @@ -1606,11 +1622,11 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC chunk_start + warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, kv_head_idx, (lane_idx % kv_frag_cols) * num_elems_per_128b()); uint32_t k_smem_offset_r = k_smem.template get_permuted_offset( - get_warp_idx_z() * num_frags_z * 16 + + get_warp_idx_z() * num_frags_z * 16 + 8 * (lane_idx / 16) + lane_idx % 8, (lane_idx % 16) / 8), v_smem_offset_r = v_smem.template get_permuted_offset( - get_warp_idx_z() * num_frags_z * 16 + lane_idx % 16, + get_warp_idx_z() * num_frags_z * 16 + lane_idx % 16, lane_idx / 16), kv_smem_offset_w = k_smem.template get_permuted_offset( warp_idx * kv_frag_rows + lane_idx / kv_frag_cols, lane_idx % kv_frag_cols); diff --git a/src/utils.h b/src/utils.h index 6ad480c2..40dfc230 100644 --- a/src/utils.h +++ b/src/utils.h @@ -82,7 +82,7 @@ void vec_normal_(std::vector& vec, float mean = 0.f, float std = 1.f) { std::normal_distribution d{mean, std}; for (size_t i = 0; i < vec.size(); ++i) { // TODO (yiakwy) : RECOVER - vec[i] = T(1.f);//T(d(gen)); + vec[i] = T(1.f);//T(i);//T(d(gen)); } } From addf8b541604348387e8ddc0f54bb3ecc4aa808e Mon Sep 17 00:00:00 2001 From: yiakwy-xpu-ml-framework-team Date: Fri, 25 Oct 2024 17:34:39 +0000 Subject: [PATCH 15/15] add support of prefill (part 2). add support update_mdo_states for output product registers --- include/flashinfer/attention/prefill.cuh | 261 ++++++++++++++++++----- 1 file changed, 205 insertions(+), 56 deletions(-) diff --git a/include/flashinfer/attention/prefill.cuh b/include/flashinfer/attention/prefill.cuh index ec6bfbac..1de513c6 100644 --- a/include/flashinfer/attention/prefill.cuh +++ b/include/flashinfer/attention/prefill.cuh @@ -628,11 +628,10 @@ __device__ __forceinline__ void compute_qk( #ifdef USE_ROCM - // TODO (yiakwy) : REMOVE - if (threadIdx.x == 0 && threadIdx.z == 0) { - printf("[compute_qk] channel_size_128b_q=%d, channel_size_128b_kv=%d\n", channel_size_128b_q, channel_size_128b_kv); - printf("[compute_qk] num_frags_x=%d, num_frags_y=%d, num_frags_z=%d\n", num_frags_x, num_frags_y, num_frags_z); - } + using float16_t = rocwmma::float16_t; + + using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t; + using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; // NOTE(yiakwy) : each thread of 64=16x4 threads block, cooperatively loads 4 x consecutive fp16/bf16 data to cover 16x16 matrix frag uint32_t a_frag[num_frags_x][2]; @@ -642,23 +641,13 @@ __device__ __forceinline__ void compute_qk( // TODO (yiakwy) : z={0,1} is used for lane mappping, z={2,3} used for warps mapping what if we change blckDim.x from 32 to 64 uint32_t lane_id = ( threadIdx.x + threadIdx.z * blockDim.x ) % 64 ; + // TODO (yiakwy) : CONSTANTS uint32_t lane_id_x = lane_id % 16; uint32_t lane_id_y = lane_id / 16; - // TODO (yiakwy) : replace these variables later - /* - uint32_t warp_idx_x = get_warp_idx_x<1, 4>(); - uint32_t warp_idx_z = get_warp_idx_z<1, 4>(); - - uint32_t warp_idx = get_warp_idx<1, 4>(); - */ + // TODO (yiakwy) : CONSTANTS uint32_t warp_idx_z = get_warp_idx_z<1, 4>(); uint32_t warp64_idx_z = warp_idx_z / 2; - - using float16_t = rocwmma::float16_t; - - using float16x4 = __attribute__((__vector_size__(4 * sizeof(float16_t)))) float16_t; - using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; #define MTX_FRAG_LDA (head_dim) #define MTX_FRAG_LDB (head_dim) @@ -684,7 +673,7 @@ __device__ __forceinline__ void compute_qk( for (uint32_t fx = 0; fx < num_frags_x; ++fx) { // NOTE (yiakwy) : q_smem has shape of (num_frags_x, 16, 8x8), v_mfma_m16n16k16_fp16 will be applied 4 times along feat dim - b128_t* smem_ptr = q_smem->base + *q_smem_offset_r;// (warp_idx_x * num_frags_x * 16) * channel_size_128b_q; + b128_t* smem_ptr = q_smem->base + *q_smem_offset_r; float16_t *s = reinterpret_cast(smem_ptr); float16x4 *a = reinterpret_cast(a_frag[fx]); @@ -697,9 +686,11 @@ __device__ __forceinline__ void compute_qk( (*a)[j] = *(s + offset); + #if defined(DEBUG_PREFILL) || defined(DEBUG_PREFILL_COMPUTE_QK) if (fx == 0 && fy == 0) { printf("[compute_qk] (x=%d, y=%d, z=%d) (lane_id_x=%d, lane_id_y=%d, j=%d) (fx=%d, fy=%d) a_mtx_frag[%d, %d]=%f, *(s)=%f\n", threadIdx.x, threadIdx.y, threadIdx.z, lane_id_x, lane_id_y, j, fx, fy, lane_id_x, j + lane_id_y * 4, (float)((*a)[j]), (float)(*(s+offset))); } + #endif } *q_smem_offset_r = @@ -728,13 +719,14 @@ __device__ __forceinline__ void compute_qk( for (uint32_t j=0; j < 4; j++) { // NOTE (yiakwy) : loads 16 consecutive data of 1 row uint32_t offset = lane_id_x + (lane_id_y * 4 + j) * MTX_FRAG_LDB; - // uint32_t offset = lane_id_x * MTX_FRAG_LDB + j + lane_id_y * 4; (*b)[j] = *(s+offset); + #if defined(DEBUG_PREFILL) || defined(DEBUG_PREFILL_COMPUTE_QK) if (fy == 0 && fz == 0) { printf("[compute_qk] (x=%d, y=%d, z=%d) (lane_id_x=%d, lane_id_y=%d, j=%d) (fz=%d, fy=%d) b_mtx_frag[%d, %d]=%f\n", threadIdx.x, threadIdx.y, threadIdx.z, lane_id_x, lane_id_y, j, fz, fy, lane_id_y * 4 + j, lane_id_x, (float)((*b)[j])); } + #endif } // NOTE(yiakwy) : k is still in row-major layout @@ -749,15 +741,20 @@ __device__ __forceinline__ void compute_qk( if constexpr (std::is_same::value) { floatx4 *d = reinterpret_cast(s_frag[fx][fz]); *d = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *d, 0, 0, 0); - __asm__ volatile("s_barrier" ::); - if (fx == 0 && fy == 0 && fz == 0) { + // __asm__ volatile("s_barrier" ::); + __builtin_amdgcn_s_waitcnt(0); + __builtin_amdgcn_s_barrier(); + + // #if defined(DEBUG_PREFILL) || defined(DEBUG_PREFILL_COMPUTE_QK) + if (fx == 0 && fy == 3 && fz == 0) { for (uint32_t reg_id=0; reg_id < 4; reg_id++) { - printf("[compute_qk] (x=%d, y=%d, reg_id=%d) s_frag[fx=%d][fy=0][fz=%d][%d, %d] = %f\n", lane_id_x, lane_id_y, reg_id, fx, fz, reg_id + lane_id_y * 4, lane_id_x, (*d)[reg_id]); + printf("[compute_qk] (lane_id_x=%d, lane_id_y=%d, reg_id=%d) s_frag[fx=%d][fy=0][fz=%d][%d, %d] = %f\n", lane_id_x, lane_id_y, reg_id, fx, fz, reg_id + lane_id_y * 4, lane_id_x, (*d)[reg_id]); } } + // #endif } else { // TODO (yiakwy) : device cast fp32 to fp16 assert(0 && "AMD v_mfma instruction does not support fp16 output."); @@ -772,8 +769,10 @@ __device__ __forceinline__ void compute_qk( } } - // } // if lane_id < 64 } // if warp64_idx_z * num_frags_z * 16 < kv_len + + // NOTE(yiakwy) : we have threads not in USE, so we must synchrose the whole threads block before prceeding + __syncthreads(); #else #pragma unroll @@ -982,6 +981,15 @@ template ::value) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { @@ -990,15 +998,78 @@ __device__ __forceinline__ void update_mdo_states(DTypeQKAccum (*s_frag)[num_fra float m_prev = m[fx][j]; #pragma unroll for (uint32_t fz = 0; fz < num_frags_z; ++fz) { + + // NOTE(yiakwy) : should we reuse j ? + #ifdef USE_ROCM + + for (uint32_t i=0; i < 2; i++) { + // TODO (yiakwy) : check s_smem swizzle strategy + s_smem[j * 2 + i + lane_id_y * 4][lane_id_x] = s_frag[fx][fz][j * 2 + i]; + } + + // __asm__ volatile("s_barrier" ::); + __builtin_amdgcn_s_waitcnt(0); + __builtin_amdgcn_s_barrier(); + + // NOTE(yiakwy) at this moment, only half of 16x16 matrix filled + // rows / cols 0 1 2 3 .. 7 + // 0 0 1 2 3 4 5 6 .. 14 15 + // 1 16 17 18 19 20 21 22 .. 30 31 + // - - - - - - - .. - - + // - - - - - - - .. - - + // 4 64 65 66 67 68 69 70 .. 71 72 + // 5 ... + // ... + + // NOTE(yiakwy) : now we mimic CUDA mma rules (2 rows of registers per thread) to avoid update signature of m, d + // NOTE(yiakwy) : design decision, for 16x16 (implementation) fragment each thread process 4 elements, i.e. 2 elements per row (row 0, row 8 for example), 8 threads per row + // each row is mapped to 8 rows {0, 1, 4, 5, 8, 9, 12, 13} + // maybe we could have a good math, but let's get thing done quickly + constexpr uint32_t rows_map[8] = {0, 1, 4, 5, 8, 9, 12, 13}; + uint32_t reduceop_lane_id_x = rows_map[lane_id / 8] + j * 2; + uint32_t reduceop_lane_id_y = (lane_id % 8) * 2; + float m_local = max(s_smem[reduceop_lane_id_x][reduceop_lane_id_y], s_smem[reduceop_lane_id_x][reduceop_lane_id_y + 1]); + m[fx][j] = max(m[fx][j], m_local); + + if (fx == 0 && fz == 0) { + for (uint32_t i=0; i < 2; i++) { + printf("[update_mdo_states] (x = %d, y = %d, z = %d) , frag (fx=%d, fz=%d) (reduceop_lane_id_x=%d, reduceop_lane_id_y=%d, reg_id=%d) s_smem[%d][%d] = %f, m[%d][%d]= %f\n", threadIdx.x, threadIdx.y, threadIdx.z, fx, fz, reduceop_lane_id_x, reduceop_lane_id_y, j * 2 + i, reduceop_lane_id_x, reduceop_lane_id_y, s_smem[reduceop_lane_id_x][reduceop_lane_id_y + i], fx, j, m[fx][j]); + } + } + + #else + float m_local = max(max(s_frag[fx][fz][j * 2 + 0], s_frag[fx][fz][j * 2 + 1]), max(s_frag[fx][fz][j * 2 + 4], s_frag[fx][fz][j * 2 + 5])); m[fx][j] = max(m[fx][j], m_local); + + #endif // USE_ROCM + } - m[fx][j] = max(m[fx][j], math::shfl_xor_sync(m[fx][j], 0x2)); - m[fx][j] = max(m[fx][j], math::shfl_xor_sync(m[fx][j], 0x1)); + #ifdef USE_ROCM + m[fx][j] = max(m[fx][j], math::shfl_xor_sync(m[fx][j], 0x4)); // NOTE (yiakwy) : 8 -> 4 + #endif // USE_ROCM + m[fx][j] = max(m[fx][j], math::shfl_xor_sync(m[fx][j], 0x2)); // NOTE (yiakwy) : 4 -> 2 + m[fx][j] = max(m[fx][j], math::shfl_xor_sync(m[fx][j], 0x1)); // NOTE (yiakwy) : 2 -> 1 float o_scale = math::ptx_exp2(m_prev - m[fx][j]); d[fx][j] *= o_scale; + + #ifdef USE_ROCM + +#pragma unroll + for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + o_frag[fx][fy][j * 2 + 0] *= o_scale; + o_frag[fx][fy][j * 2 + 1] *= o_scale; + } +#pragma unroll + for (uint32_t fz = 0; fz < num_frags_z; ++ fz) { + // TODO (yiakwy) : check s_smem swizzle strategy + s_frag[fx][fz][j * 2 + 0] = math::ptx_exp2(s_smem[j * 2 + 0 + lane_id_y * 4][lane_id_x] - m[fx][j]); + s_frag[fx][fz][j * 2 + 1] = math::ptx_exp2(s_smem[j * 2 + 1 + lane_id_y * 4][lane_id_x] - m[fx][j]); + } + + #else #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { o_frag[fx][fy][j * 2 + 0] *= o_scale; @@ -1013,9 +1084,16 @@ __device__ __forceinline__ void update_mdo_states(DTypeQKAccum (*s_frag)[num_fra s_frag[fx][fz][j * 2 + 4] = math::ptx_exp2(s_frag[fx][fz][j * 2 + 4] - m[fx][j]); s_frag[fx][fz][j * 2 + 5] = math::ptx_exp2(s_frag[fx][fz][j * 2 + 5] - m[fx][j]); } + #endif // USE_ROCM } } } else if constexpr (std::is_same::value) { + + #ifdef USE_ROCM + // TODO (yiakwy) : remove assert + assert(0 && "[update_mdo_state] half output for accumulator is not supported yet, defaults to fp32 mixed precision!"); + #endif + #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { half m_prev[2]; @@ -1064,11 +1142,36 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); #ifdef USE_ROCM + + using float16_t = rocwmma::float16_t; + using float16x4 = __attribute__((__vector_size__(4 * sizeof(rocwmma::float16_t)))) rocwmma::float16_t; using floatx4 = __attribute__((__vector_size__(4 * sizeof(float)))) float; - #endif + + // TODO (yiakwy) : CONSTANTS + uint32_t lane_id = ( threadIdx.x + threadIdx.z * blockDim.x ) % 64; + uint32_t lane_id_x = lane_id % 16; + uint32_t lane_id_y = lane_id / 16; + + // TODO (yiakwy) : CONSTANTS + uint32_t warp_idx_z = get_warp_idx_z<1, 4>(); + uint32_t warp64_idx_z = warp_idx_z / 2; + + // NOTE(yiakwy) : only floatx4 of s_frag is used + + #define MTX_FRAG_LDA 16 + + DTypeQ s_frag_f16[num_frags_x][num_frags_z][4]; + + // NOTE(yiakwy) : we will write thread private memory to this to synchronize data cross lanes + __shared__ DTypeQKAccum s_smem[16][16]; + + #else DTypeQ s_frag_f16[num_frags_x][num_frags_z][8]; + + #endif // USE_ROCM + if constexpr (std::is_same::value) { #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { @@ -1085,14 +1188,14 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, for (uint32_t fz = 0; fz < num_frags_z; ++fz) { #ifdef USE_ROCM - // TODO (yiakwy) : check feasibility with v_mfma_fp32_m16xn16xk16_fp16 #pragma unroll - for (int i=0; i < 8; ++i) { + // NOTE(yiakwy) : registers points to 4 consecutive rows S[reg_id + lane_id_y * 4][lane_id_x] + for (int i=0; i < 4/*rows of s frag*/; ++i) { if constexpr (std::is_same::value) { - // device cast from half to float + // NOTE(yiakwy) : device cast from half to float, accumulated cross lanes d[fx][i] += (float)s_frag_f16[fx][fz][i]; } else { - // device cast from half to float + // NOTE(yiakwy) : device cast from float to half d[fx][i] += (float)s_frag[fx][fz][i]; } } @@ -1113,57 +1216,81 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, for (uint32_t fz = 0; fz < num_frags_z; ++fz) { #pragma unroll for (uint32_t fy = 0; fy < num_frags_y; ++fy) { + + if (warp64_idx_z * num_frags_z * 16U < 16U/*kv_len*/ ) { + uint32_t b_frag[4]; if constexpr (sizeof(DTypeKV) == 1) { - // TODO (yiakwy) : add FP8 support for KV Cache - assert(0 && "FP8 KV Cache is not supported."); + #ifdef USE_ROCM + // TODO (yiakwy) : add FP8 support for KV Cache + assert(0 && "FP8 KV Cache is not supported."); + #endif - uint32_t b_frag_f8[2]; - if (fy % 2 == 0) { - v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_f8); - } else { - v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, b_frag_f8); - } - b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); - b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); - vec_cast::template cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8); - swap(b_frag[1], b_frag[2]); + uint32_t b_frag_f8[2]; + if (fy % 2 == 0) { + v_smem->ldmatrix_m8n8x4_trans_left_half(*v_smem_offset_r, b_frag_f8); + } else { + v_smem->ldmatrix_m8n8x4_trans_right_half(*v_smem_offset_r, b_frag_f8); + } + b_frag_f8[0] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[0]); + b_frag_f8[1] = frag_layout_swizzle_16b_to_8b_trans(b_frag_f8[1]); + vec_cast::template cast<8>((DTypeQ*)b_frag, (DTypeKV*)b_frag_f8); + swap(b_frag[1], b_frag[2]); } else { - + #ifdef USE_ROCM b128_t* smem_ptr = v_smem->base + *v_smem_offset_r; + float16_t *s = reinterpret_cast(smem_ptr); float16x4 *b = reinterpret_cast(b_frag); #pragma unroll for (int j=0; j < 4; j++) { - (*b)[j] = (rocwmma::float16_t)(*b)[j]; + + uint32_t offset = lane_id_x + (lane_id_y * 4 + j) * MTX_FRAG_LDB; + + (*b)[j] = (float16_t)(*(s + offset)); } #else v_smem->ldmatrix_m8n8x4_trans(*v_smem_offset_r, b_frag); #endif // USE_ROCM - } + } // load v from global + #pragma unroll for (uint32_t fx = 0; fx < num_frags_x; ++fx) { #ifdef USE_ROCM + for (uint32_t i=0; i < 4; i++) { + s_smem[i+lane_id_y * 4][lane_id_x] = s_frag_f16[fx][fz][i]; + } + + __asm__ volatile("s_barrier" ::); + float16x4 *b = reinterpret_cast(b_frag); - floatx4 *o = reinterpret_cast(o_frag[fx][fz]); + floatx4 *o = reinterpret_cast(o_frag[fx][fy]); if constexpr (std::is_same::value) { - float16x4 *a = reinterpret_cast(s_frag_f16[fx][fz]); + float16x4 *a = reinterpret_cast(s_smem + lane_id_x * 16 + lane_id_y * 4); *o = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *o, 0, 0, 0); } else { - float16x4 *a = reinterpret_cast(s_frag[fx][fz]); + float16x4 *a = reinterpret_cast(s_smem + lane_id_x * 16 + lane_id_y * 4); *o = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *o, 0, 0, 0); } - #else + __asm__ volatile("s_barrier" ::); + + if (fz == 0 && fy == 0 && fx == 0) { + for (uint32_t reg_id = 0; reg_id < 4; reg_id++) { + printf("[compute_sfm_v] (lane_id_x=%d, lane_id_y=%d, reg_id=%d) o_frag[fx=%d][fy=%d][%d, %d] = %f\n", lane_id_x, lane_id_y, reg_id, fx, fy, reg_id + lane_id_y * 4, lane_id_x, (*o)[reg_id]); + } + } + + #else // USE_ROCM if constexpr (std::is_same::value) { mma::mma_sync_m16n16k16_row_col_f16f16f32( @@ -1184,10 +1311,13 @@ __device__ __forceinline__ void compute_sfm_v(smem_t* v_smem, } else { *v_smem_offset_r = v_smem->template advance_offset_by_column<2>(*v_smem_offset_r, fy); } + + } // if warp64_idx_z * num_frags_z * 16U < kv_len + + *v_smem_offset_r = + v_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*v_smem_offset_r) - + sizeof(DTypeKV) * num_frags_y; } - *v_smem_offset_r = - v_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*v_smem_offset_r) - - sizeof(DTypeKV) * num_frags_y; } *v_smem_offset_r -= 16 * num_frags_z * channel_size_128b_kv; } @@ -1453,14 +1583,30 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC constexpr uint32_t channel_size_128b_kv = head_dim / num_elems_per_128b(); constexpr uint32_t channel_size_128b_out = head_dim / num_elems_per_128b(); // e.g.: 64/4 - extern __shared__ uint8_t smem[]; + extern __shared__ uint8_t smem[]; // NOTE(yaikwy) : e.g. 128 (num_frags x 4 x 16) x 64 + + #ifdef USE_ROCM + + // NOTE(yiakwy) : e.g. 1x4 fragments, threads cooperatively ld/st 16 bf16/half elements in each frag + DTypeQKAccum s_frag[num_frags_x][num_frags_z][8]; + // NOTE(yiakwy) : e.g. 1x4 fragments, threads cooperatively ld/st 16 bf16/half elements in each frag + float o_frag[num_frags_x][num_frags_y][8]; + + DTypeQKAccum m[num_frags_x][2]; + __shared__ float d[num_frags_x][2]; + + #else // NOTE(yiakwy) : e.g. 1x4 fragments, threads cooperatively ld/st 16 bf16/half elements in each frag DTypeQKAccum s_frag[num_frags_x][num_frags_z][8]; // NOTE(yiakwy) : e.g. 1x4 fragments, threads cooperatively ld/st 16 bf16/half elements in each frag float o_frag[num_frags_x][num_frags_y][8]; + DTypeQKAccum m[num_frags_x][2]; float d[num_frags_x][2]; + + #endif + float rope_freq[num_frags_y / 2][4]; if constexpr (pos_encoding_mode == PosEncodingMode::kRoPELlama) { init_rope_freq(rope_freq, log2_rope_rcp_scale, log2_rope_rcp_theta); @@ -1702,7 +1848,7 @@ __launch_bounds__(num_warps_x* num_warps_z* warp_size) void SinglePrefillWithKVC // TODO (yiakwy) : REMOVE if (threadIdx.x == 0 && threadIdx.z == 0) { - printf("[single prefill kernel] calling pdate_mdo_states completes."); + printf("[single prefill kernel] calling udate_mdo_states completes.\n"); } // NOTE (yiakwy) : prepare the next loading @@ -2556,8 +2702,11 @@ cudaError_t SinglePrefillWithKVCacheDispatched( std::cout << "num_rows_per_cta : " << num_rows_per_cta << std::endl; std::cout << "num_threads : " << num_threads << std::endl; - std::cout << "num_warps_x : " << num_warps_x << std::endl; - std::cout << "num_warps_z : " << num_warps_z << std::endl; + std::cout << "num_warps_x (threads block) : " << num_warps_x << std::endl; + std::cout << "num_warps_z (threads block) : " << num_warps_z << std::endl; + std::cout << "num_x_frags : " << num_frags_x << std::endl; + std::cout << "num_y_frags : " << num_frags_y << std::endl; + std::cout << "num_z_frags : " << num_frags_z << std::endl; if (num_chunks <= 1 || tmp == nullptr) { // Enough parallelism, do not split-kv