Skip to content

Commit

Permalink
Add vectorized rms_norm support for Navi31
Browse files Browse the repository at this point in the history
- supports vectorized rms_norm_kernel
  • Loading branch information
hyoon1 committed Nov 12, 2024
1 parent 8f3bf8b commit 867c12d
Showing 1 changed file with 5 additions and 4 deletions.
9 changes: 5 additions & 4 deletions csrc/layernorm_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ using __nv_bfloat162 = __hip_bfloat162;
#include "quantization/fp8/nvidia/quant_utils.cuh"
#endif

#if defined(__HIPCC__) && (defined(__gfx90a__) || defined(__gfx940__) || \
defined(__gfx941__) || defined(__gfx942__))
#define __HIP__MI300_MI250__
#if defined(__HIPCC__) && \
(defined(__gfx90a__) || defined(__gfx940__) || defined(__gfx941__) || \
defined(__gfx942__) || defined(__gfx1100__))
#define __HIP__MI300_MI250_Navi31__
#endif

namespace vllm {
Expand Down Expand Up @@ -72,7 +73,7 @@ struct __align__(16) vec8_t {
__device__ scalar_t sum() const { return x + y + z + w + u + v + s + t; }
};

#ifdef __HIP__MI300_MI250__
#ifdef __HIP__MI300_MI250_Navi31__

// TODO(woosuk): Further optimize this kernel.
template <typename scalar_t>
Expand Down

0 comments on commit 867c12d

Please sign in to comment.