Skip to content

Commit 1d91ec8

Browse files
author
ssjia
committed
[ET-VK][ez] Refactor yaml configs for SDPA shaders
Title says it all! Use the new combos codegen API which makes it easier to express generating storage type combinations. Differential Revision: [D86226138](https://our.internmc.facebook.com/intern/diff/D86226138/) [ghstack-poisoned]
1 parent bde6b11 commit 1d91ec8

File tree

7 files changed

+38
-15
lines changed

7 files changed

+38
-15
lines changed

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_coop.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@ sdpa_compute_attn_weights_coop:
1212
TILE_K4: 1
1313
TILE_N4: 1
1414
generate_variant_forall:
15+
combination:
16+
parameter_names: [IO_STORAGE, K_CACHE_STORAGE]
17+
combos:
18+
- parameter_values: [texture3d, texture3d]
19+
- parameter_values: [buffer, texture3d]
20+
- parameter_values: [buffer, buffer]
1521
DTYPE:
1622
- VALUE: float
1723
- VALUE: half
1824
shader_variants:
19-
- NAME: sdpa_compute_attn_weights_coop_texture3d_texture3d
20-
- NAME: sdpa_compute_attn_weights_coop_buffer_texture3d
21-
IO_STORAGE: buffer
25+
- NAME: sdpa_compute_attn_weights_coop

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_attn_weights_tiled.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ sdpa_compute_attn_weights_tiled:
1313
TILE_K4: 1
1414
TILE_N4: 1
1515
generate_variant_forall:
16+
combination:
17+
parameter_names: [IO_STORAGE, K_CACHE_STORAGE]
18+
combos:
19+
- parameter_values: [texture3d, texture3d]
20+
- parameter_values: [buffer, texture3d]
21+
- parameter_values: [buffer, buffer]
1622
DTYPE:
1723
- VALUE: float
1824
- VALUE: half
1925
shader_variants:
20-
- NAME: sdpa_compute_attn_weights_tiled_texture3d_texture3d
21-
- NAME: sdpa_compute_attn_weights_tiled_buffer_texture3d
22-
IO_STORAGE: buffer
26+
- NAME: sdpa_compute_attn_weights_tiled

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_coop.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,14 @@ sdpa_compute_out_coop:
1212
TILE_K4: 1
1313
TILE_N4: 1
1414
generate_variant_forall:
15+
combination:
16+
parameter_names: [IO_STORAGE, V_CACHE_STORAGE]
17+
combos:
18+
- parameter_values: [texture3d, texture3d]
19+
- parameter_values: [buffer, texture3d]
20+
- parameter_values: [buffer, buffer]
1521
DTYPE:
1622
- VALUE: float
1723
- VALUE: half
1824
shader_variants:
19-
- NAME: sdpa_compute_out_coop_texture3d_texture3d
20-
- NAME: sdpa_compute_out_coop_buffer_texture3d
21-
IO_STORAGE: buffer
25+
- NAME: sdpa_compute_out_coop

backends/vulkan/runtime/graph/ops/glsl/sdpa_compute_out_tiled.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,14 @@ sdpa_compute_out_tiled:
1313
TILE_K4: 1
1414
TILE_N4: 1
1515
generate_variant_forall:
16+
combination:
17+
parameter_names: [IO_STORAGE, V_CACHE_STORAGE]
18+
combos:
19+
- parameter_values: [texture3d, texture3d]
20+
- parameter_values: [buffer, texture3d]
21+
- parameter_values: [buffer, buffer]
1622
DTYPE:
1723
- VALUE: float
1824
- VALUE: half
1925
shader_variants:
20-
- NAME: sdpa_compute_out_tiled_texture3d_texture3d
21-
- NAME: sdpa_compute_out_tiled_buffer_texture3d
22-
IO_STORAGE: buffer
26+
- NAME: sdpa_compute_out_tiled

backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.glsl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
#define IN_VEC4_T ${texel_load_type(DTYPE, INPUT_STORAGE)}
66
#define T ${buffer_scalar_type(DTYPE)}
77

8+
$if OUTPUT_STORAGE == "buffer":
9+
#define OUTPUT_BUFFER
810
$if INPUT_STORAGE == "buffer":
911
#define INPUT_BUFFER
1012

backends/vulkan/runtime/graph/ops/glsl/sdpa_kv_cache_update.yaml

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,14 @@ sdpa_kv_cache_update:
1010
INPUT_STORAGE: texture3d
1111
OUTPUT_STORAGE: texture3d
1212
generate_variant_forall:
13+
combination:
14+
parameter_names: [OUTPUT_STORAGE, INPUT_STORAGE]
15+
combos:
16+
- parameter_values: [texture3d, texture3d]
17+
- parameter_values: [texture3d, buffer]
18+
- parameter_values: [buffer, buffer]
1319
DTYPE:
1420
- VALUE: half
1521
- VALUE: float
1622
shader_variants:
17-
- NAME: sdpa_kv_cache_update_texture3d
18-
- NAME: sdpa_kv_cache_update_buffer
19-
INPUT_STORAGE: buffer
23+
- NAME: sdpa_kv_cache_update

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ void add_sdpa_kv_cache_update_node(
282282
const ValueRef projected,
283283
const ValueRef cache) {
284284
std::string kernel_name("sdpa_kv_cache_update");
285+
add_storage_type_suffix(kernel_name, graph.storage_type_of(cache));
285286
add_storage_type_suffix(kernel_name, graph.storage_type_of(projected));
286287
add_dtype_suffix(kernel_name, graph.dtype_of(projected));
287288

0 commit comments

Comments
 (0)