Skip to content

Commit 8010c98

Browse files
author
ssjia
committed
[ET-VK][ez] Update SDPA test to be able to test different SDPA modes
Title says it all! The purpose of this diff is twofold: 1. Test SDPA as both a fused operator (sdpa_with_kv_cache) and decomposed update_cache and custom_sdpa ops in order to detect possible regressions with being able to support older models 2. Make it easier to debug issues with SDPA by exposing a mode that tests only the attention weight computation. Title says it all! Update SDPA op to use buffer storage for cache tensors if projected tensors are buffer. Also included is a small change to ensure that cache tensors use the same storage type as input tensors. Differential Revision: [D86226135](https://our.internmc.facebook.com/intern/diff/D86226135/) [ghstack-poisoned]
1 parent 1d91ec8 commit 8010c98

File tree

3 files changed

+210
-48
lines changed

3 files changed

+210
-48
lines changed

backends/vulkan/op_registry.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def register_dequantize_for_conv2d_op():
630630
@update_features("llama::sdpa_with_kv_cache")
631631
def register_sdpa_with_kv_cache_op():
632632
return OpFeatures(
633-
inputs_storage=utils.WIDTH_PACKED_TEXTURE,
633+
inputs_storage=utils.CONTIGUOUS_ANY,
634634
supports_resize=True,
635635
supports_prepacking=True,
636636
)

backends/vulkan/runtime/graph/ops/impl/SDPA.cpp

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -526,10 +526,11 @@ void sdpa_with_kv_cache_impl(
526526

527527
(void)sequence_len;
528528

529-
const ValueRef k_cache = prepack_standard(
530-
graph, k_cache_data, utils::kTexture3D, utils::kWidthPacked);
531-
const ValueRef v_cache = prepack_standard(
532-
graph, v_cache_data, utils::kTexture3D, utils::kWidthPacked);
529+
utils::StorageType cache_storage = graph.storage_type_of(q_projected);
530+
const ValueRef k_cache =
531+
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
532+
const ValueRef v_cache =
533+
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);
533534

534535
update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
535536
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
@@ -547,10 +548,51 @@ void sdpa_with_kv_cache_impl(
547548
out});
548549
}
549550

551+
void compute_attn_weight_with_kv_cache_impl(
552+
ComputeGraph& graph,
553+
const std::vector<ValueRef>& args) {
554+
int arg_idx = 0;
555+
const ValueRef q_projected = args[arg_idx++];
556+
const ValueRef k_projected = args[arg_idx++];
557+
const ValueRef v_projected = args[arg_idx++];
558+
const ValueRef k_cache_data = args[arg_idx++];
559+
const ValueRef v_cache_data = args[arg_idx++];
560+
const ValueRef input_pos_symint = args[arg_idx++];
561+
const ValueRef sequence_len = args[arg_idx++];
562+
const ValueRef attn_mask = args[arg_idx++];
563+
(void)attn_mask;
564+
const ValueRef dropout_p = args[arg_idx++];
565+
(void)dropout_p;
566+
const ValueRef is_causal = args[arg_idx++];
567+
(void)is_causal;
568+
const ValueRef scale = args[arg_idx++];
569+
(void)scale;
570+
571+
// Output tensors
572+
const ValueRef out = args[arg_idx++];
573+
574+
(void)sequence_len;
575+
576+
utils::StorageType cache_storage = graph.storage_type_of(q_projected);
577+
const ValueRef k_cache =
578+
prepack_standard(graph, k_cache_data, cache_storage, utils::kWidthPacked);
579+
const ValueRef v_cache =
580+
prepack_standard(graph, v_cache_data, cache_storage, utils::kWidthPacked);
581+
582+
update_cache_impl(graph, {k_projected, k_cache, input_pos_symint, -1});
583+
update_cache_impl(graph, {v_projected, v_cache, input_pos_symint, -1});
584+
585+
add_sdpa_compute_attn_weights_node(
586+
graph, q_projected, k_cache, input_pos_symint, out);
587+
}
588+
550589
REGISTER_OPERATORS {
551590
VK_REGISTER_OP(sdpa_with_kv_cache.default, sdpa_with_kv_cache_impl);
552591
VK_REGISTER_OP(update_cache.default, update_cache_impl);
553592
VK_REGISTER_OP(llama.custom_sdpa.default, sdpa_impl);
593+
VK_REGISTER_OP(
594+
testing.compute_attn_weight_with_kv_cache.default,
595+
compute_attn_weight_with_kv_cache_impl);
554596
}
555597

556598
} // namespace vkcompute

backends/vulkan/test/op_tests/sdpa_test.cpp

Lines changed: 163 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,24 @@
2323
#include <cassert>
2424
#include <iostream>
2525

26+
//
27+
// SDPA Mode Enum
28+
//
29+
30+
enum class SDPAMode { DECOMPOSED, FUSED, ATTN_WEIGHT_ONLY };
31+
32+
std::ostream& operator<<(std::ostream& os, const SDPAMode& mode) {
33+
switch (mode) {
34+
case SDPAMode::DECOMPOSED:
35+
return os << "DECOMPOSED";
36+
case SDPAMode::FUSED:
37+
return os << "FUSED";
38+
case SDPAMode::ATTN_WEIGHT_ONLY:
39+
return os << "ATTN_WEIGHT_ONLY";
40+
}
41+
return os;
42+
}
43+
2644
namespace torch {
2745
namespace executor {
2846
namespace native {
@@ -74,7 +92,7 @@ at::Tensor sdpa_with_kv_cache_aten(
7492
const int64_t seq_len,
7593
// @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
7694
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
77-
const std::optional<at::Tensor> attn_mask,
95+
const std::optional<at::Tensor>& attn_mask,
7896
const double dropout_p,
7997
const bool is_causal,
8098
// @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
@@ -161,10 +179,11 @@ at::Tensor sdpa_reference_impl(
161179
at::Tensor& value_cache,
162180
const int64_t start_pos,
163181
const int64_t seq_len,
164-
const std::optional<at::Tensor> __attn_mask_ignored,
182+
const std::optional<at::Tensor>& __attn_mask_ignored,
165183
const double dropout_p,
166184
const bool is_causal,
167-
const std::optional<double> scale) {
185+
const std::optional<double> scale,
186+
SDPAMode mode = SDPAMode::DECOMPOSED) {
168187
at::Tensor attn_mask =
169188
construct_attention_mask(q_projected, key_cache, start_pos);
170189

@@ -202,6 +221,10 @@ at::Tensor sdpa_reference_impl(
202221
float scale_factor = 1.0 / sqrt(q_transposed.size(-1));
203222
at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask;
204223

224+
if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
225+
return attn_weight;
226+
}
227+
205228
at::Tensor attn_weight_softmax = at::softmax(attn_weight, -1);
206229
at::Tensor out = at::matmul(attn_weight_softmax, v_transposed);
207230

@@ -268,7 +291,8 @@ void test_vulkan_sdpa(
268291
const int num_kv_heads,
269292
const int batch_size,
270293
vkcompute::utils::StorageType storage_type,
271-
at::ScalarType dtype = at::kFloat) {
294+
at::ScalarType dtype = at::kFloat,
295+
SDPAMode mode = SDPAMode::DECOMPOSED) {
272296
// compute the max sequence length
273297
int max_seq_len = start_input_pos;
274298
for (int i = 0; i < sequence_lens.size(); ++i) {
@@ -296,6 +320,9 @@ void test_vulkan_sdpa(
296320

297321
// Get reference output
298322
at::Tensor out = at::empty_like(q);
323+
if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
324+
out = at::empty({batch_size, num_heads, init_seq_len, init_seq_len});
325+
}
299326

300327
// Build Vulkan SDPA graph
301328
using namespace vkcompute;
@@ -330,22 +357,87 @@ void test_vulkan_sdpa(
330357
const ValueRef r_out = graph.add_tensor(
331358
out.sizes().vec(), from_at_scalartype(out.scalar_type()), storage_type);
332359

333-
VK_GET_OP_FN("sdpa_with_kv_cache.default")
334-
(graph,
335-
{
336-
r_q.value,
337-
r_k.value,
338-
r_v.value,
339-
r_k_cache_data,
340-
r_v_cache_data,
341-
r_input_pos_symint,
342-
kDummyValueRef, // sequence_len
343-
kDummyValueRef, // attn_mask
344-
kDummyValueRef, // dropout_p
345-
kDummyValueRef, // is_causal
346-
kDummyValueRef, // scale
347-
r_out,
348-
});
360+
switch (mode) {
361+
case SDPAMode::DECOMPOSED: {
362+
const ValueRef r_k_cache = graph.add_tensor(
363+
k_cache_data.sizes().vec(),
364+
from_at_scalartype(k_cache_data.scalar_type()),
365+
storage_type);
366+
const ValueRef r_v_cache = graph.add_tensor(
367+
v_cache_data.sizes().vec(),
368+
from_at_scalartype(v_cache_data.scalar_type()),
369+
storage_type);
370+
const ValueRef r_dummy_out = graph.add_tensor(
371+
{1}, from_at_scalartype(out.scalar_type()), utils::kBuffer);
372+
VK_GET_OP_FN("update_cache.default")
373+
(graph,
374+
{
375+
r_k.value,
376+
r_k_cache,
377+
r_input_pos_symint,
378+
r_dummy_out,
379+
});
380+
VK_GET_OP_FN("update_cache.default")
381+
(graph,
382+
{
383+
r_v.value,
384+
r_v_cache,
385+
r_input_pos_symint,
386+
r_dummy_out,
387+
});
388+
VK_GET_OP_FN("llama.custom_sdpa.default")
389+
(graph,
390+
{
391+
r_q.value,
392+
r_k_cache,
393+
r_v_cache,
394+
r_input_pos_symint,
395+
kDummyValueRef, // attn_mask
396+
kDummyValueRef, // dropout_p
397+
kDummyValueRef, // is_causal
398+
kDummyValueRef, // scale
399+
r_out,
400+
});
401+
} break;
402+
case SDPAMode::FUSED:
403+
VK_GET_OP_FN("sdpa_with_kv_cache.default")
404+
(graph,
405+
{
406+
r_q.value,
407+
r_k.value,
408+
r_v.value,
409+
r_k_cache_data,
410+
r_v_cache_data,
411+
r_input_pos_symint,
412+
kDummyValueRef, // sequence_len
413+
kDummyValueRef, // attn_mask
414+
kDummyValueRef, // dropout_p
415+
kDummyValueRef, // is_causal
416+
kDummyValueRef, // scale
417+
r_out,
418+
});
419+
break;
420+
case SDPAMode::ATTN_WEIGHT_ONLY:
421+
VK_GET_OP_FN("testing.compute_attn_weight_with_kv_cache.default")
422+
(graph,
423+
{
424+
r_q.value,
425+
r_k.value,
426+
r_v.value,
427+
r_k_cache_data,
428+
r_v_cache_data,
429+
r_input_pos_symint,
430+
kDummyValueRef, // sequence_len
431+
kDummyValueRef, // attn_mask
432+
kDummyValueRef, // dropout_p
433+
kDummyValueRef, // is_causal
434+
kDummyValueRef, // scale
435+
r_out,
436+
});
437+
break;
438+
default:
439+
VK_THROW("Unsupported SDPA mode");
440+
}
349441

350442
ValueRef staging_out = graph.set_output_tensor(r_out);
351443

@@ -378,7 +470,7 @@ void test_vulkan_sdpa(
378470
v = at::rand_like(k);
379471

380472
at::Tensor reference_out = sdpa_reference_impl(
381-
q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {});
473+
q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0, true, {}, mode);
382474

383475
graph.set_symint(r_input_pos_symint, input_pos);
384476
graph.resize_input(0, q.sizes().vec());
@@ -393,15 +485,38 @@ void test_vulkan_sdpa(
393485

394486
graph.execute();
395487

396-
out = at::empty_like(q);
488+
if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
489+
const int context_len = input_pos + seq_len;
490+
const int context_len_align_up4 = (context_len + 3) & ~3;
491+
const int seq_len_align_up4 = (seq_len + 3) & ~3;
492+
493+
out = at::empty(
494+
{batch_size, num_heads, seq_len_align_up4, context_len_align_up4},
495+
q.options());
496+
} else {
497+
out = at::empty_like(q);
498+
}
397499
EXTRACT_TENSOR(out);
398500

501+
if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
502+
// Index vk_out to only include the relevant seq_len and context_len
503+
// dimensions
504+
int context_len = input_pos + seq_len;
505+
vk_out = vk_out.index(
506+
{at::indexing::Slice(),
507+
at::indexing::Slice(),
508+
at::indexing::Slice(0, seq_len),
509+
at::indexing::Slice(0, context_len)});
510+
}
511+
399512
const bool output_correct = at::allclose(reference_out, vk_out);
400513
if (!output_correct) {
401514
// Print only differing tensor elements side by side for easier comparison
402515
auto ref_flat = reference_out.flatten();
403516
auto vk_flat = vk_out.flatten();
404517
auto numel = ref_flat.numel();
518+
std::cout << "While testing " << mode << " mode with " << storage_type
519+
<< " storage" << std::endl;
405520
std::cout << "reference_out\tvk_out\tindex" << std::endl;
406521
int first_diff_idx = -1;
407522
auto sizes = reference_out.sizes();
@@ -466,27 +581,32 @@ void test_vulkan_sdpa(
466581
const int num_kv_heads,
467582
const int batch_size,
468583
at::ScalarType dtype = at::kFloat) {
469-
// Test texture
470-
test_vulkan_sdpa(
471-
start_input_pos,
472-
sequence_lens,
473-
head_dim,
474-
num_heads,
475-
num_kv_heads,
476-
batch_size,
477-
vkcompute::utils::kTexture3D,
478-
dtype);
479-
480-
// Test buffer
481-
test_vulkan_sdpa(
482-
start_input_pos,
483-
sequence_lens,
484-
head_dim,
485-
num_heads,
486-
num_kv_heads,
487-
batch_size,
488-
vkcompute::utils::kBuffer,
489-
dtype);
584+
for (SDPAMode mode :
585+
{SDPAMode::ATTN_WEIGHT_ONLY, SDPAMode::DECOMPOSED, SDPAMode::FUSED}) {
586+
// Test texture
587+
test_vulkan_sdpa(
588+
start_input_pos,
589+
sequence_lens,
590+
head_dim,
591+
num_heads,
592+
num_kv_heads,
593+
batch_size,
594+
vkcompute::utils::kTexture3D,
595+
dtype,
596+
mode);
597+
598+
// Test buffer
599+
test_vulkan_sdpa(
600+
start_input_pos,
601+
sequence_lens,
602+
head_dim,
603+
num_heads,
604+
num_kv_heads,
605+
batch_size,
606+
vkcompute::utils::kBuffer,
607+
dtype,
608+
mode);
609+
}
490610
}
491611

492612
TEST(VulkanSDPATest, test_sdpa_op_small_params) {

0 commit comments

Comments
 (0)