2323#include < cassert>
2424#include < iostream>
2525
26+ //
27+ // SDPA Mode Enum
28+ //
29+
30+ enum class SDPAMode { DECOMPOSED, FUSED, ATTN_WEIGHT_ONLY };
31+
32+ std::ostream& operator <<(std::ostream& os, const SDPAMode& mode) {
33+ switch (mode) {
34+ case SDPAMode::DECOMPOSED:
35+ return os << " DECOMPOSED" ;
36+ case SDPAMode::FUSED:
37+ return os << " FUSED" ;
38+ case SDPAMode::ATTN_WEIGHT_ONLY:
39+ return os << " ATTN_WEIGHT_ONLY" ;
40+ }
41+ return os;
42+ }
43+
2644namespace torch {
2745namespace executor {
2846namespace native {
@@ -74,7 +92,7 @@ at::Tensor sdpa_with_kv_cache_aten(
7492 const int64_t seq_len,
7593 // @lint-ignore CLANGTIDY facebook-hte-ConstantArgumentPassByValue
7694 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
77- const std::optional<at::Tensor> attn_mask,
95+ const std::optional<at::Tensor>& attn_mask,
7896 const double dropout_p,
7997 const bool is_causal,
8098 // @lint-ignore CLANGTIDY facebook-hte-ParameterMightThrowOnCopy
@@ -161,10 +179,11 @@ at::Tensor sdpa_reference_impl(
161179 at::Tensor& value_cache,
162180 const int64_t start_pos,
163181 const int64_t seq_len,
164- const std::optional<at::Tensor> __attn_mask_ignored,
182+ const std::optional<at::Tensor>& __attn_mask_ignored,
165183 const double dropout_p,
166184 const bool is_causal,
167- const std::optional<double > scale) {
185+ const std::optional<double > scale,
186+ SDPAMode mode = SDPAMode::DECOMPOSED) {
168187 at::Tensor attn_mask =
169188 construct_attention_mask (q_projected, key_cache, start_pos);
170189
@@ -202,6 +221,10 @@ at::Tensor sdpa_reference_impl(
202221 float scale_factor = 1.0 / sqrt (q_transposed.size (-1 ));
203222 at::Tensor attn_weight = attn_weight_prescale * scale_factor + attn_mask;
204223
224+ if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
225+ return attn_weight;
226+ }
227+
205228 at::Tensor attn_weight_softmax = at::softmax (attn_weight, -1 );
206229 at::Tensor out = at::matmul (attn_weight_softmax, v_transposed);
207230
@@ -268,7 +291,8 @@ void test_vulkan_sdpa(
268291 const int num_kv_heads,
269292 const int batch_size,
270293 vkcompute::utils::StorageType storage_type,
271- at::ScalarType dtype = at::kFloat ) {
294+ at::ScalarType dtype = at::kFloat ,
295+ SDPAMode mode = SDPAMode::DECOMPOSED) {
272296 // compute the max sequence length
273297 int max_seq_len = start_input_pos;
274298 for (int i = 0 ; i < sequence_lens.size (); ++i) {
@@ -296,6 +320,9 @@ void test_vulkan_sdpa(
296320
297321 // Get reference output
298322 at::Tensor out = at::empty_like (q);
323+ if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
324+ out = at::empty ({batch_size, num_heads, init_seq_len, init_seq_len});
325+ }
299326
300327 // Build Vulkan SDPA graph
301328 using namespace vkcompute ;
@@ -330,22 +357,87 @@ void test_vulkan_sdpa(
330357 const ValueRef r_out = graph.add_tensor (
331358 out.sizes ().vec (), from_at_scalartype (out.scalar_type ()), storage_type);
332359
333- VK_GET_OP_FN (" sdpa_with_kv_cache.default" )
334- (graph,
335- {
336- r_q.value ,
337- r_k.value ,
338- r_v.value ,
339- r_k_cache_data,
340- r_v_cache_data,
341- r_input_pos_symint,
342- kDummyValueRef , // sequence_len
343- kDummyValueRef , // attn_mask
344- kDummyValueRef , // dropout_p
345- kDummyValueRef , // is_causal
346- kDummyValueRef , // scale
347- r_out,
348- });
360+ switch (mode) {
361+ case SDPAMode::DECOMPOSED: {
362+ const ValueRef r_k_cache = graph.add_tensor (
363+ k_cache_data.sizes ().vec (),
364+ from_at_scalartype (k_cache_data.scalar_type ()),
365+ storage_type);
366+ const ValueRef r_v_cache = graph.add_tensor (
367+ v_cache_data.sizes ().vec (),
368+ from_at_scalartype (v_cache_data.scalar_type ()),
369+ storage_type);
370+ const ValueRef r_dummy_out = graph.add_tensor (
371+ {1 }, from_at_scalartype (out.scalar_type ()), utils::kBuffer );
372+ VK_GET_OP_FN (" update_cache.default" )
373+ (graph,
374+ {
375+ r_k.value ,
376+ r_k_cache,
377+ r_input_pos_symint,
378+ r_dummy_out,
379+ });
380+ VK_GET_OP_FN (" update_cache.default" )
381+ (graph,
382+ {
383+ r_v.value ,
384+ r_v_cache,
385+ r_input_pos_symint,
386+ r_dummy_out,
387+ });
388+ VK_GET_OP_FN (" llama.custom_sdpa.default" )
389+ (graph,
390+ {
391+ r_q.value ,
392+ r_k_cache,
393+ r_v_cache,
394+ r_input_pos_symint,
395+ kDummyValueRef , // attn_mask
396+ kDummyValueRef , // dropout_p
397+ kDummyValueRef , // is_causal
398+ kDummyValueRef , // scale
399+ r_out,
400+ });
401+ } break ;
402+ case SDPAMode::FUSED:
403+ VK_GET_OP_FN (" sdpa_with_kv_cache.default" )
404+ (graph,
405+ {
406+ r_q.value ,
407+ r_k.value ,
408+ r_v.value ,
409+ r_k_cache_data,
410+ r_v_cache_data,
411+ r_input_pos_symint,
412+ kDummyValueRef , // sequence_len
413+ kDummyValueRef , // attn_mask
414+ kDummyValueRef , // dropout_p
415+ kDummyValueRef , // is_causal
416+ kDummyValueRef , // scale
417+ r_out,
418+ });
419+ break ;
420+ case SDPAMode::ATTN_WEIGHT_ONLY:
421+ VK_GET_OP_FN (" testing.compute_attn_weight_with_kv_cache.default" )
422+ (graph,
423+ {
424+ r_q.value ,
425+ r_k.value ,
426+ r_v.value ,
427+ r_k_cache_data,
428+ r_v_cache_data,
429+ r_input_pos_symint,
430+ kDummyValueRef , // sequence_len
431+ kDummyValueRef , // attn_mask
432+ kDummyValueRef , // dropout_p
433+ kDummyValueRef , // is_causal
434+ kDummyValueRef , // scale
435+ r_out,
436+ });
437+ break ;
438+ default :
439+ VK_THROW (" Unsupported SDPA mode" );
440+ }
349441
350442 ValueRef staging_out = graph.set_output_tensor (r_out);
351443
@@ -378,7 +470,7 @@ void test_vulkan_sdpa(
378470 v = at::rand_like (k);
379471
380472 at::Tensor reference_out = sdpa_reference_impl (
381- q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0 , true , {});
473+ q, k, v, k_cache, v_cache, input_pos, seq_len, {}, 0.0 , true , {}, mode );
382474
383475 graph.set_symint (r_input_pos_symint, input_pos);
384476 graph.resize_input (0 , q.sizes ().vec ());
@@ -393,15 +485,38 @@ void test_vulkan_sdpa(
393485
394486 graph.execute ();
395487
396- out = at::empty_like (q);
488+ if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
489+ const int context_len = input_pos + seq_len;
490+ const int context_len_align_up4 = (context_len + 3 ) & ~3 ;
491+ const int seq_len_align_up4 = (seq_len + 3 ) & ~3 ;
492+
493+ out = at::empty (
494+ {batch_size, num_heads, seq_len_align_up4, context_len_align_up4},
495+ q.options ());
496+ } else {
497+ out = at::empty_like (q);
498+ }
397499 EXTRACT_TENSOR (out);
398500
501+ if (mode == SDPAMode::ATTN_WEIGHT_ONLY) {
502+ // Index vk_out to only include the relevant seq_len and context_len
503+ // dimensions
504+ int context_len = input_pos + seq_len;
505+ vk_out = vk_out.index (
506+ {at::indexing::Slice (),
507+ at::indexing::Slice (),
508+ at::indexing::Slice (0 , seq_len),
509+ at::indexing::Slice (0 , context_len)});
510+ }
511+
399512 const bool output_correct = at::allclose (reference_out, vk_out);
400513 if (!output_correct) {
401514 // Print only differing tensor elements side by side for easier comparison
402515 auto ref_flat = reference_out.flatten ();
403516 auto vk_flat = vk_out.flatten ();
404517 auto numel = ref_flat.numel ();
518+ std::cout << " While testing " << mode << " mode with " << storage_type
519+ << " storage" << std::endl;
405520 std::cout << " reference_out\t vk_out\t index" << std::endl;
406521 int first_diff_idx = -1 ;
407522 auto sizes = reference_out.sizes ();
@@ -466,27 +581,32 @@ void test_vulkan_sdpa(
466581 const int num_kv_heads,
467582 const int batch_size,
468583 at::ScalarType dtype = at::kFloat ) {
469- // Test texture
470- test_vulkan_sdpa (
471- start_input_pos,
472- sequence_lens,
473- head_dim,
474- num_heads,
475- num_kv_heads,
476- batch_size,
477- vkcompute::utils::kTexture3D ,
478- dtype);
479-
480- // Test buffer
481- test_vulkan_sdpa (
482- start_input_pos,
483- sequence_lens,
484- head_dim,
485- num_heads,
486- num_kv_heads,
487- batch_size,
488- vkcompute::utils::kBuffer ,
489- dtype);
584+ for (SDPAMode mode :
585+ {SDPAMode::ATTN_WEIGHT_ONLY, SDPAMode::DECOMPOSED, SDPAMode::FUSED}) {
586+ // Test texture
587+ test_vulkan_sdpa (
588+ start_input_pos,
589+ sequence_lens,
590+ head_dim,
591+ num_heads,
592+ num_kv_heads,
593+ batch_size,
594+ vkcompute::utils::kTexture3D ,
595+ dtype,
596+ mode);
597+
598+ // Test buffer
599+ test_vulkan_sdpa (
600+ start_input_pos,
601+ sequence_lens,
602+ head_dim,
603+ num_heads,
604+ num_kv_heads,
605+ batch_size,
606+ vkcompute::utils::kBuffer ,
607+ dtype,
608+ mode);
609+ }
490610}
491611
492612TEST (VulkanSDPATest, test_sdpa_op_small_params) {
0 commit comments