Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP][AMDGPU] try rocm POC #491

Draft
wants to merge 15 commits into
base: main
Choose a base branch
from

Conversation

yiakwy-xpu-ml-framework-team
Copy link

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team commented Sep 4, 2024

Description

For benchmark purpose, I need to make sure FlashInfer compile and workable in AMD GPU platform, hence I ported the programs into AMD GPUs.

CPU : EPYC 9534
GPU : MI30X (MI300X, MI308X)
ARCH : gfx942

Test

BUILD

cd build && cmake .. && make [TARGET]

Test & benchmark

** total test list (with benchmark) **
[] test_sum_all_reduce
[] test_attn_all_reduce
[x] test_fast_dequant (surpress FP8 tests, lets do it in the next PR)
[x] test_fast_div
[x] test_norm
[x] test_sampling
[] test_cascade
[x] test_page
[-] test_single_prefill : TODO fix precision, investigating
[-] test_single_decode : TODO fix precision, probably wrong magic number and wave size
[] test_batch_prefill
[] test_batch_decode

  • test_norm
截屏2024-09-09 13 31 08
  • bench_norm:
截屏2024-09-09 13 29 07
  • test_fast_div
截屏2024-09-23 15 19 39
  • test_single_decode
截屏2024-09-22 20 00 08

Investigating ...

  • test_sampling:
截屏2024-09-23 18 51 18
  • bench_sampling:
截屏2024-09-23 20 11 17

full logs:
full_bench_sampling_log.txt

  • test_page
截屏2024-09-23 21 08 05

@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team changed the title add rocm support [AMDGPU] add rocm support Sep 4, 2024
@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team changed the title [AMDGPU] add rocm support [WIP][AMDGPU] add rocm support Sep 4, 2024
@zhyncs zhyncs requested a review from yzh119 September 4, 2024 08:18
@zhyncs
Copy link
Member

zhyncs commented Sep 4, 2024

hence I ported the programs into AMD GPUs

Hi @yiakwy-xpu-ml-framework-team Nice work! Can the unit test run as expected? Recently, FlashInfer has started supporting JIT. This PR will likely need to wait for JIT support before it's reviewed.

@zhyncs
Copy link
Member

zhyncs commented Sep 4, 2024

@yiakwy-xpu-ml-framework-team
Copy link
Author

ref https://github.com/flashinfer-ai/flashinfer/tree/main/python/tests

@zhyncs Thanks for the quick review! I am tackling nvbench suites. tests and benchmark info will be added soon

.gitmodules Outdated Show resolved Hide resolved
.gitmodules Outdated Show resolved Hide resolved
@@ -36,14 +36,18 @@ macro(flashinfer_option variable description value)
if("${__value}" MATCHES ";")
# list values directly pass through
__flashinfer_option(${variable} "${description}" "${__value}")
message(STATUS "1 : creating ${variable} option, description : ${description}, value : ${__value}")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this just for debugging?

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No worry, this will be removed. It is just for debugging since I found the function does not work as expected, it should override default values either from <config.cmake> or CMake specification with commandline arguments.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Values used in this test:

# config.cmake
# Whether to compile fp8 kernels or not.
set(FLASHINFER_ENABLE_FP8 ON)
# Whether to compile bf16 kernels or not.
set(FLASHINFER_ENABLE_BF16 ON)
# Whether to compile tvm bindings or not.
set(FLASHINFER_TVM_BINDING OFF)
# Whether to compile prefill kernel tests/benchmarks or not.
set(FLASHINFER_PREFILL ON)
# Whether to compile decode kernel tests/benchmarks or not.
set(FLASHINFER_DECODE ON)
# Whether to compile page kernel tests/benchmarks or not.
set(FLASHINFER_PAGE ON)
# Whether to compile cascade kernel tests/benchmarks or not.
set(FLASHINFER_CASCADE ON)
# Whether to compile sampling kernel tests/benchmarks or not.
set(FLASHINFER_SAMPLING ON)
# Whether to compile normalization kernel tests/benchmarks or not.
set(FLASHINFER_NORMALIZATION ON)
# Whether to compile fastdiv tests
set(FLASHINFER_FASTDIV_TEST ON)
# Whether to compile fastdequant tests
set(FLASHINFER_FASTDEQUANT_TEST ON)
# Whether to compile distributed tests
set(FLASHINFER_DISTRIBUTED OFF)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true")
set(FLASHINFER_GEN_MASK_MODES 0 1 2)

