Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
58 commits
Select commit Hold shift + click to select a range
e54cb5a
intial commit
Chi-Chu319 Oct 6, 2025
191f179
unified attention rename
Chi-Chu319 Oct 9, 2025
436eb3a
transform q tensor view
juuso-oskari Oct 10, 2025
df60493
refactor
Chi-Chu319 Oct 10, 2025
1f4648d
refactor. and fixed q transformation
Chi-Chu319 Oct 10, 2025
bc6385f
Some refactor
Chi-Chu319 Oct 13, 2025
36a65b1
refactor
Chi-Chu319 Oct 13, 2025
2d6dab2
refactor the q tensor view transformation
juuso-oskari Oct 13, 2025
49ce980
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/…
juuso-oskari Oct 13, 2025
af94aaf
refactor the q tensor view transformation
juuso-oskari Oct 13, 2025
55fc6d7
kv tensor view
Chi-Chu319 Oct 13, 2025
96fde33
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/…
Chi-Chu319 Oct 13, 2025
16129a7
stride fix
Chi-Chu319 Oct 13, 2025
b721f79
fix
juuso-oskari Oct 13, 2025
81a02ff
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/…
juuso-oskari Oct 13, 2025
6ba25b7
add commenting
juuso-oskari Oct 13, 2025
be58d51
o ptr and window
Chi-Chu319 Oct 13, 2025
cd35428
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/…
Chi-Chu319 Oct 13, 2025
6a7fa95
kv tensor view and initial window
Chi-Chu319 Oct 13, 2025
b37c356
fix q window origin
juuso-oskari Oct 14, 2025
c3d27ab
fix q window
juuso-oskari Oct 14, 2025
e1120ff
pipeline api
Chi-Chu319 Oct 14, 2025
96b208f
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/…
Chi-Chu319 Oct 14, 2025
c87f2e3
o window change
Chi-Chu319 Oct 14, 2025
ec29289
kv paging
Chi-Chu319 Oct 14, 2025
b940a75
Comments
Chi-Chu319 Oct 14, 2025
4d232d5
fix seq_len -> cur_batch_query_len
juuso-oskari Oct 14, 2025
72fe8b3
merge
juuso-oskari Oct 14, 2025
853fa21
Example boostrap
Chi-Chu319 Oct 15, 2025
63c17b7
correct masking by transforming y_idx = y_idx / num_queries_per_kv
juuso-oskari Oct 16, 2025
498a97a
merge
juuso-oskari Oct 16, 2025
6293257
use correct mask in kernel
juuso-oskari Oct 16, 2025
aa4908a
fix mask
juuso-oskari Oct 16, 2025
072de38
comment
juuso-oskari Oct 16, 2025
9940bd0
fix order in mask caller
juuso-oskari Oct 16, 2025
af9167a
example
Chi-Chu319 Oct 17, 2025
995c670
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/…
Chi-Chu319 Oct 17, 2025
f4e8f79
fixing args
Chi-Chu319 Oct 17, 2025
3f963d4
modified cmake files at unified attention example. Now cmake works, b…
juuso-oskari Oct 20, 2025
9fda954
Compiling fix
Chi-Chu319 Oct 20, 2025
97e7527
fixing compile errors...
juuso-oskari Oct 20, 2025
d68a541
fixing compile errors...
juuso-oskari Oct 20, 2025
f72b994
More compilation fixes
Chi-Chu319 Oct 20, 2025
e144872
change to BLOCK_M in shape definitions
juuso-oskari Oct 23, 2025
3c0e6d3
fixing bugs
juuso-oskari Oct 23, 2025
0d2a9ba
fixed example
Chi-Chu319 Oct 23, 2025
3bcef59
block table stride fix
Chi-Chu319 Oct 23, 2025
5bf72d2
fixing bugs
juuso-oskari Oct 23, 2025
e03ed35
fix the vector max
Chi-Chu319 Oct 23, 2025
3fe5d79
Merge branch 'tianxing/unified-attention' of https://github.com/ROCm/…
Chi-Chu319 Oct 23, 2025
6ea56be
removed redundent code
Chi-Chu319 Oct 23, 2025
3bb29bf
Fixed pipeline args
Chi-Chu319 Oct 23, 2025
ebf1c4c
const blockq
Chi-Chu319 Oct 23, 2025
d18f8e4
Fixed block Q with M
Chi-Chu319 Oct 23, 2025
89cfdb3
Fixed block Q with M
Chi-Chu319 Oct 23, 2025
22c5c20
Debugging window size
Chi-Chu319 Oct 24, 2025
d5c8315
fixed window creation number<>{}
Chi-Chu319 Oct 28, 2025
98f15ee
Added block q
Chi-Chu319 Nov 5, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 2 additions & 3 deletions example/ck_tile/01_fmha/fmha_fwd_v3_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -57,9 +57,8 @@ struct fmha_fwd_v3_kernel_traits
{
static constexpr auto date_type = DataType;
static constexpr bool is_variable_seqlen = IsVariableSeqlen;
static constexpr bool is_masking = IsMasking;

// M0 N0 K0 N1 K1
static constexpr bool is_masking = IsMasking
// M0 N0 K0 N1 K1
using fmha_block_tile = sequence<256, 32, 128, 128, 32, 128>;
using fmha_warp_gemm_shape = sequence<32, 32, 16>;
using fmha_block_warps = sequence<8, 1, 1>;
Expand Down
222 changes: 222 additions & 0 deletions example/ck_tile/01_unified_attention/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,222 @@
# Commented out: FMHA fwd/bwd instance generation and codegen commands not used by unified_attention
#
# set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# # Currently only gfx9 archs are supported by FMHA
# list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
# if(NOT INST_TARGETS)
# message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
# return()
# endif()
#
# # validate user-specified fmha_fwd API list
# set(FMHA_FWD_KNOWN_APIS "fwd;fwd_splitkv;fwd_appendkv;pagedkv_prefill")
# set(FMHA_FWD_ENABLE_APIS "fwd" CACHE STRING
# "semicolon-separated list of APIs to generate (${FMHA_FWD_KNOWN_APIS}) & link, or \"all\".")
# if(BUILD_TESTING)
# # Build instances of all APIs for tests
# set(FMHA_FWD_ENABLE_APIS "all")
# endif()
# if(FMHA_FWD_ENABLE_APIS STREQUAL "all")
# set(FMHA_FWD_ENABLE_APIS ${FMHA_FWD_KNOWN_APIS})
# endif()
#
# foreach(api ${FMHA_FWD_ENABLE_APIS})
# if(NOT "${api}" IN_LIST FMHA_FWD_KNOWN_APIS)
# message(FATAL_ERROR "${api} isn't a known api: ${FMHA_FWD_KNOWN_APIS}.")
# endif()
# endforeach()
#
# # "fwd" is a must-have api for the fmha_fwd example, add it if not specified
# if(NOT "fwd" IN_LIST FMHA_FWD_ENABLE_APIS)
# list(PREPEND FMHA_FWD_ENABLE_APIS "fwd")
# endif()
#
# file(GLOB_RECURSE CODE_GEN_SCRIPTS CONFIGURE_DEPENDS
# ${CMAKE_CURRENT_LIST_DIR}/generate.py
# ${CMAKE_CURRENT_LIST_DIR}/codegen/*.py
# )
# set_directory_properties(PROPERTIES CMAKE_CONFIGURE_DEPENDS "${CODE_GEN_SCRIPTS}")
#
# string(REPLACE ";" "," FMHA_FWD_APIS "${FMHA_FWD_ENABLE_APIS}")
# set(FMHA_FWD_CODE_GEN_COMMON_ARGS
# ${CMAKE_CURRENT_LIST_DIR}/generate.py
# --api ${FMHA_FWD_APIS}
# --optdim 32,64,128,256
# )
# set(FMHA_BWD_CODE_GEN_COMMON_ARGS
# ${CMAKE_CURRENT_LIST_DIR}/generate.py
# --api bwd
# --receipt 3
# --optdim 32,64,96,128,256
# )
#
# if(BUILD_TESTING)
# list(APPEND FMHA_FWD_CODE_GEN_COMMON_ARGS --filter *_nlogits*_nskip*,*@*_nlogits*_nbias*,*,*_nlogits*_nskip*_pagedkv)
# endif()
#
# execute_process(
# COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
# --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt
# RESULT_VARIABLE ret
# )
# if(ret AND NOT ret EQUAL 0)
# message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of FWD kernels via Python.")
# endif()
#
# execute_process(
# COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
# --list_blobs ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt
# RESULT_VARIABLE ret
# )
# if(ret AND NOT ret EQUAL 0)
# message(FATAL_ERROR "CK Tile FMHA FAILED to genrate a list of BWD kernels via Python.")
# endif()
#
# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/fwd_blob_list.txt FMHA_FWD_GEN_BLOBS)
# file(STRINGS ${CMAKE_CURRENT_BINARY_DIR}/bwd_blob_list.txt FMHA_BWD_GEN_BLOBS)
#
# add_custom_command(
# OUTPUT ${FMHA_FWD_GEN_BLOBS}
# COMMAND ${Python3_EXECUTABLE} ${FMHA_FWD_CODE_GEN_COMMON_ARGS}
# --output_dir ${CMAKE_CURRENT_BINARY_DIR}
# DEPENDS ${CODE_GEN_SCRIPTS}
# )
#
# add_custom_command(
# OUTPUT ${FMHA_BWD_GEN_BLOBS}
# COMMAND ${Python3_EXECUTABLE} ${FMHA_BWD_CODE_GEN_COMMON_ARGS}
# --output_dir ${CMAKE_CURRENT_BINARY_DIR}
# DEPENDS ${CODE_GEN_SCRIPTS}
# )
#
# set(FMHA_FWD_INSTANCES "tile_fmha_fwd_instances")
# set(FMHA_BWD_INSTANCES "tile_fmha_bwd_instances")
#
# message(DEBUG "adding instances ${FMHA_FWD_INSTANCES}")
# add_library(${FMHA_FWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
# target_include_directories(${FMHA_FWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# target_sources(${FMHA_FWD_INSTANCES} PRIVATE ${FMHA_FWD_GEN_BLOBS})
# set_source_files_properties(${FMHA_FWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
# set_property(TARGET ${FMHA_FWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
#
# message(DEBUG "adding instances ${FMHA_BWD_INSTANCES}")
# add_library(${FMHA_BWD_INSTANCES} OBJECT EXCLUDE_FROM_ALL)
# target_include_directories(${FMHA_BWD_INSTANCES} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
# target_sources(${FMHA_BWD_INSTANCES} PRIVATE ${FMHA_BWD_GEN_BLOBS})
# set_source_files_properties(${FMHA_BWD_GEN_BLOBS} PROPERTIES LANGUAGE HIP)
# set_property(TARGET ${FMHA_BWD_INSTANCES} PROPERTY HIP_ARCHITECTURES ${INST_TARGETS})
#
# set(FMHA_FWD_PRIVATE_COMPILE_OPTIONS)
# set(FMHA_BWD_PRIVATE_COMPILE_OPTIONS)
# set(FMHA_FWD_INTERFACE_COMPILE_OPTIONS)
# set(FMHA_BWD_INTERFACE_COMPILE_OPTIONS)
#
# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-undefined-func-template)
#
# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -Wno-float-equal)
#
# if(NOT DEFINED FMHA_FWD_FAST_EXP2)
# set(FMHA_FWD_FAST_EXP2 ON)
# endif()
#
# if(FMHA_FWD_FAST_EXP2)
# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=1 -fgpu-flush-denormals-to-zero)
# else()
# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_FAST_EXP2=0)
# endif()
# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -fgpu-flush-denormals-to-zero)
#
# if("fwd_splitkv" IN_LIST FMHA_FWD_ENABLE_APIS)
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=1)
# else()
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_SPLITKV_API=0)
# endif()
#
# if("fwd_appendkv" IN_LIST FMHA_FWD_ENABLE_APIS)
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=1)
# else()
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_APPENDKV_API=0)
# endif()
#
# if("pagedkv_prefill" IN_LIST FMHA_FWD_ENABLE_APIS)
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=1)
# else()
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FMHA_FWD_PAGEDKV_API=0)
# endif()
#
# if(CK_USE_OCP_FP8)
# list(APPEND FMHA_FWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
# list(APPEND FMHA_FWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_USE_OCP_FP8)
# endif()
#
# list(APPEND FMHA_BWD_PRIVATE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
# list(APPEND FMHA_BWD_INTERFACE_COMPILE_OPTIONS -DCK_TILE_FLOAT_TO_BFLOAT16_DEFAULT=3)
#
# target_compile_options(${FMHA_FWD_INSTANCES}
# PRIVATE ${FMHA_FWD_PRIVATE_COMPILE_OPTIONS}
# INTERFACE ${FMHA_FWD_INTERFACE_COMPILE_OPTIONS})
# target_compile_options(${FMHA_BWD_INSTANCES}
# PRIVATE ${FMHA_BWD_PRIVATE_COMPILE_OPTIONS}
# INTERFACE ${FMHA_BWD_INTERFACE_COMPILE_OPTIONS})
#
# set(EXAMPLE_FMHA_FWD "tile_example_fmha_fwd")
# set(EXAMPLE_FMHA_BWD "tile_example_fmha_bwd")
#
# message(DEBUG "adding example ${EXAMPLE_FMHA_FWD}")
# add_executable(${EXAMPLE_FMHA_FWD} EXCLUDE_FROM_ALL example_fmha_fwd.cpp)
# target_link_libraries(${EXAMPLE_FMHA_FWD} ${FMHA_FWD_INSTANCES})
# target_include_directories(${EXAMPLE_FMHA_FWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
#
# message(DEBUG "adding example ${EXAMPLE_FMHA_BWD}")
# add_executable(${EXAMPLE_FMHA_BWD} EXCLUDE_FROM_ALL example_fmha_bwd.cpp)
# target_link_libraries(${EXAMPLE_FMHA_BWD} ${FMHA_BWD_INSTANCES})
# target_include_directories(${EXAMPLE_FMHA_BWD} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
#
# set_property(GLOBAL PROPERTY RULE_MESSAGES OFF)

