Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][AMDGPU] try rocm POC #491

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -1,6 +1,10 @@
[submodule "3rdparty/nvbench"]
path = 3rdparty/nvbench
url = https://github.com/NVIDIA/nvbench.git
[submodule "3rdparty/hipbench"]
path = 3rdparty/hipbench
# url = https://github.com/ROCm/hipBench.git
url = https://github.com/yiakwy-xpu-ml-framework-team/hipbench
[submodule "3rdparty/googletest"]
path = 3rdparty/googletest
url = https://github.com/google/googletest.git
Expand Down
287 changes: 261 additions & 26 deletions CMakeLists.txt

Large diffs are not rendered by default.

2 changes: 1 addition & 1 deletion cmake/config.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -40,4 +40,4 @@ set(FLASHINFER_GEN_MASK_MODES 0 1 2)
# So it's recommended to set it to a specific value if you know the architecture of the target GPU.
# Example:
# set(FLASHINFER_CUDA_ARCHITECTURES 80)
set(FLASHINFER_CUDA_ARCHITECTURES native)
set(FLASHINFER_CUDA_ARCHITECTURES native)
2 changes: 2 additions & 0 deletions cmake/modules/FindThrust.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,9 @@ find_path( THRUST_INCLUDE_DIR
/usr/include/cuda
/usr/local/include
/usr/local/cuda/include
/opt/rocm/include
${CUDA_INCLUDE_DIRS}
${HIP_INCLUDE_DIRS}
NAMES thrust/version.h
DOC "Thrust headers"
)
Expand Down
4 changes: 4 additions & 0 deletions cmake/utils/Utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -36,14 +36,18 @@ macro(flashinfer_option variable description value)
if("${__value}" MATCHES ";")
# list values directly pass through
__flashinfer_option(${variable} "${description}" "${__value}")
message(STATUS "1 : creating ${variable} option, description : ${description}, value : ${__value}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just for debugging?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worry, this will be removed. It is just for debugging since I found the function does not work as expected, it should override default values either from <config.cmake> or CMake specification with commandline arguments.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Values used in this test:

# config.cmake
# Whether to compile fp8 kernels or not.
set(FLASHINFER_ENABLE_FP8 ON)
# Whether to compile bf16 kernels or not.
set(FLASHINFER_ENABLE_BF16 ON)
# Whether to compile tvm bindings or not.
set(FLASHINFER_TVM_BINDING OFF)
# Whether to compile prefill kernel tests/benchmarks or not.
set(FLASHINFER_PREFILL ON)
# Whether to compile decode kernel tests/benchmarks or not.
set(FLASHINFER_DECODE ON)
# Whether to compile page kernel tests/benchmarks or not.
set(FLASHINFER_PAGE ON)
# Whether to compile cascade kernel tests/benchmarks or not.
set(FLASHINFER_CASCADE ON)
# Whether to compile sampling kernel tests/benchmarks or not.
set(FLASHINFER_SAMPLING ON)
# Whether to compile normalization kernel tests/benchmarks or not.
set(FLASHINFER_NORMALIZATION ON)
# Whether to compile fastdiv tests
set(FLASHINFER_FASTDIV_TEST ON)
# Whether to compile fastdequant tests
set(FLASHINFER_FASTDEQUANT_TEST ON)
# Whether to compile distributed tests
set(FLASHINFER_DISTRIBUTED OFF)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true")
set(FLASHINFER_GEN_MASK_MODES 0 1 2)

# Set target cuda architectures for tests/benchmarks, defaults to native.
# "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU.
# it's new in CMake 3.24, if you are using an older of CMake or you want to use a different value, you can
# set its value here. Supported CUDA architctures include 80;86;89;90
# NOTE(Zihao): using "native" might be slow because whenever compile a cuda file with `-arch=native`, nvcc will spawn
# a `__nvcc_device_query` process to get the architecture of the host's GPU, which could stall the compilation process.
# So it's recommended to set it to a specific value if you know the architecture of the target GPU.
# Example:
# set(FLASHINFER_CUDA_ARCHITECTURES 80)
set(FLASHINFER_CUDA_ARCHITECTURES native)

elseif(DEFINED ${__value})
if(${__value})
__flashinfer_option(${variable} "${description}" ON)
message(STATUS "2 : creating ${variable} option, description : ${description}, value : ON")
else()
__flashinfer_option(${variable} "${description}" OFF)
message(STATUS "3 : creating ${variable} option, description : ${description}, value : OFF")
endif()
else()
__flashinfer_option(${variable} "${description}" "${__value}")
message(STATUS "4 : creating ${variable} option, description : ${description}, value : ${__value}")
endif()
else()
unset(${variable} CACHE)
Expand Down
8 changes: 8 additions & 0 deletions include/flashinfer/attention/cascade.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,15 @@
#ifndef FLASHINFER_CASCADE_CUH_
#define FLASHINFER_CASCADE_CUH_

#ifdef USE_ROCM

#include <hip/hip_cooperative_groups.h>
// CUDA API Portable interfaces
#include "flashinfer/hip_defs.h"

# else
#include <cooperative_groups.h>
#endif // USE_ROCM

#include "../cp_async.cuh"
#include "../math.cuh"
Expand Down
16 changes: 15 additions & 1 deletion include/flashinfer/attention/decode.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,28 @@
*/
#ifndef FLASHINFER_DECODE_CUH_
#define FLASHINFER_DECODE_CUH_

#ifdef USE_ROCM

#include <hip/hip_cooperative_groups.h>
#include <hip/pipeline.h>

#include "flashinfer/hip_cuda_type_utils.h"
// CUDA API Portable interfaces
#include "flashinfer/hip_defs.h"

# else
#include <cooperative_groups.h>
#include <cuda_bf16.h>
#include <cuda_fp16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
// this is used
#include <cuda/pipeline>
#endif // USE_ROCM

#include <cstddef>
#include <cuda/pipeline>

#include <iostream>
#include <optional>
#include <random>
Expand Down
14 changes: 14 additions & 0 deletions include/flashinfer/attention/handler.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,23 @@
#ifndef FLASHINFER_ATTENTION_HANDLER_CUH_
#define FLASHINFER_ATTENTION_HANDLER_CUH_

#ifdef USE_ROCM

#include <hip/hip_runtime_api.h>
// CUDA API Portable interfaces
#include "flashinfer/hip_defs.h"

#include <hip/driver_types.h>

#else

#include <cuda_runtime_api.h>

// Note this is part of NV SDK
#include <driver_types.h>

#endif // USE_ROCM

#include <algorithm>
#include <cstddef>
#include <sstream>
Expand Down
Loading