# Set target cuda architectures for tests/benchmarks, defaults to native.
# "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU.
# it's new in CMake 3.24, if you are using an older of CMake or you want to use a different value, you can
# set its value here. Supported CUDA architctures include 80;86;89;90
# NOTE(Zihao): using "native" might be slow because whenever compile a cuda file with `-arch=native`, nvcc will spawn
# a `__nvcc_device_query` process to get the architecture of the host's GPU, which could stall the compilation process.
# So it's recommended to set it to a specific value if you know the architecture of the target GPU.
# Example:
# set(FLASHINFER_CUDA_ARCHITECTURES 80)
set(FLASHINFER_CUDA_ARCHITECTURES native)

@yiakwy-xpu-ml-framework-team
Copy link
Author

yiakwy-xpu-ml-framework-team commented Sep 6, 2024

@zhyncs test_norm and hip_cuda defs added. I guess there are more efforts to port other functions. 1, 2 days expected.

Summary

Sampel test

Other binaries verification is working in progress.

test_norm:

截屏2024-09-06 12 33 03

I found we have re-written many warpsync level functions, I think these function should have been ready included in SDK(CUDA/HIP) with C primitives to call inline asm. Maintaining these codes outside of SDK is very dangerous.

Currently I added some amd version to replace ptx version, do have we some better choices ? Write asm needs deep dive into MI300 ISA :

namespace amdgpu {

// ROCM exp c primitive, which computes 2^x in fp8/fp16/bf16/fp32
template<typename T>
__forceinline__ __device__ T exp2(T);

template<typename T>
__forceinline__ __device__ T log2(T);

template<typename T>
__forceinline__ __device__ T rcp(T);

template<typename T>
__forceinline__ __device__ T shfl_xor_sync(T, int);

template<typename T>
__forceinline__ __device__ T rsqrt(T);

// sepicalization

// TODO (yiakwy) : add equivalent asm version for fast exp computation (polynomial approx)
template<>
inline __device__ float exp2(float x) {
  return exp2f(x);
}

template<>
inline __device__ half exp2(half x) {
  return hexp2(x);
}

template<>
__forceinline__ __device__ float log2(float x) {
  return log2f(x);
}

template<>
inline __device__ half log2(half x) {
  return hlog2(x);
}

template<>
__forceinline__ __device__ float rcp(float x) {
  // TODO (yiakwy) : __frcp_rn is not supported in ROCM 6.2
  return 1.f / x;
}

// TODO (yiakwy) : verify; see details from here : https://rocm.docs.amd.com/projects/HIP/en/develop/reference/kernel_language.html
template<>
__forceinline__ __device__ float shfl_xor_sync(float x, int lane_mask) {
  // note AMD uses 8 byte mask (i.e. long datatype) to allow all 64 threads participate in
  // TODO (yiakyw) : SDK compatibility checking ...
  return __shfl_xor_sync(0xffffffffffffffff, x, lane_mask);
}

template<>
__forceinline__ __device__ half2 shfl_xor_sync(half2 x, int lane_mask) {
  // note AMD uses 8 byte mask (i.e. long datatype)
  return __shfl_xor_sync(0xffffffffffffffff, x, lane_mask);
}

template<>
__forceinline__ __device__ float rsqrt(float x) {
  return rsqrtf(x);
}

} // amdgpu

/*!
 * \brief Wrapper of PTX ex2.approx instruction, which computes 2^x
 * \param x input
 */
__forceinline__ __device__ float ptx_exp2(float x) {
  return amdgpu::exp2(x);
}

Roadmap

I disabled FP8 related functions, since nv fp8 is very different from AMD FP8, details can be founded in openxla project where fp8 from different vendors (AMD, Graphcore) incorporated.

AMD Fp8 has different signature from __nv_fp8_storage_t, __nv_fp8_e5m2, __nv_fp8_e4m3 by considering infinite, nan, bias and other special values .

This enable best accuracy in low precision computing.

Later I will add fp8 support.

- resovle nvbench problem

- add hip cuda defs and port test_norm

- add test_norm & bench_norm
@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team force-pushed the porting_kernels_amd_ck_hip_version branch from 3039894 to 553037f Compare September 9, 2024 05:46
@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team changed the title [WIP][AMDGPU] add rocm support [WIP][AMDGPU] try rocm support POC Sep 13, 2024
@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team marked this pull request as draft September 13, 2024 07:31
@yiakwy-xpu-ml-framework-team yiakwy-xpu-ml-framework-team changed the title [WIP][AMDGPU] try rocm support POC [WIP][AMDGPU] try rocm POC Sep 13, 2024
@yzh119
Copy link
Collaborator

yzh119 commented Sep 25, 2024

Sorry for the late reply, I'll take a look at this after #507 got merged.

@LeshengJin @spectrometerHBH would you mind helping review this PR?

@yiakwy-xpu-ml-framework-team
Copy link
Author