# --- Unified Attention target (kept) ---

#
set(INST_TARGETS ${SUPPORTED_GPU_TARGETS})
# Currently only gfx9 archs are supported by FMHA
list(FILTER INST_TARGETS INCLUDE REGEX "gfx9")
if(NOT INST_TARGETS)
message(WARNING "Skipping Tile Engine FMHA compilation: No supported GPU targets (gfx9) found in SUPPORTED_GPU_TARGETS: ${SUPPORTED_GPU_TARGETS}")
return()
endif()

set(EXAMPLE_FMHA_FWD_V3 "tile_example_unified_attention")
message(DEBUG "adding example ${EXAMPLE_FMHA_FWD_V3}")

add_executable(${EXAMPLE_FMHA_FWD_V3} EXCLUDE_FROM_ALL example_unified_attention.cpp)
target_include_directories(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${CMAKE_CURRENT_LIST_DIR})
file(GLOB FMHA_FWD_V3_INSTANCES CONFIGURE_DEPENDS
"${CMAKE_CURRENT_LIST_DIR}/instances/*.cpp"
)
target_sources(${EXAMPLE_FMHA_FWD_V3} PRIVATE
unified_attention.cpp
${FMHA_FWD_V3_INSTANCES}
)

set(EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-fgpu-flush-denormals-to-zero
-Wno-undefined-func-template
--save-temps
)
set(EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS)

check_cxx_compiler_flag("-mllvm --amdgpu-disable-packed-fp32=1" HAS_DISABLE_PACKED_FP32)
if(HAS_DISABLE_PACKED_FP32)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS
-mllvm --amdgpu-disable-packed-fp32=1
)
list(APPEND EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS
-DCK_TILE_DISABLE_PACKED_FP32=1
)
endif()

target_compile_options(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_OPTIONS})
target_compile_definitions(${EXAMPLE_FMHA_FWD_V3} PRIVATE ${EXAMPLE_FMHA_FWD_V3_COMPILE_DEFINITIONS})
Loading
Loading