From fa9cef3513275317d5f6bf46cf74f0cc1c2cc41b Mon Sep 17 00:00:00 2001 From: Eric Buehler Date: Thu, 31 Oct 2024 09:13:09 -0400 Subject: [PATCH] Use set_params! --- candle-metal-kernels/src/lib.rs | 82 ++++++++----------- .../src/scaled_dot_product_attention.metal | 4 +- 2 files changed, 36 insertions(+), 50 deletions(-) diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index 28cbc5499d..9ff86e6475 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -8,7 +8,7 @@ use std::sync::RwLock; pub mod utils; pub use utils::BufferOffset; -use utils::{get_block_dims, linear_split, EncoderProvider}; +use utils::{get_block_dims, linear_split, EncoderParam, EncoderProvider}; const AFFINE: &str = include_str!("affine.metal"); const BINARY: &str = include_str!("binary.metal"); @@ -1809,25 +1809,27 @@ pub fn call_sdpa_full( }; let batch_strides = [b_stride_q, b_stride_k, b_stride_v, b_stride_o]; - encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger); - encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger); - encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger); - encoder.set_buffer(3, Some(&output), 0); + impl EncoderParam for MLXFastAttentionParams { + fn set_param(encoder: &ComputeCommandEncoderRef, position: u64, data: Self) { + encoder.set_bytes( + position, + core::mem::size_of::() as u64, + &data as *const MLXFastAttentionParams as *const c_void, + ); + } + } - encoder.set_bytes( - 4, - std::mem::size_of::() as u64, - ¶ms as *const MLXFastAttentionParams as *const c_void, - ); - encoder.set_bytes( - 6, - (std::mem::size_of::() * batch_shape.len()) as u64, - batch_shape.as_ptr() as *const i32 as *const c_void, - ); - encoder.set_bytes( - 7, - (std::mem::size_of::() * batch_strides.len()) as u64, - batch_strides.as_ptr() as *const c_void, + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + params, + &batch_shape[..], + &batch_strides[..] + ) ); let grid_dims = MTLSize { @@ -1917,35 +1919,19 @@ pub fn call_sdpa_vector( // q = (bs, qhead, seq, hidden) // k/v = (bs, kv_head, kv_seq, hidden) - encoder.set_buffer(0, Some(&q_buffer), q_offset as NSUInteger); - encoder.set_buffer(1, Some(&k_buffer), k_offset as NSUInteger); - encoder.set_buffer(2, Some(&v_buffer), v_offset as NSUInteger); - encoder.set_buffer(3, Some(&output), 0); - - encoder.set_bytes( - 4, - std::mem::size_of::() as u64, - &gqa_factor as *const i32 as *const c_void, - ); - encoder.set_bytes( - 5, - std::mem::size_of::() as u64, - &n as *const i32 as *const c_void, - ); - encoder.set_bytes( - 6, - std::mem::size_of::() as u64, - &stride as *const usize as *const c_void, - ); - encoder.set_bytes( - 7, - std::mem::size_of::() as u64, - &alpha as *const f32 as *const c_void, - ); - encoder.set_bytes( - 8, - std::mem::size_of::() as u64, - &softcapping as *const f32 as *const c_void, + set_params!( + encoder, + ( + (q_buffer, q_offset), + (k_buffer, k_offset), + (v_buffer, v_offset), + output, + gqa_factor, + n, + stride, + alpha, + softcapping + ) ); let grid_dims = MTLSize { diff --git a/candle-metal-kernels/src/scaled_dot_product_attention.metal b/candle-metal-kernels/src/scaled_dot_product_attention.metal index d65cc621ac..69815abd0d 100644 --- a/candle-metal-kernels/src/scaled_dot_product_attention.metal +++ b/candle-metal-kernels/src/scaled_dot_product_attention.metal @@ -1170,8 +1170,8 @@ template < const device itype* V [[buffer(2)]], \ device otype* O [[buffer(3)]], \ const constant MLXFastAttentionParams* params [[buffer(4)]], \ - const constant int* batch_shape [[buffer(6)]], \ - const constant size_t* batch_strides [[buffer(7)]], \ + const constant int* batch_shape [[buffer(5)]], \ + const constant size_t* batch_strides [[buffer(6)]], \ uint simd_lane_id [[thread_index_in_simdgroup]], \ uint simd_group_id [[simdgroup_index_in_threadgroup]], \ uint3 tid [[threadgroup_position_in_grid]], \