yiakwy-xpu-ml-framework-team commented Oct 10, 2024

Thank you @yzh119 and hope this notes of the undergoing and future work useful for @LeshengJin @spectrometerHBH.

How to port the work without rewriting the fundaments

Flash attention ops runnable exist in several opensource repos for HPC and AI. Benchmark in H100 shows flashinfer has strong performance in fp8 and almost 10 percent improvement for fp16/bf16 implementation.

So instead of rewriting some ops with existing solution (not a opensource collabration), there must be some goodness we can learn and borrow from flashinfer authors.

So I wish some porting by reusing most of the flahinfer codes as an add-on feature (and I guess ROCm team don't have enough resources to maintain it as a private repo).

I illustrated with NV MMA instruction, there must be some assumption to hold so that the work can continue

  • Global to LDS (SMEM) can be resued (verified in the following very simple cases)
  • We only fucus on matrix fragment loading and computing
  • Swizzle impact performance but we will focus on accuracies first

SMEM layout when loading from global

I verified in a very simple case for

// kv_layout == 0
q : (1/*HEADS*/, {2,16}/*seqlens*/, 64 /*hidden size, 128 bit for half*/)
// assume q_heads == kv_heads and replicate the data
k : (1/*HEADS*/, {2,16}/*seqlens*/, 64 /*hidden size, 128 bit for half*/)
v : (1/*HEADS*/, {2,16}/*seqlens*/, 64 /*hidden size, 128 bit for half*/)

// this inputs produce 1x2x2 output for compute_qk function which can easily verified at first glance.

By digging into codes and add some test instruments, we found 16x64 (num_frags_x == 1) produced in SMEMS (kernel launch codes does not show directly how threads, smems mapped ) from global memroy.

    // NOTE(yiakwy) : each thread of a 32 threads block, cooperatively load 128 bit (uint4/float4/halfx8) data from system memory to shared memory
    // qsmem shape = (_, 128 Byte)
    // -- frags y -> (but loaded into SMEM the next 16 rows)
    // qsmem row/col 0                       1                       ... 7               warp_idx {0..3}  
    //       0       0  1  2  3  4  5  6  7  8  9  10 11 12 13 14 15 ... 60  61  62  63  0                |
    //       1       64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 ... 124 125 126 127 0                |
    //       2       .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  ... .   .   .   .   0               frags x
    //       3       .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  ... .   .   .   .   0                |
    //       ...     .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  ... .   .   .   .   0                |
    //       0+4*3   .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  ... .   .   .   .   0                v
    //       1+4*3   .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  ... .   .   .   .   0
    //       2+4*3   .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  ... .   .   .   .   0
    //       3+4*3   .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  .  ... .   .   .   .   0
    //  qsmem is (num_frags_x x 16) x 64 (128 bit) matrix fragment    

here is the snapshot in MI300X:

截屏2024-10-09 19 03 22

That means the codes with 64 threads warp does not affect the orginal layout designed for 32 threads warp.

matrix fragment loading by rewriting threads layout

The mma codes relies on NV ldmatrix, mma which MI300X does not have. NV uses ldmatrix_m8n8x4 and fragments to prepare 16x16 matrix fragment, I use thread private memory (which later compiles to register by compiler) to mimic the work:

  if (warp64_idx_z * num_frags_z * 16U < 16U/*kv_len*/ ) {

#pragma unroll
  for (uint32_t fy = 0; fy < num_frags_y; ++fy) {

    // load q
#pragma unroll
    for (uint32_t fx = 0; fx < num_frags_x; ++fx) {

      // NOTE (yiakwy) : q_smem has shape of (num_frags_x, 16, 8x8), v_mfma_m16n16k16_fp16 will be applied 4 times along feat dim
      b128_t* smem_ptr = q_smem->base + *q_smem_offset_r;
      float16_t *s = reinterpret_cast<float16_t *>(smem_ptr);

      float16x4 *a = reinterpret_cast<float16x4 *>(a_frag[fx]);

      // TODO (yiakwy) : replaced with more efficient load instruction
#pragma unroll
      for (uint32_t j=0; j < 4; j++) {
        // NOTE (yiakwy) : 16 threads loads 4 columns (16x4fp16) of data cooperatively
        uint32_t offset = lane_id_x * MTX_FRAG_LDA + j + lane_id_y * 4;

        (*a)[j] = *(s + offset);
      }

      *q_smem_offset_r =
              q_smem->template advance_offset_by_row<16, channel_size_128b_q>(*q_smem_offset_r);
    } // num_frags_x

    // NOTE(yiakwy) : next to 16 = 2x8 columns
    *q_smem_offset_r = q_smem->template advance_offset_by_column<2>(*q_smem_offset_r, fy) -
                       num_frags_x * 16 * channel_size_128b_q;

    // load k
#pragma unroll
    for (uint32_t fz = 0; fz < num_frags_z; ++fz) {

      if constexpr (sizeof(DTypeKV) == 1) {
        assert(0 && "KV Cache with FP8 data type is not supported in ROCM");
      }

      b128_t* smem_ptr = k_smem->base + *k_smem_offset_r;
      float16_t *s = reinterpret_cast<float16_t *>(smem_ptr);

      float16x4 *b = reinterpret_cast<float16x4 *>(b_frag);

      // TODO (yiakwy) : replaced with more efficient load inst
#pragma unroll
      for (uint32_t j=0; j < 4; j++) {
        // NOTE (yiakwy) : loads 16 consecutive data of 1 row
        uint32_t offset = lane_id_x + (lane_id_y * 4 + j) * MTX_FRAG_LDB;

        (*b)[j] = *(s+offset);
      }

      // NOTE(yiakwy) : k is still in row-major layout
      *k_smem_offset_r =
          k_smem->template advance_offset_by_row<16, channel_size_128b_kv>(*k_smem_offset_r);

      // compute
      for (uint32_t fx = 0; fx < num_frags_x; ++fx) {
        float16x4 *a = reinterpret_cast<float16x4 *>(a_frag[fx]);
        float16x4 *b = reinterpret_cast<float16x4 *>(b_frag);

        if constexpr (std::is_same<DTypeQKAccum, float>::value) {
          floatx4 *d = reinterpret_cast<floatx4 *>(s_frag[fx][fz]);
          *d = __builtin_amdgcn_mfma_f32_16x16x16f16(*a, *b, *d, 0, 0, 0);

          // __asm__ volatile("s_barrier" ::);
          __builtin_amdgcn_s_waitcnt(0);
          __builtin_amdgcn_s_barrier();
        } else {
          // TODO (yiakwy) : device cast fp32 to fp16
          assert(0 && "AMD v_mfma instruction does not support fp16 output.");
        }
      }
    }
    if constexpr (sizeof(DTypeKV) == 1) {
      assert(0 && "FP8 KV Cache will be suppported soon.");
    } else {
      *k_smem_offset_r = k_smem->template advance_offset_by_column<2>(*k_smem_offset_r, fy) -
                         num_frags_z * 16 * channel_size_128b_kv;
    }
  }

  } // if warp64_idx_z * num_frags_z * 16 < kv_len
  
    // NOTE(yiakwy) : we have threads not in USE, so we must synchrose the whole threads block before prceeding
  __syncthreads();

Then I did the same thing for k and v. I disabled swizzle. Then the last thing to do is making matrix fragments.

Applying the same technique to compute_sfm then the codes compiles.

What I found about abstraction of mma::mma_sync_m16n16k16_row_col_f16f16f32

Ideally , I just need to addon mma::mma_sync_m16n16k16_row_col_f16f16f32 with platform control. But for the moment, I can not use it directly.

This simply because mma needs threads work cooperatively to load data to thread local memory (to registers by compiler). This means mma::mma_sync_m16n16k16_row_col_f16f16f32 replies on a warp of threads layout (32 in NV, typically each thread loads 2 elements in 8x4 layout; while 64 in AMD, typically each thread loads 4 elements in 16 x 4 layout, column by column and row by row).

The next things

Q, KV layout especially fragments layout has many variances. I am drawing a picture to fully understand this .

I think team fully uses uint4 to load data to maximax the global to SMEM throughput and has done a great efforts for matrix swizzling, it produces exactly the same, and effcient sequence ids for threads loading from /storing to LDS (SMEM) memory as I saw before.

And its fragments layout is very different. LaneId is indepenent from threaIdx.y, threadIx.z which used for matrix fragment ids in a warp (not directly derivaved from threads ids based threads wavenumber). I trying to draw picutre to fully explain this.

@yiakwy-xpu-ml-framework-team
Copy link
Author

yiakwy-xpu-ml-framework-team commented Oct 25, 2024

The fragments is now decoded in this way

  • threads related offset handled in each computing blocks
  • warps half reduced (warp64_id_x , warp64_id_z), threads remapped to 16x4 layout

Flash attention passes thread private memolry in different functions without storing to smem and synchronize, this should be very fast, but

  • in AMD v_mfma instruction: a (columns major) x b (rows major) to accumulate output (rows major), that means when we compute s=qk, o=sv, we need to remap s from column major (4 registers) to row major.

Finally, MI300 has 8 XCDs, and they have no L2 cache to share with each other due to chiplet architecture. So we need to make sure continuous fragments mapped to the same CUs.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants