diff --git a/candle-metal-kernels/src/lib.rs b/candle-metal-kernels/src/lib.rs index e0c97962cb..743b9fe2b3 100644 --- a/candle-metal-kernels/src/lib.rs +++ b/candle-metal-kernels/src/lib.rs @@ -19,6 +19,7 @@ const CAST: &str = include_str!("cast.metal"); const CONV: &str = include_str!("conv.metal"); const REDUCE: &str = include_str!("reduce.metal"); const RANDOM: &str = include_str!("random.metal"); +// Current source: https://github.com/ivarflakstad/metal-flash-attention/tree/candle const MFA: &[u8] = include_bytes!("libMetalFlashAttention.metallib"); const QUANTIZED: &str = include_str!("quantized.metal"); const SORT: &str = include_str!("sort.metal"); @@ -1564,6 +1565,7 @@ pub fn call_gemm( let bytes = match name { "sgemm" => 4, "hgemm" => 2, + "bgemm" => 2, other => { return Err(MetalKernelError::LoadLibraryError(format!( "{other} is not a valid kernel for gemm" diff --git a/candle-metal-kernels/src/tests.rs b/candle-metal-kernels/src/tests.rs index 8c38e74ac6..f70f773a97 100644 --- a/candle-metal-kernels/src/tests.rs +++ b/candle-metal-kernels/src/tests.rs @@ -1046,6 +1046,7 @@ fn where_cond_u32_f32() { } fn run_gemm( + name: &'static str, (b, m, n, k): (usize, usize, usize, usize), lhs: &[T], lhs_stride: Vec, @@ -1076,7 +1077,7 @@ fn run_gemm( &device, command_buffer, &kernels, - "sgemm", + name, (b, m, n, k), &lhs_stride, lhs_offset, @@ -1100,7 +1101,16 @@ fn gemm() { let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); let rhs_stride = vec![n * k, n, 1]; let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); - let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); + let results = run_gemm( + "sgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); assert_eq!( approx(results, 4), vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] @@ -1111,7 +1121,16 @@ fn gemm() { let lhs: Vec = (0..b * m * k).map(|f| f as f32).collect(); let rhs_stride = vec![n * k, n, 1]; let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); - let results = run_gemm((b, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 0); + let results = run_gemm( + "sgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); assert_eq!( approx(results, 4), vec![ @@ -1127,11 +1146,62 @@ fn gemm() { let rhs_stride = vec![n * k, n, 1]; let rhs: Vec = (0..b * n * k).map(|f| f as f32).collect(); // Manually set batch_size=1 and offset 12 elements * 4 the number of bytes for f32 - let results = run_gemm((1, m, n, k), &lhs, lhs_stride, 0, &rhs, rhs_stride, 12 * 4); + let results = run_gemm( + "sgemm", + (1, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 12 * 4, + ); assert_eq!( approx(results, 4), vec![56.0, 59.0, 62.0, 65.0, 200.0, 212.0, 224.0, 236.0] ); + + // bgemm sanity test + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| bf16::from_f32(f as f32)).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| bf16::from_f32(f as f32)).collect(); + let results = run_gemm( + "bgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); + assert_eq!( + approx_bf16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); + + // hgemm sanity test + let (b, m, n, k) = (1, 2, 4, 3); + let lhs_stride = vec![m * k, k, 1]; + let lhs: Vec = (0..b * m * k).map(|f| f16::from_f32(f as f32)).collect(); + let rhs_stride = vec![n * k, n, 1]; + let rhs: Vec = (0..b * n * k).map(|f| f16::from_f32(f as f32)).collect(); + let results = run_gemm( + "hgemm", + (b, m, n, k), + &lhs, + lhs_stride, + 0, + &rhs, + rhs_stride, + 0, + ); + assert_eq!( + approx_f16(results, 4), + vec![20.0, 23.0, 26.0, 29.0, 56.0, 68.0, 80.0, 92.0] + ); } fn run_random(name: &'static str, seed: u32, length: usize, a: f32, b: f32) -> Vec {