-
Notifications
You must be signed in to change notification settings - Fork 153
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[WIP][AMDGPU] try rocm POC #491
base: main
Are you sure you want to change the base?
[WIP][AMDGPU] try rocm POC #491
Conversation
Hi @yiakwy-xpu-ml-framework-team Nice work! Can the unit test run as expected? Recently, FlashInfer has started supporting JIT. This PR will likely need to wait for JIT support before it's reviewed. |
@zhyncs Thanks for the quick review! I am tackling nvbench suites. tests and benchmark info will be added soon |
@@ -36,14 +36,18 @@ macro(flashinfer_option variable description value) | |||
if("${__value}" MATCHES ";") | |||
# list values directly pass through | |||
__flashinfer_option(${variable} "${description}" "${__value}") | |||
message(STATUS "1 : creating ${variable} option, description : ${description}, value : ${__value}") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this just for debugging?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No worry, this will be removed. It is just for debugging since I found the function does not work as expected, it should override default values either from <config.cmake> or CMake specification with commandline arguments.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Values used in this test:
# config.cmake
# Whether to compile fp8 kernels or not.
set(FLASHINFER_ENABLE_FP8 ON)
# Whether to compile bf16 kernels or not.
set(FLASHINFER_ENABLE_BF16 ON)
# Whether to compile tvm bindings or not.
set(FLASHINFER_TVM_BINDING OFF)
# Whether to compile prefill kernel tests/benchmarks or not.
set(FLASHINFER_PREFILL ON)
# Whether to compile decode kernel tests/benchmarks or not.
set(FLASHINFER_DECODE ON)
# Whether to compile page kernel tests/benchmarks or not.
set(FLASHINFER_PAGE ON)
# Whether to compile cascade kernel tests/benchmarks or not.
set(FLASHINFER_CASCADE ON)
# Whether to compile sampling kernel tests/benchmarks or not.
set(FLASHINFER_SAMPLING ON)
# Whether to compile normalization kernel tests/benchmarks or not.
set(FLASHINFER_NORMALIZATION ON)
# Whether to compile fastdiv tests
set(FLASHINFER_FASTDIV_TEST ON)
# Whether to compile fastdequant tests
set(FLASHINFER_FASTDEQUANT_TEST ON)
# Whether to compile distributed tests
set(FLASHINFER_DISTRIBUTED OFF)
# The following configurations can impact the binary
# size of the generated library
set(FLASHINFER_GEN_LOGITS_POST_HOOKS 0)
set(FLASHINFER_GEN_HEAD_DIMS 64 128 256)
set(FLASHINFER_GEN_KV_LAYOUTS 0 1)
set(FLASHINFER_GEN_POS_ENCODING_MODES 0 1 2)
set(FLASHINFER_GEN_ALLOW_FP16_QK_REDUCTIONS "false" "true")
set(FLASHINFER_GEN_MASK_MODES 0 1 2)
# Set target cuda architectures for tests/benchmarks, defaults to native.
# "native" is a special value for CMAKE_CUDA_ARCHITECTURES which means use the architectures of the host's GPU.
# it's new in CMake 3.24, if you are using an older of CMake or you want to use a different value, you can
# set its value here. Supported CUDA architctures include 80;86;89;90
# NOTE(Zihao): using "native" might be slow because whenever compile a cuda file with `-arch=native`, nvcc will spawn
# a `__nvcc_device_query` process to get the architecture of the host's GPU, which could stall the compilation process.
# So it's recommended to set it to a specific value if you know the architecture of the target GPU.
# Example:
# set(FLASHINFER_CUDA_ARCHITECTURES 80)
set(FLASHINFER_CUDA_ARCHITECTURES native)
@zhyncs test_norm and hip_cuda defs added. I guess there are more efforts to port other functions. 1, 2 days expected. SummarySampel testOther binaries verification is working in progress. test_norm: I found we have re-written many warpsync level functions, I think these function should have been ready included in SDK(CUDA/HIP) with C primitives to call inline asm. Maintaining these codes outside of SDK is very dangerous. Currently I added some amd version to replace ptx version, do have we some better choices ? Write asm needs deep dive into MI300 ISA :
RoadmapI disabled FP8 related functions, since nv fp8 is very different from AMD FP8, details can be founded in openxla project where fp8 from different vendors (AMD, Graphcore) incorporated. AMD Fp8 has different signature from __nv_fp8_storage_t, __nv_fp8_e5m2, __nv_fp8_e4m3 by considering infinite, nan, bias and other special values . This enable best accuracy in low precision computing. Later I will add fp8 support. |
- resovle nvbench problem - add hip cuda defs and port test_norm - add test_norm & bench_norm
3039894
to
553037f
Compare
Sorry for the late reply, I'll take a look at this after #507 got merged. @LeshengJin @spectrometerHBH would you mind helping review this PR? |
Thank you @yzh119 and hope this notes of the undergoing and future work useful for @LeshengJin @spectrometerHBH. How to port the work without rewriting the fundamentsFlash attention ops runnable exist in several opensource repos for HPC and AI. Benchmark in H100 shows flashinfer has strong performance in fp8 and almost 10 percent improvement for fp16/bf16 implementation. So instead of rewriting some ops with existing solution (not a opensource collabration), there must be some goodness we can learn and borrow from flashinfer authors. So I wish some porting by reusing most of the flahinfer codes as an add-on feature (and I guess ROCm team don't have enough resources to maintain it as a private repo). I illustrated with NV MMA instruction, there must be some assumption to hold so that the work can continue
SMEM layout when loading from globalI verified in a very simple case for
By digging into codes and add some test instruments, we found 16x64 (num_frags_x == 1) produced in SMEMS (kernel launch codes does not show directly how threads, smems mapped ) from global memroy.
here is the snapshot in MI300X: That means the codes with 64 threads warp does not affect the orginal layout designed for 32 threads warp. matrix fragment loading by rewriting threads layoutThe mma codes relies on NV ldmatrix, mma which MI300X does not have. NV uses ldmatrix_m8n8x4 and fragments to prepare 16x16 matrix fragment, I use thread private memory (which later compiles to register by compiler) to mimic the work:
Then I did the same thing for k and v. I disabled swizzle. Then the last thing to do is making matrix fragments. Applying the same technique to compute_sfm then the codes compiles. What I found about abstraction of mma::mma_sync_m16n16k16_row_col_f16f16f32Ideally , I just need to addon mma::mma_sync_m16n16k16_row_col_f16f16f32 with platform control. But for the moment, I can not use it directly. This simply because mma needs threads work cooperatively to load data to thread local memory (to registers by compiler). This means mma::mma_sync_m16n16k16_row_col_f16f16f32 replies on a warp of threads layout (32 in NV, typically each thread loads 2 elements in 8x4 layout; while 64 in AMD, typically each thread loads 4 elements in 16 x 4 layout, column by column and row by row). The next thingsQ, KV layout especially fragments layout has many variances. I am drawing a picture to fully understand this . I think team fully uses uint4 to load data to maximax the global to SMEM throughput and has done a great efforts for matrix swizzling, it produces exactly the same, and effcient sequence ids for threads loading from /storing to LDS (SMEM) memory as I saw before. And its fragments layout is very different. LaneId is indepenent from threaIdx.y, threadIx.z which used for matrix fragment ids in a warp (not directly derivaved from threads ids based threads wavenumber). I trying to draw picutre to fully explain this. |
…tput product registers
The fragments is now decoded in this way
Flash attention passes thread private memolry in different functions without storing to smem and synchronize, this should be very fast, but
Finally, MI300 has 8 XCDs, and they have no L2 cache to share with each other due to chiplet architecture. So we need to make sure continuous fragments mapped to the same CUs. |
Description
For benchmark purpose, I need to make sure FlashInfer compile and workable in AMD GPU platform, hence I ported the programs into AMD GPUs.
CPU : EPYC 9534
GPU : MI30X (MI300X, MI308X)
ARCH : gfx942
Test
BUILD
Test & benchmark
** total test list (with benchmark) **
[] test_sum_all_reduce
[] test_attn_all_reduce
[x] test_fast_dequant (surpress FP8 tests, lets do it in the next PR)
[x] test_fast_div
[x] test_norm
[x] test_sampling
[] test_cascade
[x] test_page
[-] test_single_prefill : TODO fix precision, investigating
[-] test_single_decode : TODO fix precision, probably wrong magic number and wave size
[] test_batch_prefill
[] test_batch_decode
Investigating ...
full logs:
full_bench_sampling_log.txt