File tree Expand file tree Collapse file tree 7 files changed +38
-15
lines changed
backends/vulkan/runtime/graph/ops Expand file tree Collapse file tree 7 files changed +38
-15
lines changed Original file line number Diff line number Diff line change @@ -12,10 +12,14 @@ sdpa_compute_attn_weights_coop:
1212 TILE_K4 : 1
1313 TILE_N4 : 1
1414 generate_variant_forall :
15+ combination :
16+ parameter_names : [IO_STORAGE, K_CACHE_STORAGE]
17+ combos :
18+ - parameter_values : [texture3d, texture3d]
19+ - parameter_values : [buffer, texture3d]
20+ - parameter_values : [buffer, buffer]
1521 DTYPE :
1622 - VALUE : float
1723 - VALUE : half
1824 shader_variants :
19- - NAME : sdpa_compute_attn_weights_coop_texture3d_texture3d
20- - NAME : sdpa_compute_attn_weights_coop_buffer_texture3d
21- IO_STORAGE : buffer
25+ - NAME : sdpa_compute_attn_weights_coop
Original file line number Diff line number Diff line change @@ -13,10 +13,14 @@ sdpa_compute_attn_weights_tiled:
1313 TILE_K4 : 1
1414 TILE_N4 : 1
1515 generate_variant_forall :
16+ combination :
17+ parameter_names : [IO_STORAGE, K_CACHE_STORAGE]
18+ combos :
19+ - parameter_values : [texture3d, texture3d]
20+ - parameter_values : [buffer, texture3d]
21+ - parameter_values : [buffer, buffer]
1622 DTYPE :
1723 - VALUE : float
1824 - VALUE : half
1925 shader_variants :
20- - NAME : sdpa_compute_attn_weights_tiled_texture3d_texture3d
21- - NAME : sdpa_compute_attn_weights_tiled_buffer_texture3d
22- IO_STORAGE : buffer
26+ - NAME : sdpa_compute_attn_weights_tiled
Original file line number Diff line number Diff line change @@ -12,10 +12,14 @@ sdpa_compute_out_coop:
1212 TILE_K4 : 1
1313 TILE_N4 : 1
1414 generate_variant_forall :
15+ combination :
16+ parameter_names : [IO_STORAGE, V_CACHE_STORAGE]
17+ combos :
18+ - parameter_values : [texture3d, texture3d]
19+ - parameter_values : [buffer, texture3d]
20+ - parameter_values : [buffer, buffer]
1521 DTYPE :
1622 - VALUE : float
1723 - VALUE : half
1824 shader_variants :
19- - NAME : sdpa_compute_out_coop_texture3d_texture3d
20- - NAME : sdpa_compute_out_coop_buffer_texture3d
21- IO_STORAGE : buffer
25+ - NAME : sdpa_compute_out_coop
Original file line number Diff line number Diff line change @@ -13,10 +13,14 @@ sdpa_compute_out_tiled:
1313 TILE_K4 : 1
1414 TILE_N4 : 1
1515 generate_variant_forall :
16+ combination :
17+ parameter_names : [IO_STORAGE, V_CACHE_STORAGE]
18+ combos :
19+ - parameter_values : [texture3d, texture3d]
20+ - parameter_values : [buffer, texture3d]
21+ - parameter_values : [buffer, buffer]
1622 DTYPE :
1723 - VALUE : float
1824 - VALUE : half
1925 shader_variants :
20- - NAME : sdpa_compute_out_tiled_texture3d_texture3d
21- - NAME : sdpa_compute_out_tiled_buffer_texture3d
22- IO_STORAGE : buffer
26+ - NAME : sdpa_compute_out_tiled
Original file line number Diff line number Diff line change 55#define IN_VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)}
66#define T ${buffer_scalar_type(DTYPE)}
77
8+ $if OUTPUT_STORAGE == "buffer ":
9+ #define OUTPUT_BUFFER
810$if INPUT_STORAGE == "buffer ":
911 #define INPUT_BUFFER
1012
Original file line number Diff line number Diff line change @@ -10,10 +10,14 @@ sdpa_kv_cache_update:
1010 INPUT_STORAGE : texture3d
1111 OUTPUT_STORAGE : texture3d
1212 generate_variant_forall :
13+ combination :
14+ parameter_names : [OUTPUT_STORAGE, INPUT_STORAGE]
15+ combos :
16+ - parameter_values : [texture3d, texture3d]
17+ - parameter_values : [texture3d, buffer]
18+ - parameter_values : [buffer, buffer]
1319 DTYPE :
1420 - VALUE : half
1521 - VALUE : float
1622 shader_variants :
17- - NAME : sdpa_kv_cache_update_texture3d
18- - NAME : sdpa_kv_cache_update_buffer
19- INPUT_STORAGE : buffer
23+ - NAME : sdpa_kv_cache_update
Original file line number Diff line number Diff line change @@ -282,6 +282,7 @@ void add_sdpa_kv_cache_update_node(
282282 const ValueRef projected,
283283 const ValueRef cache) {
284284 std::string kernel_name (" sdpa_kv_cache_update" );
285+ add_storage_type_suffix (kernel_name, graph.storage_type_of (cache));
285286 add_storage_type_suffix (kernel_name, graph.storage_type_of (projected));
286287 add_dtype_suffix (kernel_name, graph.dtype_of (projected));
287288
You can’t perform that action at this time.
0 commit comments