diff --git a/crates/cubecl-linalg/src/matmul/cmma/base.rs b/crates/cubecl-linalg/src/matmul/cmma/base.rs index 711a4e69..06f3205d 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/base.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/base.rs @@ -72,7 +72,7 @@ pub(crate) struct Offsets { } #[derive(CubeType)] -pub(crate) struct CmmaMatrices { +pub(crate) struct Fragments { pub accumulators: Sequence>, pub lhs: cmma::Matrix, pub rhs: cmma::Matrix, @@ -127,9 +127,9 @@ fn calculate_offsets( #[cube] pub(crate) fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32) { - if comptime_info.cube_dispatch == 0 { + if comptime_info.cube_dispatch_strategy == 0 { RowMajorCubeDispatch::get_row_col(comptime_info) - } else if comptime_info.cube_dispatch == 1 { + } else if comptime_info.cube_dispatch_strategy == 1 { ColMajorCubeDispatch::get_row_col(comptime_info) } else { SwizzleCubeDispatch::get_row_col(comptime_info) @@ -151,7 +151,7 @@ fn make_shared_memories(#[comptime] config: ComptimeCmmaInfo) -> Shar #[cube] pub(crate) fn make_cmma_matrices( #[comptime] config: ComptimeCmmaInfo, -) -> CmmaMatrices { +) -> Fragments { let num_accumulators = config.num_accumulators; let mut accumulators = Sequence::>::new(); @@ -186,7 +186,7 @@ pub(crate) fn make_cmma_matrices( cmma::MatrixLayout::RowMajor, ); - CmmaMatrices:: { + Fragments:: { accumulators, lhs, rhs, diff --git a/crates/cubecl-linalg/src/matmul/cmma/block_loop.rs b/crates/cubecl-linalg/src/matmul/cmma/block_loop.rs index cfdd8c15..50249db0 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/block_loop.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/block_loop.rs @@ -2,8 +2,8 @@ use cubecl_core as cubecl; use cubecl_core::prelude::*; use super::{ - base::{CmmaMatrices, RuntimeCmmaInfo, SharedMemories}, - compute_loop::compute_loop, + base::{Fragments, RuntimeCmmaInfo, SharedMemories}, + compute_loop::base::compute_loop, config::ComptimeCmmaInfo, load_shared_memory::load_to_shared_memories, write_output::{base::OutputWriter, large_smem::LargeSmemWriter, reuse_smem::ReuseSmemWriter}, @@ -15,12 +15,12 @@ pub(crate) fn block_loop( rhs: &Tensor, out: &mut Tensor, shared_memories: SharedMemories, - mut cmma_matrices: CmmaMatrices, + mut fragments: Fragments, runtime_info: RuntimeCmmaInfo, #[comptime] comptime_info: ComptimeCmmaInfo, ) { let block_size_k = comptime_info.block_size_k; - let write_out_reuse_smem = comptime_info.write_out_reuse_smem; + let write_out_reuse_smem = comptime_info.write_out_strategy; // Equals ceil(dims.k / block_size_k) let dims = runtime_info.dims; @@ -42,7 +42,7 @@ pub(crate) fn block_loop( compute_loop::( shared_memories, - &mut cmma_matrices, + &mut fragments, runtime_info.ids, comptime_info, ); @@ -50,19 +50,9 @@ pub(crate) fn block_loop( sync_units(); } - if write_out_reuse_smem { - ReuseSmemWriter::write_to_output( - out, - cmma_matrices.accumulators, - runtime_info, - comptime_info, - ); + if write_out_reuse_smem == 0 { + LargeSmemWriter::write_to_output(out, fragments.accumulators, runtime_info, comptime_info); } else { - LargeSmemWriter::write_to_output( - out, - cmma_matrices.accumulators, - runtime_info, - comptime_info, - ); + ReuseSmemWriter::write_to_output(out, fragments.accumulators, runtime_info, comptime_info); } } diff --git a/crates/cubecl-linalg/src/matmul/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/cmma/compute_loop.rs deleted file mode 100644 index 861de5ee..00000000 --- a/crates/cubecl-linalg/src/matmul/cmma/compute_loop.rs +++ /dev/null @@ -1,77 +0,0 @@ -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -use super::base::{CmmaMatrices, Ids, SharedMemories}; -use super::config::ComptimeCmmaInfo; - -#[cube] -#[allow(unused_mut)] -pub(crate) fn compute_loop( - shared_memories: SharedMemories, - cmma_matrices: &mut CmmaMatrices, - ids: Ids, - #[comptime] comptime_info: ComptimeCmmaInfo, -) { - let block_size_n = comptime_info.block_size_n; - let tile_size = comptime_info.tile_size; - let num_accumulators = comptime_info.num_accumulators; - let num_coop_per_row = (block_size_n / tile_size) / num_accumulators; - - let tile_row = ids.coop / num_coop_per_row; - let tile_col_base = (ids.coop % num_coop_per_row) * num_accumulators; - - let lhs = &cmma_matrices.lhs; - let rhs = &cmma_matrices.rhs; - let accumulators = &cmma_matrices.accumulators; - - #[unroll] - for n in 0..num_accumulators { - compute_tile::( - tile_row, - tile_col_base + n, - shared_memories, - lhs, - rhs, - accumulators.index(n), - comptime_info, - ); - } -} - -#[cube] -fn compute_tile( - tile_row: u32, - tile_col: u32, - shared_memories: SharedMemories, - lhs: &cmma::Matrix, - rhs: &cmma::Matrix, - accumulator: &cmma::Matrix, - #[comptime] comptime_info: ComptimeCmmaInfo, -) { - let block_size_k = comptime_info.block_size_k; - let tile_size = comptime_info.tile_size; - let unroll = comptime_info.unroll; - - let smem_stride = tile_size * tile_size; - let num_tiles_in_k = block_size_k / tile_size; - - #[unroll(unroll)] - for k_iter in 0..num_tiles_in_k { - let shared_lhs_tile = tile_row * num_tiles_in_k + k_iter; - let shared_rhs_tile = tile_col * num_tiles_in_k + k_iter; - let shared_lhs_pos = shared_lhs_tile * smem_stride; - let shared_rhs_pos = shared_rhs_tile * smem_stride; - - let lhs_slice = shared_memories - .lhs - .slice(shared_lhs_pos, shared_lhs_pos + smem_stride); - let rhs_slice = shared_memories - .rhs - .slice(shared_rhs_pos, shared_rhs_pos + smem_stride); - - cmma::load::(lhs, lhs_slice, 16); - cmma::load::(rhs, rhs_slice, 16); - - cmma::execute::(lhs, rhs, accumulator, accumulator); - } -} diff --git a/crates/cubecl-linalg/src/matmul/cmma/compute_loop/accumulators_first.rs b/crates/cubecl-linalg/src/matmul/cmma/compute_loop/accumulators_first.rs new file mode 100644 index 00000000..101b098b --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/cmma/compute_loop/accumulators_first.rs @@ -0,0 +1,63 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::matmul::cmma::{ + base::{Fragments, Ids, SharedMemories}, + compute_loop::base::load_into_fragment, + config::ComptimeCmmaInfo, +}; + +use super::base::ComputeLoop; + +pub(crate) struct AllAccumulatorsFirstComputeLoop {} + +#[cube] +impl ComputeLoop for AllAccumulatorsFirstComputeLoop { + fn compute_loop( + shared_memories: SharedMemories, + fragments: &mut Fragments, + ids: Ids, + #[comptime] comptime_info: ComptimeCmmaInfo, + ) { + // Comptime values + let block_size_k = comptime_info.block_size_k; + let block_size_n = comptime_info.block_size_n; + let tile_size = comptime_info.tile_size; + let unroll = comptime_info.unroll; + let num_accumulators = comptime_info.num_accumulators; + let num_buffers = block_size_k / tile_size; + let num_coop_per_row = (block_size_n / tile_size) / num_accumulators; + + // Runtime values + let tile_row = ids.coop / num_coop_per_row; + let tile_col_base = (ids.coop % num_coop_per_row) * num_accumulators; + + #[unroll(unroll)] + for buffer_iter in 0..num_buffers { + #[unroll] + for accumulator_iter in 0..num_accumulators { + load_into_fragment( + tile_row * num_buffers + buffer_iter, + shared_memories.lhs, + &fragments.lhs, + comptime_info, + ); + + load_into_fragment( + (tile_col_base + accumulator_iter) * num_buffers + buffer_iter, + shared_memories.rhs, + &fragments.rhs, + comptime_info, + ); + + let accumulator = &fragments.accumulators.index(accumulator_iter); + cmma::execute::( + &fragments.lhs, + &fragments.rhs, + accumulator, + accumulator, + ); + } + } + } +} diff --git a/crates/cubecl-linalg/src/matmul/cmma/compute_loop/base.rs b/crates/cubecl-linalg/src/matmul/cmma/compute_loop/base.rs new file mode 100644 index 00000000..4bd7b2ca --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/cmma/compute_loop/base.rs @@ -0,0 +1,55 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::matmul::cmma::{ + base::{Fragments, Ids, SharedMemories}, + compute_loop::{ + accumulators_first::AllAccumulatorsFirstComputeLoop, + buffers_first::AllBuffersFirstComputeLoop, + }, + config::ComptimeCmmaInfo, +}; + +#[cube] +pub(crate) fn compute_loop( + shared_memories: SharedMemories, + fragments: &mut Fragments, + ids: Ids, + #[comptime] comptime_info: ComptimeCmmaInfo, +) { + if comptime_info.compute_loop_order_strategy == 0 { + AllBuffersFirstComputeLoop::compute_loop(shared_memories, fragments, ids, comptime_info); + } else { + AllAccumulatorsFirstComputeLoop::compute_loop( + shared_memories, + fragments, + ids, + comptime_info, + ); + } +} + +#[cube] +pub(crate) trait ComputeLoop { + fn compute_loop( + shared_memories: SharedMemories, + fragments: &mut Fragments, + ids: Ids, + #[comptime] comptime_info: ComptimeCmmaInfo, + ); +} + +#[cube] +pub(crate) fn load_into_fragment( + tile: u32, + smem: SharedMemory, + fragment: &cmma::Matrix, + #[comptime] comptime_info: ComptimeCmmaInfo, +) { + let tile_size = comptime_info.tile_size; + let smem_stride = tile_size * tile_size; + + let smem_pos = tile * smem_stride; + let slice = smem.slice(smem_pos, smem_pos + smem_stride); + cmma::load::(fragment, slice, 16); +} diff --git a/crates/cubecl-linalg/src/matmul/cmma/compute_loop/buffers_first.rs b/crates/cubecl-linalg/src/matmul/cmma/compute_loop/buffers_first.rs new file mode 100644 index 00000000..bd712a08 --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/cmma/compute_loop/buffers_first.rs @@ -0,0 +1,63 @@ +use cubecl_core as cubecl; +use cubecl_core::prelude::*; + +use crate::matmul::cmma::{ + base::{Fragments, Ids, SharedMemories}, + compute_loop::base::load_into_fragment, + config::ComptimeCmmaInfo, +}; + +use super::base::ComputeLoop; + +pub(crate) struct AllBuffersFirstComputeLoop {} + +#[cube] +impl ComputeLoop for AllBuffersFirstComputeLoop { + fn compute_loop( + shared_memories: SharedMemories, + fragments: &mut Fragments, + ids: Ids, + #[comptime] comptime_info: ComptimeCmmaInfo, + ) { + // Comptime values + let block_size_k = comptime_info.block_size_k; + let block_size_n = comptime_info.block_size_n; + let tile_size = comptime_info.tile_size; + let unroll = comptime_info.unroll; + let num_accumulators = comptime_info.num_accumulators; + let num_buffers = block_size_k / tile_size; + let num_coop_per_row = (block_size_n / tile_size) / num_accumulators; + + // Runtime values + let tile_row = ids.coop / num_coop_per_row; + let tile_col_base = (ids.coop % num_coop_per_row) * num_accumulators; + + #[unroll] + for accumulator_iter in 0..num_accumulators { + #[unroll(unroll)] + for buffer_iter in 0..num_buffers { + load_into_fragment( + tile_row * num_buffers + buffer_iter, + shared_memories.lhs, + &fragments.lhs, + comptime_info, + ); + + load_into_fragment( + (tile_col_base + accumulator_iter) * num_buffers + buffer_iter, + shared_memories.rhs, + &fragments.rhs, + comptime_info, + ); + + let accumulator = &fragments.accumulators.index(accumulator_iter); + cmma::execute::( + &fragments.lhs, + &fragments.rhs, + accumulator, + accumulator, + ); + } + } + } +} diff --git a/crates/cubecl-linalg/src/matmul/cmma/compute_loop/mod.rs b/crates/cubecl-linalg/src/matmul/cmma/compute_loop/mod.rs new file mode 100644 index 00000000..fd98a8ac --- /dev/null +++ b/crates/cubecl-linalg/src/matmul/cmma/compute_loop/mod.rs @@ -0,0 +1,3 @@ +mod accumulators_first; +pub(crate) mod base; +mod buffers_first; diff --git a/crates/cubecl-linalg/src/matmul/cmma/config.rs b/crates/cubecl-linalg/src/matmul/cmma/config.rs index 2a49b20e..9af5cb4e 100644 --- a/crates/cubecl-linalg/src/matmul/cmma/config.rs +++ b/crates/cubecl-linalg/src/matmul/cmma/config.rs @@ -6,7 +6,7 @@ use cubecl_core::prelude::*; pub(crate) const CMMA_COOP_DIM: usize = 32; pub(crate) const CMMA_TILE_SIZE: usize = 16; -#[derive(PartialEq, Eq)] +#[derive(Clone, Copy)] /// Defines how data travels from accumulators to global output pub enum WriteOutStrategy { /// Accumulators for one warp are put concurrently in a shared memory large enough to contain them all @@ -15,6 +15,15 @@ pub enum WriteOutStrategy { ReuseSmem, } +impl From for u32 { + fn from(value: WriteOutStrategy) -> Self { + match value { + WriteOutStrategy::LargeSmem => 0, + WriteOutStrategy::ReuseSmem => 1, + } + } +} + /// How cubes are dispatched in the hypercube /// Should impact L2 cache reuse #[derive(Clone, Copy)] @@ -37,6 +46,24 @@ impl From for u32 { } } +#[derive(Clone, Copy)] +/// Defines how data travels from accumulators to global output +pub enum ComputeLoopOrderStrategy { + /// Accumulators for one warp are put concurrently in a shared memory large enough to contain them all + AllBuffersFirst, + /// Accumulators for one warp are put sequentially in a shared memory with only one reusable spot + AllAccumulatorsFirst, +} + +impl From for u32 { + fn from(value: ComputeLoopOrderStrategy) -> Self { + match value { + ComputeLoopOrderStrategy::AllBuffersFirst => 0, + ComputeLoopOrderStrategy::AllAccumulatorsFirst => 1, + } + } +} + pub struct CmmaConfig { /// Corresponds to the number of tiles in the m and n dimensions for a block pub b_mn: usize, @@ -47,7 +74,9 @@ pub struct CmmaConfig { /// Whether to write all accumulators in different spots of a large shared memory or reuse the space pub write_out_strategy: WriteOutStrategy, /// Order in which to dispatch cubes - pub cube_dispatch: CubeDispatchStrategy, + pub cube_dispatch_strategy: CubeDispatchStrategy, + /// Whether to iterate on buffers or accumulators first + pub compute_loop_order_strategy: ComputeLoopOrderStrategy, } impl Default for CmmaConfig { @@ -58,6 +87,7 @@ impl Default for CmmaConfig { false, WriteOutStrategy::ReuseSmem, CubeDispatchStrategy::ColMajor, + ComputeLoopOrderStrategy::AllBuffersFirst, ) } } @@ -68,7 +98,8 @@ impl CmmaConfig { b_k: usize, unroll: bool, write_out_strategy: WriteOutStrategy, - cube_dispatch: CubeDispatchStrategy, + cube_dispatch_strategy: CubeDispatchStrategy, + compute_loop_order_strategy: ComputeLoopOrderStrategy, ) -> CmmaConfig { assert!(b_mn % CMMA_TILE_SIZE == 0); assert!(b_k % CMMA_TILE_SIZE == 0); @@ -78,7 +109,8 @@ impl CmmaConfig { b_k, unroll, write_out_strategy, - cube_dispatch, + cube_dispatch_strategy, + compute_loop_order_strategy, } } @@ -97,8 +129,9 @@ impl CmmaConfig { coop_dim: CMMA_COOP_DIM as u32, num_coops: num_coops as u32, num_accumulators: (self.b_mn / self.b_k) as u32, - write_out_reuse_smem: self.write_out_strategy == WriteOutStrategy::ReuseSmem, - cube_dispatch: self.cube_dispatch.into(), + write_out_strategy: self.write_out_strategy.into(), + cube_dispatch_strategy: self.cube_dispatch_strategy.into(), + compute_loop_order_strategy: self.compute_loop_order_strategy.into(), } } @@ -168,8 +201,10 @@ pub struct ComptimeCmmaInfo { pub num_coops: u32, /// Number of cmma per subcube performed in one pass pub num_accumulators: u32, - /// Write out strategy: false = large, true = reuse - pub write_out_reuse_smem: bool, + /// 0 = large, 1 = reuse + pub write_out_strategy: u32, /// 0 = RowMajor, 1 = ColMajor, 2 = Swizzle - pub cube_dispatch: u32, + pub cube_dispatch_strategy: u32, + /// 0 = buffer inner, 1 = buffer outer + pub compute_loop_order_strategy: u32, } diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/base.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/base.rs deleted file mode 100644 index f2632331..00000000 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/base.rs +++ /dev/null @@ -1,8 +0,0 @@ -pub(crate) const B_MN: usize = 32; -pub(crate) const B_K: usize = 16; - -pub(crate) struct DimsTestCase { - pub m: usize, - pub k: usize, - pub n: usize, -} diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs deleted file mode 100644 index 949dfcf8..00000000 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/compute_loop.rs +++ /dev/null @@ -1,437 +0,0 @@ -use cubecl::prelude::*; -use cubecl_core as cubecl; - -use crate::matmul::cmma::{ - base::{make_cmma_matrices, Ids, SharedMemories}, - compute_loop::compute_loop, - config::{CmmaConfig, ComptimeCmmaInfo}, -}; -use crate::matmul::tests::test_utils::{ - assert_equals, cmma_available, create_empty, range_tensor_f16, -}; -use half::f16; - -#[cube(launch_unchecked)] -fn compute_loop_test( - lhs_tensor: &Tensor, - rhs_tensor: &Tensor, - accumulate_array: &mut Array, - #[comptime] b_mn: u32, - #[comptime] b_k: u32, - #[comptime] comptime_info: ComptimeCmmaInfo, -) { - let mut lhs = SharedMemory::::new(b_mn * b_k); - let mut rhs = SharedMemory::::new(b_k * b_mn); - - for i in 0..b_mn * b_k { - lhs[i] = lhs_tensor[i]; - } - for i in 0..b_k * b_mn { - rhs[i] = rhs_tensor[i]; - } - for i in 0..b_mn * b_mn { - accumulate_array[i] = F::new(0.); - } - - let shared_memories = SharedMemories:: { lhs, rhs }; - let mut matrices = make_cmma_matrices::(comptime_info); - - compute_loop( - shared_memories, - &mut matrices, - Ids { - coop: UNIT_POS_Y, - lane: UNIT_POS_X, - }, - comptime_info, - ); - - let num_accumulators = comptime_info.num_accumulators; - let tile_size = comptime_info.tile_size; - let slice_offset = tile_size * tile_size; - let offset = UNIT_POS_Y * slice_offset * num_accumulators; - - let accumulators = matrices.accumulators; - #[unroll] - for n in 0..num_accumulators { - let slice = - accumulate_array.slice_mut(offset + n * slice_offset, offset + (n + 1) * slice_offset); - cmma::store::( - slice, - accumulators.index(n), - 16, - cmma::MatrixLayout::RowMajor, - ); - } -} - -fn compute_loop_test_case( - block_config: CmmaConfig, - expected: &[f32], - device: &R::Device, -) { - if !cmma_available::(device) { - // We can't execute the test, skip. - return; - } - - let client = R::client(device); - let lhs = range_tensor_f16::(&client, block_config.b_mn, block_config.b_k); - let rhs = range_tensor_f16::(&client, block_config.b_k, block_config.b_mn); - let results = create_empty::(&client, block_config.b_mn, block_config.b_mn); - let cube_dim = block_config.cube_dim(); - let cube_count = block_config.cube_count::(&[block_config.b_mn, block_config.b_mn]); - - let comptime_info = - block_config.comptime_info(block_config.b_mn, block_config.b_k, block_config.b_mn); - - unsafe { - compute_loop_test::launch_unchecked::( - &R::client(device), - cube_count, - cube_dim, - TensorArg::from_raw_parts(&lhs.handle, &lhs.strides, &lhs.shape, 1), - TensorArg::from_raw_parts(&rhs.handle, &rhs.strides, &rhs.shape, 1), - ArrayArg::from_raw_parts(&results, block_config.b_mn * block_config.b_mn, 1), - block_config.b_mn as u32, - block_config.b_k as u32, - comptime_info, - ); - }; - - assert_equals::(&client, results, expected); -} - -/// Exported test -pub fn cmma_compute_loop_block_equal_tile_test(device: &R::Device) { - compute_loop_test_case::( - CmmaConfig { - b_mn: 16, - b_k: 16, - ..Default::default() - }, - &[ - 19840.0, 19960.0, 20080.0, 20200.0, 20320.0, 20440.0, 20560.0, 20680.0, 20800.0, - 20920.0, 21040.0, 21160.0, 21280.0, 21400.0, 21520.0, 21640.0, 50560.0, 50936.0, - 51312.0, 51688.0, 52064.0, 52440.0, 52816.0, 53192.0, 53568.0, 53944.0, 54320.0, - 54696.0, 55072.0, 55448.0, 55824.0, 56200.0, 81280.0, 81912.0, 82544.0, 83176.0, - 83808.0, 84440.0, 85072.0, 85704.0, 86336.0, 86968.0, 87600.0, 88232.0, 88864.0, - 89496.0, 90128.0, 90760.0, 112000.0, 112888.0, 113776.0, 114664.0, 115552.0, 116440.0, - 117328.0, 118216.0, 119104.0, 119992.0, 120880.0, 121768.0, 122656.0, 123544.0, - 124432.0, 125320.0, 142720.0, 143864.0, 145008.0, 146152.0, 147296.0, 148440.0, - 149584.0, 150728.0, 151872.0, 153016.0, 154160.0, 155304.0, 156448.0, 157592.0, - 158736.0, 159880.0, 173440.0, 174840.0, 176240.0, 177640.0, 179040.0, 180440.0, - 181840.0, 183240.0, 184640.0, 186040.0, 187440.0, 188840.0, 190240.0, 191640.0, - 193040.0, 194440.0, 204160.0, 205816.0, 207472.0, 209128.0, 210784.0, 212440.0, - 214096.0, 215752.0, 217408.0, 219064.0, 220720.0, 222376.0, 224032.0, 225688.0, - 227344.0, 229000.0, 234880.0, 236792.0, 238704.0, 240616.0, 242528.0, 244440.0, - 246352.0, 248264.0, 250176.0, 252088.0, 254000.0, 255912.0, 257824.0, 259736.0, - 261648.0, 263560.0, 265600.0, 267768.0, 269936.0, 272104.0, 274272.0, 276440.0, - 278608.0, 280776.0, 282944.0, 285112.0, 287280.0, 289448.0, 291616.0, 293784.0, - 295952.0, 298120.0, 296320.0, 298744.0, 301168.0, 303592.0, 306016.0, 308440.0, - 310864.0, 313288.0, 315712.0, 318136.0, 320560.0, 322984.0, 325408.0, 327832.0, - 330256.0, 332680.0, 327040.0, 329720.0, 332400.0, 335080.0, 337760.0, 340440.0, - 343120.0, 345800.0, 348480.0, 351160.0, 353840.0, 356520.0, 359200.0, 361880.0, - 364560.0, 367240.0, 357760.0, 360696.0, 363632.0, 366568.0, 369504.0, 372440.0, - 375376.0, 378312.0, 381248.0, 384184.0, 387120.0, 390056.0, 392992.0, 395928.0, - 398864.0, 401800.0, 388480.0, 391672.0, 394864.0, 398056.0, 401248.0, 404440.0, - 407632.0, 410824.0, 414016.0, 417208.0, 420400.0, 423592.0, 426784.0, 429976.0, - 433168.0, 436360.0, 419200.0, 422648.0, 426096.0, 429544.0, 432992.0, 436440.0, - 439888.0, 443336.0, 446784.0, 450232.0, 453680.0, 457128.0, 460576.0, 464024.0, - 467472.0, 470920.0, 449920.0, 453624.0, 457328.0, 461032.0, 464736.0, 468440.0, - 472144.0, 475848.0, 479552.0, 483256.0, 486960.0, 490664.0, 494368.0, 498072.0, - 501776.0, 505480.0, 480640.0, 484600.0, 488560.0, 492520.0, 496480.0, 500440.0, - 504400.0, 508360.0, 512320.0, 516280.0, 520240.0, 524200.0, 528160.0, 532120.0, - 536080.0, 540040.0, - ], - device, - ); -} - -/// Exported test -pub fn cmma_compute_loop_block_larger_than_tile_test(device: &R::Device) { - compute_loop_test_case::( - CmmaConfig { - b_mn: 32, - b_k: 32, - ..Default::default() - }, - &[ - 1610496.0, 1614832.0, 1619168.0, 1623504.0, 1627840.0, 1632176.0, 1636512.0, 1640848.0, - 1645184.0, 1649520.0, 1653856.0, 1658192.0, 1662528.0, 1666864.0, 1671200.0, 1675536.0, - 1737472.0, 1742320.0, 1747168.0, 1752016.0, 1756864.0, 1761712.0, 1766560.0, 1771408.0, - 1776256.0, 1781104.0, 1785952.0, 1790800.0, 1795648.0, 1800496.0, 1805344.0, 1810192.0, - 1864448.0, 1869808.0, 1875168.0, 1880528.0, 1885888.0, 1891248.0, 1896608.0, 1901968.0, - 1907328.0, 1912688.0, 1918048.0, 1923408.0, 1928768.0, 1934128.0, 1939488.0, 1944848.0, - 1991424.0, 1997296.0, 2003168.0, 2009040.0, 2014912.0, 2020784.0, 2026656.0, 2032528.0, - 2038400.0, 2044272.0, 2050144.0, 2056016.0, 2061888.0, 2067760.0, 2073632.0, 2079504.0, - 2118400.0, 2124784.0, 2131168.0, 2137552.0, 2143936.0, 2150320.0, 2156704.0, 2163088.0, - 2169472.0, 2175856.0, 2182240.0, 2188624.0, 2195008.0, 2201392.0, 2207776.0, 2214160.0, - 2245376.0, 2252272.0, 2259168.0, 2266064.0, 2272960.0, 2279856.0, 2286752.0, 2293648.0, - 2300544.0, 2307440.0, 2314336.0, 2321232.0, 2328128.0, 2335024.0, 2341920.0, 2348816.0, - 2372352.0, 2379760.0, 2387168.0, 2394576.0, 2401984.0, 2409392.0, 2416800.0, 2424208.0, - 2431616.0, 2439024.0, 2446432.0, 2453840.0, 2461248.0, 2468656.0, 2476064.0, 2483472.0, - 2499328.0, 2507248.0, 2515168.0, 2523088.0, 2531008.0, 2538928.0, 2546848.0, 2554768.0, - 2562688.0, 2570608.0, 2578528.0, 2586448.0, 2594368.0, 2602288.0, 2610208.0, 2618128.0, - 2626304.0, 2634736.0, 2643168.0, 2651600.0, 2660032.0, 2668464.0, 2676896.0, 2685328.0, - 2693760.0, 2702192.0, 2710624.0, 2719056.0, 2727488.0, 2735920.0, 2744352.0, 2752784.0, - 2753280.0, 2762224.0, 2771168.0, 2780112.0, 2789056.0, 2798000.0, 2806944.0, 2815888.0, - 2824832.0, 2833776.0, 2842720.0, 2851664.0, 2860608.0, 2869552.0, 2878496.0, 2887440.0, - 2880256.0, 2889712.0, 2899168.0, 2908624.0, 2918080.0, 2927536.0, 2936992.0, 2946448.0, - 2955904.0, 2965360.0, 2974816.0, 2984272.0, 2993728.0, 3003184.0, 3012640.0, 3022096.0, - 3007232.0, 3017200.0, 3027168.0, 3037136.0, 3047104.0, 3057072.0, 3067040.0, 3077008.0, - 3086976.0, 3096944.0, 3106912.0, 3116880.0, 3126848.0, 3136816.0, 3146784.0, 3156752.0, - 3134208.0, 3144688.0, 3155168.0, 3165648.0, 3176128.0, 3186608.0, 3197088.0, 3207568.0, - 3218048.0, 3228528.0, 3239008.0, 3249488.0, 3259968.0, 3270448.0, 3280928.0, 3291408.0, - 3261184.0, 3272176.0, 3283168.0, 3294160.0, 3305152.0, 3316144.0, 3327136.0, 3338128.0, - 3349120.0, 3360112.0, 3371104.0, 3382096.0, 3393088.0, 3404080.0, 3415072.0, 3426064.0, - 3388160.0, 3399664.0, 3411168.0, 3422672.0, 3434176.0, 3445680.0, 3457184.0, 3468688.0, - 3480192.0, 3491696.0, 3503200.0, 3514704.0, 3526208.0, 3537712.0, 3549216.0, 3560720.0, - 3515136.0, 3527152.0, 3539168.0, 3551184.0, 3563200.0, 3575216.0, 3587232.0, 3599248.0, - 3611264.0, 3623280.0, 3635296.0, 3647312.0, 3659328.0, 3671344.0, 3683360.0, 3695376.0, - 3830528.0, 3834864.0, 3839200.0, 3843536.0, 3847872.0, 3852208.0, 3856544.0, 3860880.0, - 3865216.0, 3869552.0, 3873888.0, 3878224.0, 3882560.0, 3886896.0, 3891232.0, 3895568.0, - 4219648.0, 4224496.0, 4229344.0, 4234192.0, 4239040.0, 4243888.0, 4248736.0, 4253584.0, - 4258432.0, 4263280.0, 4268128.0, 4272976.0, 4277824.0, 4282672.0, 4287520.0, 4292368.0, - 4608768.0, 4614128.0, 4619488.0, 4624848.0, 4630208.0, 4635568.0, 4640928.0, 4646288.0, - 4651648.0, 4657008.0, 4662368.0, 4667728.0, 4673088.0, 4678448.0, 4683808.0, 4689168.0, - 4997888.0, 5003760.0, 5009632.0, 5015504.0, 5021376.0, 5027248.0, 5033120.0, 5038992.0, - 5044864.0, 5050736.0, 5056608.0, 5062480.0, 5068352.0, 5074224.0, 5080096.0, 5085968.0, - 5387008.0, 5393392.0, 5399776.0, 5406160.0, 5412544.0, 5418928.0, 5425312.0, 5431696.0, - 5438080.0, 5444464.0, 5450848.0, 5457232.0, 5463616.0, 5470000.0, 5476384.0, 5482768.0, - 5776128.0, 5783024.0, 5789920.0, 5796816.0, 5803712.0, 5810608.0, 5817504.0, 5824400.0, - 5831296.0, 5838192.0, 5845088.0, 5851984.0, 5858880.0, 5865776.0, 5872672.0, 5879568.0, - 6165248.0, 6172656.0, 6180064.0, 6187472.0, 6194880.0, 6202288.0, 6209696.0, 6217104.0, - 6224512.0, 6231920.0, 6239328.0, 6246736.0, 6254144.0, 6261552.0, 6268960.0, 6276368.0, - 6554368.0, 6562288.0, 6570208.0, 6578128.0, 6586048.0, 6593968.0, 6601888.0, 6609808.0, - 6617728.0, 6625648.0, 6633568.0, 6641488.0, 6649408.0, 6657328.0, 6665248.0, 6673168.0, - 6943488.0, 6951920.0, 6960352.0, 6968784.0, 6977216.0, 6985648.0, 6994080.0, 7002512.0, - 7010944.0, 7019376.0, 7027808.0, 7036240.0, 7044672.0, 7053104.0, 7061536.0, 7069968.0, - 7332608.0, 7341552.0, 7350496.0, 7359440.0, 7368384.0, 7377328.0, 7386272.0, 7395216.0, - 7404160.0, 7413104.0, 7422048.0, 7430992.0, 7439936.0, 7448880.0, 7457824.0, 7466768.0, - 7721728.0, 7731184.0, 7740640.0, 7750096.0, 7759552.0, 7769008.0, 7778464.0, 7787920.0, - 7797376.0, 7806832.0, 7816288.0, 7825744.0, 7835200.0, 7844656.0, 7854112.0, 7863568.0, - 8110848.0, 8120816.0, 8130784.0, 8140752.0, 8150720.0, 8160688.0, 8170656.0, 8180624.0, - 8190592.0, 8200560.0, 8210528.0, 8220496.0, 8230464.0, 8240432.0, 8250400.0, 8260368.0, - 8499968.0, 8510448.0, 8520928.0, 8531408.0, 8541888.0, 8552368.0, 8562848.0, 8573328.0, - 8583808.0, 8594288.0, 8604768.0, 8615248.0, 8625728.0, 8636208.0, 8646688.0, 8657168.0, - 8889088.0, 8900080.0, 8911072.0, 8922064.0, 8933056.0, 8944048.0, 8955040.0, 8966032.0, - 8977024.0, 8988016.0, 8999008.0, 9010000.0, 9020992.0, 9031984.0, 9042976.0, 9053968.0, - 9278208.0, 9289712.0, 9301216.0, 9312720.0, 9324224.0, 9335728.0, 9347232.0, 9358736.0, - 9370240.0, 9381744.0, 9393248.0, 9404752.0, 9416256.0, 9427760.0, 9439264.0, 9450768.0, - 9667328.0, 9679344.0, 9691360.0, 9703376.0, 9715392.0, 9727408.0, 9739424.0, 9751440.0, - 9763456.0, 9775472.0, 9787488.0, 9799504.0, 9811520.0, 9823536.0, 9835552.0, 9847568.0, - 5673728.0, 5694448.0, 5715168.0, 5735888.0, 5756608.0, 5777328.0, 5798048.0, 5818768.0, - 5839488.0, 5860208.0, 5880928.0, 5901648.0, 5922368.0, 5943088.0, 5963808.0, 5984528.0, - 5800704.0, 5821936.0, 5843168.0, 5864400.0, 5885632.0, 5906864.0, 5928096.0, 5949328.0, - 5970560.0, 5991792.0, 6013024.0, 6034256.0, 6055488.0, 6076720.0, 6097952.0, 6119184.0, - 5927680.0, 5949424.0, 5971168.0, 5992912.0, 6014656.0, 6036400.0, 6058144.0, 6079888.0, - 6101632.0, 6123376.0, 6145120.0, 6166864.0, 6188608.0, 6210352.0, 6232096.0, 6253840.0, - 6054656.0, 6076912.0, 6099168.0, 6121424.0, 6143680.0, 6165936.0, 6188192.0, 6210448.0, - 6232704.0, 6254960.0, 6277216.0, 6299472.0, 6321728.0, 6343984.0, 6366240.0, 6388496.0, - 6181632.0, 6204400.0, 6227168.0, 6249936.0, 6272704.0, 6295472.0, 6318240.0, 6341008.0, - 6363776.0, 6386544.0, 6409312.0, 6432080.0, 6454848.0, 6477616.0, 6500384.0, 6523152.0, - 6308608.0, 6331888.0, 6355168.0, 6378448.0, 6401728.0, 6425008.0, 6448288.0, 6471568.0, - 6494848.0, 6518128.0, 6541408.0, 6564688.0, 6587968.0, 6611248.0, 6634528.0, 6657808.0, - 6435584.0, 6459376.0, 6483168.0, 6506960.0, 6530752.0, 6554544.0, 6578336.0, 6602128.0, - 6625920.0, 6649712.0, 6673504.0, 6697296.0, 6721088.0, 6744880.0, 6768672.0, 6792464.0, - 6562560.0, 6586864.0, 6611168.0, 6635472.0, 6659776.0, 6684080.0, 6708384.0, 6732688.0, - 6756992.0, 6781296.0, 6805600.0, 6829904.0, 6854208.0, 6878512.0, 6902816.0, 6927120.0, - 6689536.0, 6714352.0, 6739168.0, 6763984.0, 6788800.0, 6813616.0, 6838432.0, 6863248.0, - 6888064.0, 6912880.0, 6937696.0, 6962512.0, 6987328.0, 7012144.0, 7036960.0, 7061776.0, - 6816512.0, 6841840.0, 6867168.0, 6892496.0, 6917824.0, 6943152.0, 6968480.0, 6993808.0, - 7019136.0, 7044464.0, 7069792.0, 7095120.0, 7120448.0, 7145776.0, 7171104.0, 7196432.0, - 6943488.0, 6969328.0, 6995168.0, 7021008.0, 7046848.0, 7072688.0, 7098528.0, 7124368.0, - 7150208.0, 7176048.0, 7201888.0, 7227728.0, 7253568.0, 7279408.0, 7305248.0, 7331088.0, - 7070464.0, 7096816.0, 7123168.0, 7149520.0, 7175872.0, 7202224.0, 7228576.0, 7254928.0, - 7281280.0, 7307632.0, 7333984.0, 7360336.0, 7386688.0, 7413040.0, 7439392.0, 7465744.0, - 7197440.0, 7224304.0, 7251168.0, 7278032.0, 7304896.0, 7331760.0, 7358624.0, 7385488.0, - 7412352.0, 7439216.0, 7466080.0, 7492944.0, 7519808.0, 7546672.0, 7573536.0, 7600400.0, - 7324416.0, 7351792.0, 7379168.0, 7406544.0, 7433920.0, 7461296.0, 7488672.0, 7516048.0, - 7543424.0, 7570800.0, 7598176.0, 7625552.0, 7652928.0, 7680304.0, 7707680.0, 7735056.0, - 7451392.0, 7479280.0, 7507168.0, 7535056.0, 7562944.0, 7590832.0, 7618720.0, 7646608.0, - 7674496.0, 7702384.0, 7730272.0, 7758160.0, 7786048.0, 7813936.0, 7841824.0, 7869712.0, - 7578368.0, 7606768.0, 7635168.0, 7663568.0, 7691968.0, 7720368.0, 7748768.0, 7777168.0, - 7805568.0, 7833968.0, 7862368.0, 7890768.0, 7919168.0, 7947568.0, 7975968.0, 8004368.0, - 16282368.0, 16303088.0, 16323808.0, 16344528.0, 16365248.0, 16385968.0, 16406688.0, - 16427408.0, 16448128.0, 16468848.0, 16489568.0, 16510288.0, 16531008.0, 16551728.0, - 16572448.0, 16593168.0, 16671488.0, 16692720.0, 16713952.0, 16735184.0, 16756416.0, - 16777648.0, 16798880.0, 16820112.0, 16841344.0, 16862576.0, 16883808.0, 16905040.0, - 16926272.0, 16947504.0, 16968736.0, 16989968.0, 17060608.0, 17082352.0, 17104096.0, - 17125840.0, 17147584.0, 17169328.0, 17191072.0, 17212816.0, 17234560.0, 17256304.0, - 17278048.0, 17299792.0, 17321536.0, 17343280.0, 17365024.0, 17386768.0, 17449728.0, - 17471984.0, 17494240.0, 17516496.0, 17538752.0, 17561008.0, 17583264.0, 17605520.0, - 17627776.0, 17650032.0, 17672288.0, 17694544.0, 17716800.0, 17739056.0, 17761312.0, - 17783568.0, 17838848.0, 17861616.0, 17884384.0, 17907152.0, 17929920.0, 17952688.0, - 17975456.0, 17998224.0, 18020992.0, 18043760.0, 18066528.0, 18089296.0, 18112064.0, - 18134832.0, 18157600.0, 18180368.0, 18227968.0, 18251248.0, 18274528.0, 18297808.0, - 18321088.0, 18344368.0, 18367648.0, 18390928.0, 18414208.0, 18437488.0, 18460768.0, - 18484048.0, 18507328.0, 18530608.0, 18553888.0, 18577168.0, 18617088.0, 18640880.0, - 18664672.0, 18688464.0, 18712256.0, 18736048.0, 18759840.0, 18783632.0, 18807424.0, - 18831216.0, 18855008.0, 18878800.0, 18902592.0, 18926384.0, 18950176.0, 18973968.0, - 19006208.0, 19030512.0, 19054816.0, 19079120.0, 19103424.0, 19127728.0, 19152032.0, - 19176336.0, 19200640.0, 19224944.0, 19249248.0, 19273552.0, 19297856.0, 19322160.0, - 19346464.0, 19370768.0, 19395328.0, 19420144.0, 19444960.0, 19469776.0, 19494592.0, - 19519408.0, 19544224.0, 19569040.0, 19593856.0, 19618672.0, 19643488.0, 19668304.0, - 19693120.0, 19717936.0, 19742752.0, 19767568.0, 19784448.0, 19809776.0, 19835104.0, - 19860432.0, 19885760.0, 19911088.0, 19936416.0, 19961744.0, 19987072.0, 20012400.0, - 20037728.0, 20063056.0, 20088384.0, 20113712.0, 20139040.0, 20164368.0, 20173568.0, - 20199408.0, 20225248.0, 20251088.0, 20276928.0, 20302768.0, 20328608.0, 20354448.0, - 20380288.0, 20406128.0, 20431968.0, 20457808.0, 20483648.0, 20509488.0, 20535328.0, - 20561168.0, 20562688.0, 20589040.0, 20615392.0, 20641744.0, 20668096.0, 20694448.0, - 20720800.0, 20747152.0, 20773504.0, 20799856.0, 20826208.0, 20852560.0, 20878912.0, - 20905264.0, 20931616.0, 20957968.0, 20951808.0, 20978672.0, 21005536.0, 21032400.0, - 21059264.0, 21086128.0, 21112992.0, 21139856.0, 21166720.0, 21193584.0, 21220448.0, - 21247312.0, 21274176.0, 21301040.0, 21327904.0, 21354768.0, 21340928.0, 21368304.0, - 21395680.0, 21423056.0, 21450432.0, 21477808.0, 21505184.0, 21532560.0, 21559936.0, - 21587312.0, 21614688.0, 21642064.0, 21669440.0, 21696816.0, 21724192.0, 21751568.0, - 21730048.0, 21757936.0, 21785824.0, 21813712.0, 21841600.0, 21869488.0, 21897376.0, - 21925264.0, 21953152.0, 21981040.0, 22008928.0, 22036816.0, 22064704.0, 22092592.0, - 22120480.0, 22148368.0, 22119168.0, 22147568.0, 22175968.0, 22204368.0, 22232768.0, - 22261168.0, 22289568.0, 22317968.0, 22346368.0, 22374768.0, 22403168.0, 22431568.0, - 22459968.0, 22488368.0, 22516768.0, 22545168.0, - ], - device, - ); -} - -/// Exported test -pub fn cmma_compute_loop_b_mn_larger_than_b_k_test(device: &R::Device) { - compute_loop_test_case::( - CmmaConfig { - b_mn: 32, - b_k: 16, - ..Default::default() - }, - &[ - 19840.0, 19960.0, 20080.0, 20200.0, 20320.0, 20440.0, 20560.0, 20680.0, 20800.0, - 20920.0, 21040.0, 21160.0, 21280.0, 21400.0, 21520.0, 21640.0, 50560.0, 50936.0, - 51312.0, 51688.0, 52064.0, 52440.0, 52816.0, 53192.0, 53568.0, 53944.0, 54320.0, - 54696.0, 55072.0, 55448.0, 55824.0, 56200.0, 81280.0, 81912.0, 82544.0, 83176.0, - 83808.0, 84440.0, 85072.0, 85704.0, 86336.0, 86968.0, 87600.0, 88232.0, 88864.0, - 89496.0, 90128.0, 90760.0, 112000.0, 112888.0, 113776.0, 114664.0, 115552.0, 116440.0, - 117328.0, 118216.0, 119104.0, 119992.0, 120880.0, 121768.0, 122656.0, 123544.0, - 124432.0, 125320.0, 142720.0, 143864.0, 145008.0, 146152.0, 147296.0, 148440.0, - 149584.0, 150728.0, 151872.0, 153016.0, 154160.0, 155304.0, 156448.0, 157592.0, - 158736.0, 159880.0, 173440.0, 174840.0, 176240.0, 177640.0, 179040.0, 180440.0, - 181840.0, 183240.0, 184640.0, 186040.0, 187440.0, 188840.0, 190240.0, 191640.0, - 193040.0, 194440.0, 204160.0, 205816.0, 207472.0, 209128.0, 210784.0, 212440.0, - 214096.0, 215752.0, 217408.0, 219064.0, 220720.0, 222376.0, 224032.0, 225688.0, - 227344.0, 229000.0, 234880.0, 236792.0, 238704.0, 240616.0, 242528.0, 244440.0, - 246352.0, 248264.0, 250176.0, 252088.0, 254000.0, 255912.0, 257824.0, 259736.0, - 261648.0, 263560.0, 265600.0, 267768.0, 269936.0, 272104.0, 274272.0, 276440.0, - 278608.0, 280776.0, 282944.0, 285112.0, 287280.0, 289448.0, 291616.0, 293784.0, - 295952.0, 298120.0, 296320.0, 298744.0, 301168.0, 303592.0, 306016.0, 308440.0, - 310864.0, 313288.0, 315712.0, 318136.0, 320560.0, 322984.0, 325408.0, 327832.0, - 330256.0, 332680.0, 327040.0, 329720.0, 332400.0, 335080.0, 337760.0, 340440.0, - 343120.0, 345800.0, 348480.0, 351160.0, 353840.0, 356520.0, 359200.0, 361880.0, - 364560.0, 367240.0, 357760.0, 360696.0, 363632.0, 366568.0, 369504.0, 372440.0, - 375376.0, 378312.0, 381248.0, 384184.0, 387120.0, 390056.0, 392992.0, 395928.0, - 398864.0, 401800.0, 388480.0, 391672.0, 394864.0, 398056.0, 401248.0, 404440.0, - 407632.0, 410824.0, 414016.0, 417208.0, 420400.0, 423592.0, 426784.0, 429976.0, - 433168.0, 436360.0, 419200.0, 422648.0, 426096.0, 429544.0, 432992.0, 436440.0, - 439888.0, 443336.0, 446784.0, 450232.0, 453680.0, 457128.0, 460576.0, 464024.0, - 467472.0, 470920.0, 449920.0, 453624.0, 457328.0, 461032.0, 464736.0, 468440.0, - 472144.0, 475848.0, 479552.0, 483256.0, 486960.0, 490664.0, 494368.0, 498072.0, - 501776.0, 505480.0, 480640.0, 484600.0, 488560.0, 492520.0, 496480.0, 500440.0, - 504400.0, 508360.0, 512320.0, 516280.0, 520240.0, 524200.0, 528160.0, 532120.0, - 536080.0, 540040.0, 50560.0, 50680.0, 50800.0, 50920.0, 51040.0, 51160.0, 51280.0, - 51400.0, 51520.0, 51640.0, 51760.0, 51880.0, 52000.0, 52120.0, 52240.0, 52360.0, - 146816.0, 147192.0, 147568.0, 147944.0, 148320.0, 148696.0, 149072.0, 149448.0, - 149824.0, 150200.0, 150576.0, 150952.0, 151328.0, 151704.0, 152080.0, 152456.0, - 243072.0, 243704.0, 244336.0, 244968.0, 245600.0, 246232.0, 246864.0, 247496.0, - 248128.0, 248760.0, 249392.0, 250024.0, 250656.0, 251288.0, 251920.0, 252552.0, - 339328.0, 340216.0, 341104.0, 341992.0, 342880.0, 343768.0, 344656.0, 345544.0, - 346432.0, 347320.0, 348208.0, 349096.0, 349984.0, 350872.0, 351760.0, 352648.0, - 435584.0, 436728.0, 437872.0, 439016.0, 440160.0, 441304.0, 442448.0, 443592.0, - 444736.0, 445880.0, 447024.0, 448168.0, 449312.0, 450456.0, 451600.0, 452744.0, - 531840.0, 533240.0, 534640.0, 536040.0, 537440.0, 538840.0, 540240.0, 541640.0, - 543040.0, 544440.0, 545840.0, 547240.0, 548640.0, 550040.0, 551440.0, 552840.0, - 628096.0, 629752.0, 631408.0, 633064.0, 634720.0, 636376.0, 638032.0, 639688.0, - 641344.0, 643000.0, 644656.0, 646312.0, 647968.0, 649624.0, 651280.0, 652936.0, - 724352.0, 726264.0, 728176.0, 730088.0, 732000.0, 733912.0, 735824.0, 737736.0, - 739648.0, 741560.0, 743472.0, 745384.0, 747296.0, 749208.0, 751120.0, 753032.0, - 820608.0, 822776.0, 824944.0, 827112.0, 829280.0, 831448.0, 833616.0, 835784.0, - 837952.0, 840120.0, 842288.0, 844456.0, 846624.0, 848792.0, 850960.0, 853128.0, - 916864.0, 919288.0, 921712.0, 924136.0, 926560.0, 928984.0, 931408.0, 933832.0, - 936256.0, 938680.0, 941104.0, 943528.0, 945952.0, 948376.0, 950800.0, 953224.0, - 1013120.0, 1015800.0, 1018480.0, 1021160.0, 1023840.0, 1026520.0, 1029200.0, 1031880.0, - 1034560.0, 1037240.0, 1039920.0, 1042600.0, 1045280.0, 1047960.0, 1050640.0, 1053320.0, - 1109376.0, 1112312.0, 1115248.0, 1118184.0, 1121120.0, 1124056.0, 1126992.0, 1129928.0, - 1132864.0, 1135800.0, 1138736.0, 1141672.0, 1144608.0, 1147544.0, 1150480.0, 1153416.0, - 1205632.0, 1208824.0, 1212016.0, 1215208.0, 1218400.0, 1221592.0, 1224784.0, 1227976.0, - 1231168.0, 1234360.0, 1237552.0, 1240744.0, 1243936.0, 1247128.0, 1250320.0, 1253512.0, - 1301888.0, 1305336.0, 1308784.0, 1312232.0, 1315680.0, 1319128.0, 1322576.0, 1326024.0, - 1329472.0, 1332920.0, 1336368.0, 1339816.0, 1343264.0, 1346712.0, 1350160.0, 1353608.0, - 1398144.0, 1401848.0, 1405552.0, 1409256.0, 1412960.0, 1416664.0, 1420368.0, 1424072.0, - 1427776.0, 1431480.0, 1435184.0, 1438888.0, 1442592.0, 1446296.0, 1450000.0, 1453704.0, - 1494400.0, 1498360.0, 1502320.0, 1506280.0, 1510240.0, 1514200.0, 1518160.0, 1522120.0, - 1526080.0, 1530040.0, 1534000.0, 1537960.0, 1541920.0, 1545880.0, 1549840.0, 1553800.0, - 511360.0, 515576.0, 519792.0, 524008.0, 528224.0, 532440.0, 536656.0, 540872.0, - 545088.0, 549304.0, 553520.0, 557736.0, 561952.0, 566168.0, 570384.0, 574600.0, - 542080.0, 546552.0, 551024.0, 555496.0, 559968.0, 564440.0, 568912.0, 573384.0, - 577856.0, 582328.0, 586800.0, 591272.0, 595744.0, 600216.0, 604688.0, 609160.0, - 572800.0, 577528.0, 582256.0, 586984.0, 591712.0, 596440.0, 601168.0, 605896.0, - 610624.0, 615352.0, 620080.0, 624808.0, 629536.0, 634264.0, 638992.0, 643720.0, - 603520.0, 608504.0, 613488.0, 618472.0, 623456.0, 628440.0, 633424.0, 638408.0, - 643392.0, 648376.0, 653360.0, 658344.0, 663328.0, 668312.0, 673296.0, 678280.0, - 634240.0, 639480.0, 644720.0, 649960.0, 655200.0, 660440.0, 665680.0, 670920.0, - 676160.0, 681400.0, 686640.0, 691880.0, 697120.0, 702360.0, 707600.0, 712840.0, - 664960.0, 670456.0, 675952.0, 681448.0, 686944.0, 692440.0, 697936.0, 703432.0, - 708928.0, 714424.0, 719920.0, 725416.0, 730912.0, 736408.0, 741904.0, 747400.0, - 695680.0, 701432.0, 707184.0, 712936.0, 718688.0, 724440.0, 730192.0, 735944.0, - 741696.0, 747448.0, 753200.0, 758952.0, 764704.0, 770456.0, 776208.0, 781960.0, - 726400.0, 732408.0, 738416.0, 744424.0, 750432.0, 756440.0, 762448.0, 768456.0, - 774464.0, 780472.0, 786480.0, 792488.0, 798496.0, 804504.0, 810512.0, 816520.0, - 757120.0, 763384.0, 769648.0, 775912.0, 782176.0, 788440.0, 794704.0, 800968.0, - 807232.0, 813496.0, 819760.0, 826024.0, 832288.0, 838552.0, 844816.0, 851080.0, - 787840.0, 794360.0, 800880.0, 807400.0, 813920.0, 820440.0, 826960.0, 833480.0, - 840000.0, 846520.0, 853040.0, 859560.0, 866080.0, 872600.0, 879120.0, 885640.0, - 818560.0, 825336.0, 832112.0, 838888.0, 845664.0, 852440.0, 859216.0, 865992.0, - 872768.0, 879544.0, 886320.0, 893096.0, 899872.0, 906648.0, 913424.0, 920200.0, - 849280.0, 856312.0, 863344.0, 870376.0, 877408.0, 884440.0, 891472.0, 898504.0, - 905536.0, 912568.0, 919600.0, 926632.0, 933664.0, 940696.0, 947728.0, 954760.0, - 880000.0, 887288.0, 894576.0, 901864.0, 909152.0, 916440.0, 923728.0, 931016.0, - 938304.0, 945592.0, 952880.0, 960168.0, 967456.0, 974744.0, 982032.0, 989320.0, - 910720.0, 918264.0, 925808.0, 933352.0, 940896.0, 948440.0, 955984.0, 963528.0, - 971072.0, 978616.0, 986160.0, 993704.0, 1001248.0, 1008792.0, 1016336.0, 1023880.0, - 941440.0, 949240.0, 957040.0, 964840.0, 972640.0, 980440.0, 988240.0, 996040.0, - 1003840.0, 1011640.0, 1019440.0, 1027240.0, 1035040.0, 1042840.0, 1050640.0, 1058440.0, - 972160.0, 980216.0, 988272.0, 996328.0, 1004384.0, 1012440.0, 1020496.0, 1028552.0, - 1036608.0, 1044664.0, 1052720.0, 1060776.0, 1068832.0, 1076888.0, 1084944.0, 1093000.0, - 1590656.0, 1594872.0, 1599088.0, 1603304.0, 1607520.0, 1611736.0, 1615952.0, 1620168.0, - 1624384.0, 1628600.0, 1632816.0, 1637032.0, 1641248.0, 1645464.0, 1649680.0, 1653896.0, - 1686912.0, 1691384.0, 1695856.0, 1700328.0, 1704800.0, 1709272.0, 1713744.0, 1718216.0, - 1722688.0, 1727160.0, 1731632.0, 1736104.0, 1740576.0, 1745048.0, 1749520.0, 1753992.0, - 1783168.0, 1787896.0, 1792624.0, 1797352.0, 1802080.0, 1806808.0, 1811536.0, 1816264.0, - 1820992.0, 1825720.0, 1830448.0, 1835176.0, 1839904.0, 1844632.0, 1849360.0, 1854088.0, - 1879424.0, 1884408.0, 1889392.0, 1894376.0, 1899360.0, 1904344.0, 1909328.0, 1914312.0, - 1919296.0, 1924280.0, 1929264.0, 1934248.0, 1939232.0, 1944216.0, 1949200.0, 1954184.0, - 1975680.0, 1980920.0, 1986160.0, 1991400.0, 1996640.0, 2001880.0, 2007120.0, 2012360.0, - 2017600.0, 2022840.0, 2028080.0, 2033320.0, 2038560.0, 2043800.0, 2049040.0, 2054280.0, - 2071936.0, 2077432.0, 2082928.0, 2088424.0, 2093920.0, 2099416.0, 2104912.0, 2110408.0, - 2115904.0, 2121400.0, 2126896.0, 2132392.0, 2137888.0, 2143384.0, 2148880.0, 2154376.0, - 2168192.0, 2173944.0, 2179696.0, 2185448.0, 2191200.0, 2196952.0, 2202704.0, 2208456.0, - 2214208.0, 2219960.0, 2225712.0, 2231464.0, 2237216.0, 2242968.0, 2248720.0, 2254472.0, - 2264448.0, 2270456.0, 2276464.0, 2282472.0, 2288480.0, 2294488.0, 2300496.0, 2306504.0, - 2312512.0, 2318520.0, 2324528.0, 2330536.0, 2336544.0, 2342552.0, 2348560.0, 2354568.0, - 2360704.0, 2366968.0, 2373232.0, 2379496.0, 2385760.0, 2392024.0, 2398288.0, 2404552.0, - 2410816.0, 2417080.0, 2423344.0, 2429608.0, 2435872.0, 2442136.0, 2448400.0, 2454664.0, - 2456960.0, 2463480.0, 2470000.0, 2476520.0, 2483040.0, 2489560.0, 2496080.0, 2502600.0, - 2509120.0, 2515640.0, 2522160.0, 2528680.0, 2535200.0, 2541720.0, 2548240.0, 2554760.0, - 2553216.0, 2559992.0, 2566768.0, 2573544.0, 2580320.0, 2587096.0, 2593872.0, 2600648.0, - 2607424.0, 2614200.0, 2620976.0, 2627752.0, 2634528.0, 2641304.0, 2648080.0, 2654856.0, - 2649472.0, 2656504.0, 2663536.0, 2670568.0, 2677600.0, 2684632.0, 2691664.0, 2698696.0, - 2705728.0, 2712760.0, 2719792.0, 2726824.0, 2733856.0, 2740888.0, 2747920.0, 2754952.0, - 2745728.0, 2753016.0, 2760304.0, 2767592.0, 2774880.0, 2782168.0, 2789456.0, 2796744.0, - 2804032.0, 2811320.0, 2818608.0, 2825896.0, 2833184.0, 2840472.0, 2847760.0, 2855048.0, - 2841984.0, 2849528.0, 2857072.0, 2864616.0, 2872160.0, 2879704.0, 2887248.0, 2894792.0, - 2902336.0, 2909880.0, 2917424.0, 2924968.0, 2932512.0, 2940056.0, 2947600.0, 2955144.0, - 2938240.0, 2946040.0, 2953840.0, 2961640.0, 2969440.0, 2977240.0, 2985040.0, 2992840.0, - 3000640.0, 3008440.0, 3016240.0, 3024040.0, 3031840.0, 3039640.0, 3047440.0, 3055240.0, - 3034496.0, 3042552.0, 3050608.0, 3058664.0, 3066720.0, 3074776.0, 3082832.0, 3090888.0, - 3098944.0, 3107000.0, 3115056.0, 3123112.0, 3131168.0, 3139224.0, 3147280.0, 3155336.0, - ], - device, - ); -} diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs deleted file mode 100644 index b00b319c..00000000 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/load_shared_memory.rs +++ /dev/null @@ -1,703 +0,0 @@ -use std::ops::Range; - -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -use crate::matmul::cmma::base::{Dimensions, Ids, Offsets, RuntimeCmmaInfo}; -use crate::matmul::cmma::config::CmmaConfig; -use crate::matmul::tests::test_utils::{assert_equals_range, create_empty}; -use crate::matmul::{ - cmma::{config::ComptimeCmmaInfo, load_shared_memory::*}, - tests::test_utils::range_tensor, -}; - -use super::base::{DimsTestCase, B_K, B_MN}; - -#[cube(launch_unchecked)] -fn load_lhs_test( - lhs_tensor: &Tensor, - lhs_sm_arr: &mut Array, - k_offset: u32, - m: u32, - k: u32, - n: u32, - #[comptime] config: ComptimeCmmaInfo, -) { - let block_size_m = config.block_size_m; - let block_size_k = config.block_size_k; - let sm_size = block_size_k * block_size_m; - - let mut lhs_sm = SharedMemory::::new(sm_size); - for i in 0..sm_size { - lhs_sm[i] = F::new(0.); - } - - sync_units(); - - let offsets = Offsets { - batch_lhs: 0, - batch_rhs: 0, - batch_out: 0, - cube_row: 0, - cube_col: 0, - }; - let dims = Dimensions { m, k, n }; - let ids = Ids { - coop: UNIT_POS_Y, - lane: UNIT_POS_X, - }; - let runtime_info = RuntimeCmmaInfo { offsets, dims, ids }; - - load_lhs(lhs_tensor, &mut lhs_sm, 2, k_offset, runtime_info, config); - - sync_units(); - - for i in 0..sm_size { - lhs_sm_arr[i] = lhs_sm[i]; - } -} - -#[cube(launch_unchecked)] -fn load_rhs_test( - rhs_tensor: &Tensor, - rhs_sm_arr: &mut Array, - k_offset: u32, - m: u32, - k: u32, - n: u32, - #[comptime] config: ComptimeCmmaInfo, -) { - let block_size_k = config.block_size_k; - let block_size_n = config.block_size_n; - let sm_size = block_size_k * block_size_n; - let mut rhs_sm = SharedMemory::::new(sm_size); - - for i in 0..sm_size { - rhs_sm[i] = F::new(0.); - } - - sync_units(); - - let offsets = Offsets { - batch_lhs: 0, - batch_rhs: 0, - batch_out: 0, - cube_row: 0, - cube_col: 0, - }; - let dims = Dimensions { m, k, n }; - let ids = Ids { - coop: UNIT_POS_Y, - lane: UNIT_POS_X, - }; - let runtime_info = RuntimeCmmaInfo { offsets, dims, ids }; - - load_rhs(rhs_tensor, &mut rhs_sm, 2, k_offset, runtime_info, config); - - sync_units(); - - for i in 0..sm_size { - rhs_sm_arr[i] = rhs_sm[i]; - } -} - -enum InputTensor { - Lhs, - Rhs, -} - -fn load_shared_memory_test_case( - input: InputTensor, - dims: DimsTestCase, - k_offset: usize, - config: CmmaConfig, - expected: &[f32], - device: &R::Device, - range: Range, -) { - let client = R::client(device); - - for vectorization in [1, 2, 4] { - let smem = create_empty::(&client, config.b_k, config.b_mn); - let smem_size = config.b_k * config.b_mn; - - match input { - InputTensor::Lhs => { - let tensor = range_tensor::(&client, dims.m, dims.k); - - unsafe { - load_lhs_test::launch_unchecked::( - &R::client(device), - config.cube_count::(&[dims.m, dims.n]), - config.cube_dim(), - TensorArg::from_raw_parts( - &tensor.handle, - &tensor.strides, - &tensor.shape, - vectorization, - ), - ArrayArg::from_raw_parts(&smem, smem_size, 1), - ScalarArg::new(k_offset as u32), - ScalarArg::new(dims.m as u32), - ScalarArg::new(dims.k as u32), - ScalarArg::new(dims.n as u32), - config.comptime_info(dims.m, dims.k, dims.n), - ); - }; - } - InputTensor::Rhs => { - let tensor = range_tensor::(&client, dims.k, dims.n); - - unsafe { - load_rhs_test::launch_unchecked::( - &R::client(device), - config.cube_count::(&[dims.m, dims.n]), - config.cube_dim(), - TensorArg::from_raw_parts( - &tensor.handle, - &tensor.strides, - &tensor.shape, - vectorization, - ), - ArrayArg::from_raw_parts(&smem, smem_size, 1), - ScalarArg::new(k_offset as u32), - ScalarArg::new(dims.m as u32), - ScalarArg::new(dims.k as u32), - ScalarArg::new(dims.n as u32), - config.comptime_info(dims.m, dims.k, dims.n), - ); - }; - } - } - - assert_equals_range::(&client, smem, expected, range.clone()); - } -} - -/// Exported test -pub fn load_shared_memory_lhs_warp_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Lhs, - DimsTestCase { - m: 64, - k: 32, - n: 64, - }, - 0, - CmmaConfig { - b_mn: B_MN, - b_k: B_K, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, - 46.0, 47.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, - 76.0, 77.0, 78.0, 79.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, - 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 128.0, 129.0, 130.0, 131.0, 132.0, - 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, 143.0, 160.0, - 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 169.0, 170.0, 171.0, 172.0, - 173.0, 174.0, 175.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 200.0, - 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 224.0, 225.0, 226.0, 227.0, 228.0, - 229.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, 236.0, 237.0, 238.0, 239.0, 256.0, - 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, 268.0, - 269.0, 270.0, 271.0, 288.0, 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, - 297.0, 298.0, 299.0, 300.0, 301.0, 302.0, 303.0, 320.0, 321.0, 322.0, 323.0, 324.0, - 325.0, 326.0, 327.0, 328.0, 329.0, 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 352.0, - 353.0, 354.0, 355.0, 356.0, 357.0, 358.0, 359.0, 360.0, 361.0, 362.0, 363.0, 364.0, - 365.0, 366.0, 367.0, 384.0, 385.0, 386.0, 387.0, 388.0, 389.0, 390.0, 391.0, 392.0, - 393.0, 394.0, 395.0, 396.0, 397.0, 398.0, 399.0, 416.0, 417.0, 418.0, 419.0, 420.0, - 421.0, 422.0, 423.0, 424.0, 425.0, 426.0, 427.0, 428.0, 429.0, 430.0, 431.0, 448.0, - 449.0, 450.0, 451.0, 452.0, 453.0, 454.0, 455.0, 456.0, 457.0, 458.0, 459.0, 460.0, - 461.0, 462.0, 463.0, 480.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, 487.0, 488.0, - 489.0, 490.0, 491.0, 492.0, 493.0, 494.0, 495.0, - ], - device, - 0..256, - ); -} - -/// Exported test -pub fn load_shared_memory_rhs_warp_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Rhs, - DimsTestCase { - m: 64, - k: 32, - n: 64, - }, - 0, - CmmaConfig { - b_mn: B_MN, - b_k: B_K, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, - 78.0, 79.0, 128.0, 129.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, - 138.0, 139.0, 140.0, 141.0, 142.0, 143.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, - 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 256.0, 257.0, - 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, 268.0, 269.0, - 270.0, 271.0, 320.0, 321.0, 322.0, 323.0, 324.0, 325.0, 326.0, 327.0, 328.0, 329.0, - 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 384.0, 385.0, 386.0, 387.0, 388.0, 389.0, - 390.0, 391.0, 392.0, 393.0, 394.0, 395.0, 396.0, 397.0, 398.0, 399.0, 448.0, 449.0, - 450.0, 451.0, 452.0, 453.0, 454.0, 455.0, 456.0, 457.0, 458.0, 459.0, 460.0, 461.0, - 462.0, 463.0, 512.0, 513.0, 514.0, 515.0, 516.0, 517.0, 518.0, 519.0, 520.0, 521.0, - 522.0, 523.0, 524.0, 525.0, 526.0, 527.0, 576.0, 577.0, 578.0, 579.0, 580.0, 581.0, - 582.0, 583.0, 584.0, 585.0, 586.0, 587.0, 588.0, 589.0, 590.0, 591.0, 640.0, 641.0, - 642.0, 643.0, 644.0, 645.0, 646.0, 647.0, 648.0, 649.0, 650.0, 651.0, 652.0, 653.0, - 654.0, 655.0, 704.0, 705.0, 706.0, 707.0, 708.0, 709.0, 710.0, 711.0, 712.0, 713.0, - 714.0, 715.0, 716.0, 717.0, 718.0, 719.0, 768.0, 769.0, 770.0, 771.0, 772.0, 773.0, - 774.0, 775.0, 776.0, 777.0, 778.0, 779.0, 780.0, 781.0, 782.0, 783.0, 832.0, 833.0, - 834.0, 835.0, 836.0, 837.0, 838.0, 839.0, 840.0, 841.0, 842.0, 843.0, 844.0, 845.0, - 846.0, 847.0, 896.0, 897.0, 898.0, 899.0, 900.0, 901.0, 902.0, 903.0, 904.0, 905.0, - 906.0, 907.0, 908.0, 909.0, 910.0, 911.0, 960.0, 961.0, 962.0, 963.0, 964.0, 965.0, - 966.0, 967.0, 968.0, 969.0, 970.0, 971.0, 972.0, 973.0, 974.0, 975.0, - ], - device, - 0..256, - ); -} - -/// Exported test -pub fn load_shared_memory_lhs_vertical_out_of_bound_warp_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Lhs, - DimsTestCase { - m: 12, - k: 64, - n: 64, - }, - 0, - CmmaConfig { - b_mn: B_MN, - b_k: B_K, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, - 78.0, 79.0, 128.0, 129.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, - 138.0, 139.0, 140.0, 141.0, 142.0, 143.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, - 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 256.0, 257.0, - 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, 268.0, 269.0, - 270.0, 271.0, 320.0, 321.0, 322.0, 323.0, 324.0, 325.0, 326.0, 327.0, 328.0, 329.0, - 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 384.0, 385.0, 386.0, 387.0, 388.0, 389.0, - 390.0, 391.0, 392.0, 393.0, 394.0, 395.0, 396.0, 397.0, 398.0, 399.0, 448.0, 449.0, - 450.0, 451.0, 452.0, 453.0, 454.0, 455.0, 456.0, 457.0, 458.0, 459.0, 460.0, 461.0, - 462.0, 463.0, 512.0, 513.0, 514.0, 515.0, 516.0, 517.0, 518.0, 519.0, 520.0, 521.0, - 522.0, 523.0, 524.0, 525.0, 526.0, 527.0, 576.0, 577.0, 578.0, 579.0, 580.0, 581.0, - 582.0, 583.0, 584.0, 585.0, 586.0, 587.0, 588.0, 589.0, 590.0, 591.0, 640.0, 641.0, - 642.0, 643.0, 644.0, 645.0, 646.0, 647.0, 648.0, 649.0, 650.0, 651.0, 652.0, 653.0, - 654.0, 655.0, 704.0, 705.0, 706.0, 707.0, 708.0, 709.0, 710.0, 711.0, 712.0, 713.0, - 714.0, 715.0, 716.0, 717.0, 718.0, 719.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, - ], - device, - 0..256, - ); -} - -/// Exported test -pub fn load_shared_memory_lhs_horizontal_out_of_bound_warp_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Lhs, - DimsTestCase { - m: 64, - k: 12, - n: 64, - }, - 0, - CmmaConfig { - b_mn: B_MN, - b_k: B_K, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 0.0, 0.0, 0.0, 0.0, - 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 0.0, 0.0, 0.0, - 0.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 0.0, 0.0, - 0.0, 0.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 0.0, - 0.0, 0.0, 0.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, - 0.0, 0.0, 0.0, 0.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, - 83.0, 0.0, 0.0, 0.0, 0.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, - 94.0, 95.0, 0.0, 0.0, 0.0, 0.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, - 104.0, 105.0, 106.0, 107.0, 0.0, 0.0, 0.0, 0.0, 108.0, 109.0, 110.0, 111.0, 112.0, - 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 0.0, 0.0, 0.0, 0.0, 120.0, 121.0, - 122.0, 123.0, 124.0, 125.0, 126.0, 127.0, 128.0, 129.0, 130.0, 131.0, 0.0, 0.0, 0.0, - 0.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, - 143.0, 0.0, 0.0, 0.0, 0.0, 144.0, 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, - 152.0, 153.0, 154.0, 155.0, 0.0, 0.0, 0.0, 0.0, 156.0, 157.0, 158.0, 159.0, 160.0, - 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 0.0, 0.0, 0.0, 0.0, 168.0, 169.0, - 170.0, 171.0, 172.0, 173.0, 174.0, 175.0, 176.0, 177.0, 178.0, 179.0, 0.0, 0.0, 0.0, - 0.0, 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, - 191.0, 0.0, 0.0, 0.0, 0.0, - ], - device, - 0..256, - ); -} - -/// Exported test -pub fn load_shared_memory_lhs_whole_out_of_bound_warp_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Lhs, - DimsTestCase { - m: 12, - k: 12, - n: 64, - }, - 0, - CmmaConfig { - b_mn: B_MN, - b_k: B_K, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 0.0, 0.0, 0.0, 0.0, 12.0, - 13.0, 14.0, 15.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 0.0, 0.0, 0.0, 0.0, - 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 32.0, 33.0, 34.0, 35.0, 0.0, 0.0, 0.0, - 0.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 0.0, 0.0, - 0.0, 0.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 0.0, - 0.0, 0.0, 0.0, 60.0, 61.0, 62.0, 63.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, - 0.0, 0.0, 0.0, 0.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 80.0, 81.0, 82.0, - 83.0, 0.0, 0.0, 0.0, 0.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, - 94.0, 95.0, 0.0, 0.0, 0.0, 0.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, - 104.0, 105.0, 106.0, 107.0, 0.0, 0.0, 0.0, 0.0, 108.0, 109.0, 110.0, 111.0, 112.0, - 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 0.0, 0.0, 0.0, 0.0, 120.0, 121.0, - 122.0, 123.0, 124.0, 125.0, 126.0, 127.0, 128.0, 129.0, 130.0, 131.0, 0.0, 0.0, 0.0, - 0.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, - 143.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, - 0.0, - ], - device, - 0..256, - ); -} - -/// Exported test -pub fn load_shared_memory_lhs_second_warp_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Lhs, - DimsTestCase { - m: 64, - k: 64, - n: 64, - }, - 0, - CmmaConfig { - b_mn: B_MN, - b_k: B_K, - ..Default::default() - }, - &[ - 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 80., - 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 144., 145., - 146., 147., 148., 149., 150., 151., 152., 153., 154., 155., 156., 157., 158., 159., - 208., 209., 210., 211., 212., 213., 214., 215., 216., 217., 218., 219., 220., 221., - 222., 223., 272., 273., 274., 275., 276., 277., 278., 279., 280., 281., 282., 283., - 284., 285., 286., 287., 336., 337., 338., 339., 340., 341., 342., 343., 344., 345., - 346., 347., 348., 349., 350., 351., 400., 401., 402., 403., 404., 405., 406., 407., - 408., 409., 410., 411., 412., 413., 414., 415., 464., 465., 466., 467., 468., 469., - 470., 471., 472., 473., 474., 475., 476., 477., 478., 479., 528., 529., 530., 531., - 532., 533., 534., 535., 536., 537., 538., 539., 540., 541., 542., 543., 592., 593., - 594., 595., 596., 597., 598., 599., 600., 601., 602., 603., 604., 605., 606., 607., - 656., 657., 658., 659., 660., 661., 662., 663., 664., 665., 666., 667., 668., 669., - 670., 671., 720., 721., 722., 723., 724., 725., 726., 727., 728., 729., 730., 731., - 732., 733., 734., 735., 784., 785., 786., 787., 788., 789., 790., 791., 792., 793., - 794., 795., 796., 797., 798., 799., 848., 849., 850., 851., 852., 853., 854., 855., - 856., 857., 858., 859., 860., 861., 862., 863., 912., 913., 914., 915., 916., 917., - 918., 919., 920., 921., 922., 923., 924., 925., 926., 927., 976., 977., 978., 979., - 980., 981., 982., 983., 984., 985., 986., 987., 988., 989., 990., 991., - ], - device, - 256..512, - ); -} - -/// Exported test -pub fn load_shared_memory_rhs_second_warp_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Rhs, - DimsTestCase { - m: 64, - k: 64, - n: 64, - }, - 0, - CmmaConfig { - b_mn: B_MN, - b_k: B_K, - ..Default::default() - }, - &[ - 1024., 1025., 1026., 1027., 1028., 1029., 1030., 1031., 1032., 1033., 1034., 1035., - 1036., 1037., 1038., 1039., 1088., 1089., 1090., 1091., 1092., 1093., 1094., 1095., - 1096., 1097., 1098., 1099., 1100., 1101., 1102., 1103., 1152., 1153., 1154., 1155., - 1156., 1157., 1158., 1159., 1160., 1161., 1162., 1163., 1164., 1165., 1166., 1167., - 1216., 1217., 1218., 1219., 1220., 1221., 1222., 1223., 1224., 1225., 1226., 1227., - 1228., 1229., 1230., 1231., 1280., 1281., 1282., 1283., 1284., 1285., 1286., 1287., - 1288., 1289., 1290., 1291., 1292., 1293., 1294., 1295., 1344., 1345., 1346., 1347., - 1348., 1349., 1350., 1351., 1352., 1353., 1354., 1355., 1356., 1357., 1358., 1359., - 1408., 1409., 1410., 1411., 1412., 1413., 1414., 1415., 1416., 1417., 1418., 1419., - 1420., 1421., 1422., 1423., 1472., 1473., 1474., 1475., 1476., 1477., 1478., 1479., - 1480., 1481., 1482., 1483., 1484., 1485., 1486., 1487., 1536., 1537., 1538., 1539., - 1540., 1541., 1542., 1543., 1544., 1545., 1546., 1547., 1548., 1549., 1550., 1551., - 1600., 1601., 1602., 1603., 1604., 1605., 1606., 1607., 1608., 1609., 1610., 1611., - 1612., 1613., 1614., 1615., 1664., 1665., 1666., 1667., 1668., 1669., 1670., 1671., - 1672., 1673., 1674., 1675., 1676., 1677., 1678., 1679., 1728., 1729., 1730., 1731., - 1732., 1733., 1734., 1735., 1736., 1737., 1738., 1739., 1740., 1741., 1742., 1743., - 1792., 1793., 1794., 1795., 1796., 1797., 1798., 1799., 1800., 1801., 1802., 1803., - 1804., 1805., 1806., 1807., 1856., 1857., 1858., 1859., 1860., 1861., 1862., 1863., - 1864., 1865., 1866., 1867., 1868., 1869., 1870., 1871., 1920., 1921., 1922., 1923., - 1924., 1925., 1926., 1927., 1928., 1929., 1930., 1931., 1932., 1933., 1934., 1935., - 1984., 1985., 1986., 1987., 1988., 1989., 1990., 1991., 1992., 1993., 1994., 1995., - 1996., 1997., 1998., 1999., - ], - device, - 256..512, - ); -} - -/// Exported test -pub fn load_shared_memory_lhs_third_warp_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Lhs, - DimsTestCase { - m: 64, - k: 64, - n: 64, - }, - 0, - CmmaConfig { - b_mn: B_MN, - b_k: B_K, - ..Default::default() - }, - &[ - 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, - 30.0, 31.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, - 92.0, 93.0, 94.0, 95.0, 144.0, 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, - 153.0, 154.0, 155.0, 156.0, 157.0, 158.0, 159.0, 208.0, 209.0, 210.0, 211.0, 212.0, - 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0, 272.0, - 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0, 281.0, 282.0, 283.0, 284.0, - 285.0, 286.0, 287.0, 336.0, 337.0, 338.0, 339.0, 340.0, 341.0, 342.0, 343.0, 344.0, - 345.0, 346.0, 347.0, 348.0, 349.0, 350.0, 351.0, 400.0, 401.0, 402.0, 403.0, 404.0, - 405.0, 406.0, 407.0, 408.0, 409.0, 410.0, 411.0, 412.0, 413.0, 414.0, 415.0, 464.0, - 465.0, 466.0, 467.0, 468.0, 469.0, 470.0, 471.0, 472.0, 473.0, 474.0, 475.0, 476.0, - 477.0, 478.0, 479.0, 528.0, 529.0, 530.0, 531.0, 532.0, 533.0, 534.0, 535.0, 536.0, - 537.0, 538.0, 539.0, 540.0, 541.0, 542.0, 543.0, 592.0, 593.0, 594.0, 595.0, 596.0, - 597.0, 598.0, 599.0, 600.0, 601.0, 602.0, 603.0, 604.0, 605.0, 606.0, 607.0, 656.0, - 657.0, 658.0, 659.0, 660.0, 661.0, 662.0, 663.0, 664.0, 665.0, 666.0, 667.0, 668.0, - 669.0, 670.0, 671.0, 720.0, 721.0, 722.0, 723.0, 724.0, 725.0, 726.0, 727.0, 728.0, - 729.0, 730.0, 731.0, 732.0, 733.0, 734.0, 735.0, 784.0, 785.0, 786.0, 787.0, 788.0, - 789.0, 790.0, 791.0, 792.0, 793.0, 794.0, 795.0, 796.0, 797.0, 798.0, 799.0, 848.0, - 849.0, 850.0, 851.0, 852.0, 853.0, 854.0, 855.0, 856.0, 857.0, 858.0, 859.0, 860.0, - 861.0, 862.0, 863.0, 912.0, 913.0, 914.0, 915.0, 916.0, 917.0, 918.0, 919.0, 920.0, - 921.0, 922.0, 923.0, 924.0, 925.0, 926.0, 927.0, 976.0, 977.0, 978.0, 979.0, 980.0, - 981.0, 982.0, 983.0, 984.0, 985.0, 986.0, 987.0, 988.0, 989.0, 990.0, 991.0, - ], - device, - 256..512, - ); -} - -/// Exported test -pub fn load_shared_memory_rhs_third_warp_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Rhs, - DimsTestCase { - m: 64, - k: 64, - n: 64, - }, - 0, - CmmaConfig { - b_mn: 64, - b_k: 32, - ..Default::default() - }, - &[ - 16., 17., 18., 19., 20., 21., 22., 23., 24., 25., 26., 27., 28., 29., 30., 31., 80., - 81., 82., 83., 84., 85., 86., 87., 88., 89., 90., 91., 92., 93., 94., 95., 144., 145., - 146., 147., 148., 149., 150., 151., 152., 153., 154., 155., 156., 157., 158., 159., - 208., 209., 210., 211., 212., 213., 214., 215., 216., 217., 218., 219., 220., 221., - 222., 223., 272., 273., 274., 275., 276., 277., 278., 279., 280., 281., 282., 283., - 284., 285., 286., 287., 336., 337., 338., 339., 340., 341., 342., 343., 344., 345., - 346., 347., 348., 349., 350., 351., 400., 401., 402., 403., 404., 405., 406., 407., - 408., 409., 410., 411., 412., 413., 414., 415., 464., 465., 466., 467., 468., 469., - 470., 471., 472., 473., 474., 475., 476., 477., 478., 479., 528., 529., 530., 531., - 532., 533., 534., 535., 536., 537., 538., 539., 540., 541., 542., 543., 592., 593., - 594., 595., 596., 597., 598., 599., 600., 601., 602., 603., 604., 605., 606., 607., - 656., 657., 658., 659., 660., 661., 662., 663., 664., 665., 666., 667., 668., 669., - 670., 671., 720., 721., 722., 723., 724., 725., 726., 727., 728., 729., 730., 731., - 732., 733., 734., 735., 784., 785., 786., 787., 788., 789., 790., 791., 792., 793., - 794., 795., 796., 797., 798., 799., 848., 849., 850., 851., 852., 853., 854., 855., - 856., 857., 858., 859., 860., 861., 862., 863., 912., 913., 914., 915., 916., 917., - 918., 919., 920., 921., 922., 923., 924., 925., 926., 927., 976., 977., 978., 979., - 980., 981., 982., 983., 984., 985., 986., 987., 988., 989., 990., 991., - ], - device, - 512..768, - ); -} - -/// Exported test -pub fn load_shared_memory_lhs_k_offset_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Lhs, - DimsTestCase { - m: 64, - k: 64, - n: 64, - }, - 32, - CmmaConfig { - b_mn: B_MN, - b_k: B_K, - ..Default::default() - }, - &[ - 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, - 46.0, 47.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, - 107.0, 108.0, 109.0, 110.0, 111.0, 160.0, 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, - 167.0, 168.0, 169.0, 170.0, 171.0, 172.0, 173.0, 174.0, 175.0, 224.0, 225.0, 226.0, - 227.0, 228.0, 229.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, 236.0, 237.0, 238.0, - 239.0, 288.0, 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 297.0, 298.0, - 299.0, 300.0, 301.0, 302.0, 303.0, 352.0, 353.0, 354.0, 355.0, 356.0, 357.0, 358.0, - 359.0, 360.0, 361.0, 362.0, 363.0, 364.0, 365.0, 366.0, 367.0, 416.0, 417.0, 418.0, - 419.0, 420.0, 421.0, 422.0, 423.0, 424.0, 425.0, 426.0, 427.0, 428.0, 429.0, 430.0, - 431.0, 480.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, 487.0, 488.0, 489.0, 490.0, - 491.0, 492.0, 493.0, 494.0, 495.0, 544.0, 545.0, 546.0, 547.0, 548.0, 549.0, 550.0, - 551.0, 552.0, 553.0, 554.0, 555.0, 556.0, 557.0, 558.0, 559.0, 608.0, 609.0, 610.0, - 611.0, 612.0, 613.0, 614.0, 615.0, 616.0, 617.0, 618.0, 619.0, 620.0, 621.0, 622.0, - 623.0, 672.0, 673.0, 674.0, 675.0, 676.0, 677.0, 678.0, 679.0, 680.0, 681.0, 682.0, - 683.0, 684.0, 685.0, 686.0, 687.0, 736.0, 737.0, 738.0, 739.0, 740.0, 741.0, 742.0, - 743.0, 744.0, 745.0, 746.0, 747.0, 748.0, 749.0, 750.0, 751.0, 800.0, 801.0, 802.0, - 803.0, 804.0, 805.0, 806.0, 807.0, 808.0, 809.0, 810.0, 811.0, 812.0, 813.0, 814.0, - 815.0, 864.0, 865.0, 866.0, 867.0, 868.0, 869.0, 870.0, 871.0, 872.0, 873.0, 874.0, - 875.0, 876.0, 877.0, 878.0, 879.0, 928.0, 929.0, 930.0, 931.0, 932.0, 933.0, 934.0, - 935.0, 936.0, 937.0, 938.0, 939.0, 940.0, 941.0, 942.0, 943.0, 992.0, 993.0, 994.0, - 995.0, 996.0, 997.0, 998.0, 999.0, 1000.0, 1001.0, 1002.0, 1003.0, 1004.0, 1005.0, - 1006.0, 1007.0, - ], - device, - 0..256, - ); -} - -/// Exported test -pub fn load_shared_memory_rhs_k_offset_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Rhs, - DimsTestCase { - m: 64, - k: 64, - n: 64, - }, - 32, - CmmaConfig { - b_mn: B_MN, - b_k: B_K, - ..Default::default() - }, - &[ - 2048., 2049., 2050., 2051., 2052., 2053., 2054., 2055., 2056., 2057., 2058., 2059., - 2060., 2061., 2062., 2063., 2112., 2113., 2114., 2115., 2116., 2117., 2118., 2119., - 2120., 2121., 2122., 2123., 2124., 2125., 2126., 2127., 2176., 2177., 2178., 2179., - 2180., 2181., 2182., 2183., 2184., 2185., 2186., 2187., 2188., 2189., 2190., 2191., - 2240., 2241., 2242., 2243., 2244., 2245., 2246., 2247., 2248., 2249., 2250., 2251., - 2252., 2253., 2254., 2255., 2304., 2305., 2306., 2307., 2308., 2309., 2310., 2311., - 2312., 2313., 2314., 2315., 2316., 2317., 2318., 2319., 2368., 2369., 2370., 2371., - 2372., 2373., 2374., 2375., 2376., 2377., 2378., 2379., 2380., 2381., 2382., 2383., - 2432., 2433., 2434., 2435., 2436., 2437., 2438., 2439., 2440., 2441., 2442., 2443., - 2444., 2445., 2446., 2447., 2496., 2497., 2498., 2499., 2500., 2501., 2502., 2503., - 2504., 2505., 2506., 2507., 2508., 2509., 2510., 2511., 2560., 2561., 2562., 2563., - 2564., 2565., 2566., 2567., 2568., 2569., 2570., 2571., 2572., 2573., 2574., 2575., - 2624., 2625., 2626., 2627., 2628., 2629., 2630., 2631., 2632., 2633., 2634., 2635., - 2636., 2637., 2638., 2639., 2688., 2689., 2690., 2691., 2692., 2693., 2694., 2695., - 2696., 2697., 2698., 2699., 2700., 2701., 2702., 2703., 2752., 2753., 2754., 2755., - 2756., 2757., 2758., 2759., 2760., 2761., 2762., 2763., 2764., 2765., 2766., 2767., - 2816., 2817., 2818., 2819., 2820., 2821., 2822., 2823., 2824., 2825., 2826., 2827., - 2828., 2829., 2830., 2831., 2880., 2881., 2882., 2883., 2884., 2885., 2886., 2887., - 2888., 2889., 2890., 2891., 2892., 2893., 2894., 2895., 2944., 2945., 2946., 2947., - 2948., 2949., 2950., 2951., 2952., 2953., 2954., 2955., 2956., 2957., 2958., 2959., - 3008., 3009., 3010., 3011., 3012., 3013., 3014., 3015., 3016., 3017., 3018., 3019., - 3020., 3021., 3022., 3023., - ], - device, - 0..256, - ); -} - -/// Exported test -pub fn load_shared_memory_rhs_larger_block_test(device: &R::Device) { - load_shared_memory_test_case::( - InputTensor::Rhs, - DimsTestCase { - m: 16, - k: 32, - n: 32, - }, - 0, - CmmaConfig { - b_mn: 32, - b_k: 32, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, - 46.0, 47.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, - 76.0, 77.0, 78.0, 79.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, - 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 128.0, 129.0, 130.0, 131.0, 132.0, - 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, 143.0, 160.0, - 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 169.0, 170.0, 171.0, 172.0, - 173.0, 174.0, 175.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 200.0, - 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 224.0, 225.0, 226.0, 227.0, 228.0, - 229.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, 236.0, 237.0, 238.0, 239.0, 256.0, - 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, 268.0, - 269.0, 270.0, 271.0, 288.0, 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, - 297.0, 298.0, 299.0, 300.0, 301.0, 302.0, 303.0, 320.0, 321.0, 322.0, 323.0, 324.0, - 325.0, 326.0, 327.0, 328.0, 329.0, 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 352.0, - 353.0, 354.0, 355.0, 356.0, 357.0, 358.0, 359.0, 360.0, 361.0, 362.0, 363.0, 364.0, - 365.0, 366.0, 367.0, 384.0, 385.0, 386.0, 387.0, 388.0, 389.0, 390.0, 391.0, 392.0, - 393.0, 394.0, 395.0, 396.0, 397.0, 398.0, 399.0, 416.0, 417.0, 418.0, 419.0, 420.0, - 421.0, 422.0, 423.0, 424.0, 425.0, 426.0, 427.0, 428.0, 429.0, 430.0, 431.0, 448.0, - 449.0, 450.0, 451.0, 452.0, 453.0, 454.0, 455.0, 456.0, 457.0, 458.0, 459.0, 460.0, - 461.0, 462.0, 463.0, 480.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, 487.0, 488.0, - 489.0, 490.0, 491.0, 492.0, 493.0, 494.0, 495.0, 512.0, 513.0, 514.0, 515.0, 516.0, - 517.0, 518.0, 519.0, 520.0, 521.0, 522.0, 523.0, 524.0, 525.0, 526.0, 527.0, 544.0, - 545.0, 546.0, 547.0, 548.0, 549.0, 550.0, 551.0, 552.0, 553.0, 554.0, 555.0, 556.0, - 557.0, 558.0, 559.0, 576.0, 577.0, 578.0, 579.0, 580.0, 581.0, 582.0, 583.0, 584.0, - 585.0, 586.0, 587.0, 588.0, 589.0, 590.0, 591.0, 608.0, 609.0, 610.0, 611.0, 612.0, - 613.0, 614.0, 615.0, 616.0, 617.0, 618.0, 619.0, 620.0, 621.0, 622.0, 623.0, 640.0, - 641.0, 642.0, 643.0, 644.0, 645.0, 646.0, 647.0, 648.0, 649.0, 650.0, 651.0, 652.0, - 653.0, 654.0, 655.0, 672.0, 673.0, 674.0, 675.0, 676.0, 677.0, 678.0, 679.0, 680.0, - 681.0, 682.0, 683.0, 684.0, 685.0, 686.0, 687.0, 704.0, 705.0, 706.0, 707.0, 708.0, - 709.0, 710.0, 711.0, 712.0, 713.0, 714.0, 715.0, 716.0, 717.0, 718.0, 719.0, 736.0, - 737.0, 738.0, 739.0, 740.0, 741.0, 742.0, 743.0, 744.0, 745.0, 746.0, 747.0, 748.0, - 749.0, 750.0, 751.0, 768.0, 769.0, 770.0, 771.0, 772.0, 773.0, 774.0, 775.0, 776.0, - 777.0, 778.0, 779.0, 780.0, 781.0, 782.0, 783.0, 800.0, 801.0, 802.0, 803.0, 804.0, - 805.0, 806.0, 807.0, 808.0, 809.0, 810.0, 811.0, 812.0, 813.0, 814.0, 815.0, 832.0, - 833.0, 834.0, 835.0, 836.0, 837.0, 838.0, 839.0, 840.0, 841.0, 842.0, 843.0, 844.0, - 845.0, 846.0, 847.0, 864.0, 865.0, 866.0, 867.0, 868.0, 869.0, 870.0, 871.0, 872.0, - 873.0, 874.0, 875.0, 876.0, 877.0, 878.0, 879.0, 896.0, 897.0, 898.0, 899.0, 900.0, - 901.0, 902.0, 903.0, 904.0, 905.0, 906.0, 907.0, 908.0, 909.0, 910.0, 911.0, 928.0, - 929.0, 930.0, 931.0, 932.0, 933.0, 934.0, 935.0, 936.0, 937.0, 938.0, 939.0, 940.0, - 941.0, 942.0, 943.0, 960.0, 961.0, 962.0, 963.0, 964.0, 965.0, 966.0, 967.0, 968.0, - 969.0, 970.0, 971.0, 972.0, 973.0, 974.0, 975.0, 992.0, 993.0, 994.0, 995.0, 996.0, - 997.0, 998.0, 999.0, 1000.0, 1001.0, 1002.0, 1003.0, 1004.0, 1005.0, 1006.0, 1007.0, - ], - device, - 0..512, - ); -} diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/mod.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/mod.rs deleted file mode 100644 index afff3f95..00000000 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/mod.rs +++ /dev/null @@ -1,4 +0,0 @@ -mod base; -pub mod compute_loop; -pub mod load_shared_memory; -pub mod write_output; diff --git a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs b/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs deleted file mode 100644 index 3dccea93..00000000 --- a/crates/cubecl-linalg/src/matmul/tests/cmma/write_output.rs +++ /dev/null @@ -1,551 +0,0 @@ -use std::ops::Range; - -use cubecl_core as cubecl; -use cubecl_core::prelude::*; - -use crate::matmul::cmma::base::{get_row_col, Dimensions, Ids, Offsets, RuntimeCmmaInfo}; -use crate::matmul::cmma::config::{ - CmmaConfig, ComptimeCmmaInfo, CubeDispatchStrategy, WriteOutStrategy, -}; -use crate::matmul::cmma::write_output::base::shared_memory_to_output; -use crate::matmul::tests::test_utils::{ - assert_equals, assert_equals_range, range_tensor, zeros_tensor, -}; - -use super::base::DimsTestCase; - -#[cube(launch_unchecked)] -fn write_output_test( - out: &mut Tensor, - acc_sm_arr: &mut Array, - m: u32, - k: u32, - n: u32, - #[comptime] comptime_info: ComptimeCmmaInfo, -) { - let num_accumulators = comptime_info.num_accumulators; - let tile_size = comptime_info.tile_size; - let num_coops = comptime_info.num_coops; - - let sm_stride = tile_size * tile_size; - let sm_size = num_accumulators * num_coops * sm_stride; - - let mut accumulate = SharedMemory::::new(sm_size); - for i in 0..sm_size { - accumulate[i] = acc_sm_arr[i]; - } - - let (cube_row, cube_col) = get_row_col(comptime_info); - - let offsets = Offsets { - batch_lhs: 0, - batch_rhs: 0, - batch_out: 0, - cube_row, - cube_col, - }; - let dims = Dimensions { m, k, n }; - let ids = Ids { - coop: UNIT_POS_Y, - lane: UNIT_POS_X, - }; - let runtime_info = RuntimeCmmaInfo { offsets, dims, ids }; - - let smem_position_base = num_accumulators * ids.coop; - - #[unroll] - for n_iter in 0..num_accumulators { - shared_memory_to_output( - out, - smem_position_base + n_iter, - accumulate, - n_iter, - runtime_info, - comptime_info, - ); - } -} - -fn write_output_test_case( - dims: DimsTestCase, - config: CmmaConfig, - expected: &[f32], - device: &R::Device, - range: Option>, -) { - let client = R::client(device); - - for vectorization in [1, 2, 4] { - let out = zeros_tensor::(&client, dims.m, dims.n); - - let acc_sm = range_tensor::(&client, config.b_mn, config.b_mn); - - unsafe { - write_output_test::launch_unchecked::( - &client, - config.cube_count::(&[dims.m, dims.n]), - config.cube_dim(), - TensorArg::from_raw_parts(&out.handle, &out.strides, &out.shape, vectorization), - ArrayArg::from_raw_parts(&acc_sm.handle, acc_sm.shape.len(), 1), - ScalarArg::new(dims.m as u32), - ScalarArg::new(dims.k as u32), - ScalarArg::new(dims.n as u32), - config.comptime_info(dims.m, dims.k, dims.n), - ); - }; - - match range { - Some(ref range) => { - assert_equals_range::(&client, out.handle, expected, range.clone()) - } - None => assert_equals::(&client, out.handle, expected), - } - } -} - -/// Exported test -pub fn cmma_write_output_warp_test(device: &R::Device) { - write_output_test_case::( - DimsTestCase { - m: 16, - k: 16, - n: 32, - }, - CmmaConfig { - b_mn: 32, - b_k: 16, - write_out_strategy: WriteOutStrategy::LargeSmem, - cube_dispatch: CubeDispatchStrategy::ColMajor, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, - 268.0, 269.0, 270.0, 271.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, - 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 272.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, - 279.0, 280.0, 281.0, 282.0, 283.0, 284.0, 285.0, 286.0, 287.0, 32.0, 33.0, 34.0, 35.0, - 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 288.0, 289.0, - 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 297.0, 298.0, 299.0, 300.0, 301.0, - 302.0, 303.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, - 60.0, 61.0, 62.0, 63.0, 304.0, 305.0, 306.0, 307.0, 308.0, 309.0, 310.0, 311.0, 312.0, - 313.0, 314.0, 315.0, 316.0, 317.0, 318.0, 319.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, - 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 320.0, 321.0, 322.0, 323.0, - 324.0, 325.0, 326.0, 327.0, 328.0, 329.0, 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, - 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, - 94.0, 95.0, 336.0, 337.0, 338.0, 339.0, 340.0, 341.0, 342.0, 343.0, 344.0, 345.0, - 346.0, 347.0, 348.0, 349.0, 350.0, 351.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, - 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 352.0, 353.0, 354.0, - 355.0, 356.0, 357.0, 358.0, 359.0, 360.0, 361.0, 362.0, 363.0, 364.0, 365.0, 366.0, - 367.0, 112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, - 123.0, 124.0, 125.0, 126.0, 127.0, 368.0, 369.0, 370.0, 371.0, 372.0, 373.0, 374.0, - 375.0, 376.0, 377.0, 378.0, 379.0, 380.0, 381.0, 382.0, 383.0, 128.0, 129.0, 130.0, - 131.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, - 143.0, 384.0, 385.0, 386.0, 387.0, 388.0, 389.0, 390.0, 391.0, 392.0, 393.0, 394.0, - 395.0, 396.0, 397.0, 398.0, 399.0, 144.0, 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, - 151.0, 152.0, 153.0, 154.0, 155.0, 156.0, 157.0, 158.0, 159.0, 400.0, 401.0, 402.0, - 403.0, 404.0, 405.0, 406.0, 407.0, 408.0, 409.0, 410.0, 411.0, 412.0, 413.0, 414.0, - 415.0, 160.0, 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 169.0, 170.0, - 171.0, 172.0, 173.0, 174.0, 175.0, 416.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, - 423.0, 424.0, 425.0, 426.0, 427.0, 428.0, 429.0, 430.0, 431.0, 176.0, 177.0, 178.0, - 179.0, 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, - 191.0, 432.0, 433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0, - 443.0, 444.0, 445.0, 446.0, 447.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, - 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 448.0, 449.0, 450.0, - 451.0, 452.0, 453.0, 454.0, 455.0, 456.0, 457.0, 458.0, 459.0, 460.0, 461.0, 462.0, - 463.0, 208.0, 209.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, - 219.0, 220.0, 221.0, 222.0, 223.0, 464.0, 465.0, 466.0, 467.0, 468.0, 469.0, 470.0, - 471.0, 472.0, 473.0, 474.0, 475.0, 476.0, 477.0, 478.0, 479.0, 224.0, 225.0, 226.0, - 227.0, 228.0, 229.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, 236.0, 237.0, 238.0, - 239.0, 480.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, 487.0, 488.0, 489.0, 490.0, - 491.0, 492.0, 493.0, 494.0, 495.0, 240.0, 241.0, 242.0, 243.0, 244.0, 245.0, 246.0, - 247.0, 248.0, 249.0, 250.0, 251.0, 252.0, 253.0, 254.0, 255.0, 496.0, 497.0, 498.0, - 499.0, 500.0, 501.0, 502.0, 503.0, 504.0, 505.0, 506.0, 507.0, 508.0, 509.0, 510.0, - 511.0, - ], - device, - None, - ); -} - -/// Exported test -pub fn cmma_write_output_warp_horizontal_out_of_bounds_test(device: &R::Device) { - write_output_test_case::( - DimsTestCase { - m: 16, - k: 16, - n: 28, - }, - CmmaConfig { - b_mn: 32, - b_k: 16, - write_out_strategy: WriteOutStrategy::LargeSmem, - cube_dispatch: CubeDispatchStrategy::ColMajor, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, - 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, - 30.0, 31.0, 272.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0, 281.0, - 282.0, 283.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, - 44.0, 45.0, 46.0, 47.0, 288.0, 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, - 297.0, 298.0, 299.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, - 59.0, 60.0, 61.0, 62.0, 63.0, 304.0, 305.0, 306.0, 307.0, 308.0, 309.0, 310.0, 311.0, - 312.0, 313.0, 314.0, 315.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, - 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 320.0, 321.0, 322.0, 323.0, 324.0, 325.0, 326.0, - 327.0, 328.0, 329.0, 330.0, 331.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, - 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 336.0, 337.0, 338.0, 339.0, 340.0, - 341.0, 342.0, 343.0, 344.0, 345.0, 346.0, 347.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, - 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 352.0, 353.0, - 354.0, 355.0, 356.0, 357.0, 358.0, 359.0, 360.0, 361.0, 362.0, 363.0, 112.0, 113.0, - 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, - 126.0, 127.0, 368.0, 369.0, 370.0, 371.0, 372.0, 373.0, 374.0, 375.0, 376.0, 377.0, - 378.0, 379.0, 128.0, 129.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, - 138.0, 139.0, 140.0, 141.0, 142.0, 143.0, 384.0, 385.0, 386.0, 387.0, 388.0, 389.0, - 390.0, 391.0, 392.0, 393.0, 394.0, 395.0, 144.0, 145.0, 146.0, 147.0, 148.0, 149.0, - 150.0, 151.0, 152.0, 153.0, 154.0, 155.0, 156.0, 157.0, 158.0, 159.0, 400.0, 401.0, - 402.0, 403.0, 404.0, 405.0, 406.0, 407.0, 408.0, 409.0, 410.0, 411.0, 160.0, 161.0, - 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 169.0, 170.0, 171.0, 172.0, 173.0, - 174.0, 175.0, 416.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0, 425.0, - 426.0, 427.0, 176.0, 177.0, 178.0, 179.0, 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, - 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, 432.0, 433.0, 434.0, 435.0, 436.0, 437.0, - 438.0, 439.0, 440.0, 441.0, 442.0, 443.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, - 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 448.0, 449.0, - 450.0, 451.0, 452.0, 453.0, 454.0, 455.0, 456.0, 457.0, 458.0, 459.0, 208.0, 209.0, - 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, - 222.0, 223.0, 464.0, 465.0, 466.0, 467.0, 468.0, 469.0, 470.0, 471.0, 472.0, 473.0, - 474.0, 475.0, 224.0, 225.0, 226.0, 227.0, 228.0, 229.0, 230.0, 231.0, 232.0, 233.0, - 234.0, 235.0, 236.0, 237.0, 238.0, 239.0, 480.0, 481.0, 482.0, 483.0, 484.0, 485.0, - 486.0, 487.0, 488.0, 489.0, 490.0, 491.0, 240.0, 241.0, 242.0, 243.0, 244.0, 245.0, - 246.0, 247.0, 248.0, 249.0, 250.0, 251.0, 252.0, 253.0, 254.0, 255.0, 496.0, 497.0, - 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, 504.0, 505.0, 506.0, 507.0, - ], - device, - None, - ); -} - -/// Exported test -pub fn cmma_write_output_warp_vertical_out_of_bounds_test(device: &R::Device) { - write_output_test_case::( - DimsTestCase { - m: 14, - k: 16, - n: 32, - }, - CmmaConfig { - b_mn: 32, - b_k: 16, - write_out_strategy: WriteOutStrategy::LargeSmem, - cube_dispatch: CubeDispatchStrategy::ColMajor, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, - 268.0, 269.0, 270.0, 271.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, - 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 272.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, - 279.0, 280.0, 281.0, 282.0, 283.0, 284.0, 285.0, 286.0, 287.0, 32.0, 33.0, 34.0, 35.0, - 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 288.0, 289.0, - 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 297.0, 298.0, 299.0, 300.0, 301.0, - 302.0, 303.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, - 60.0, 61.0, 62.0, 63.0, 304.0, 305.0, 306.0, 307.0, 308.0, 309.0, 310.0, 311.0, 312.0, - 313.0, 314.0, 315.0, 316.0, 317.0, 318.0, 319.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, - 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 320.0, 321.0, 322.0, 323.0, - 324.0, 325.0, 326.0, 327.0, 328.0, 329.0, 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, - 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, - 94.0, 95.0, 336.0, 337.0, 338.0, 339.0, 340.0, 341.0, 342.0, 343.0, 344.0, 345.0, - 346.0, 347.0, 348.0, 349.0, 350.0, 351.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, - 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 352.0, 353.0, 354.0, - 355.0, 356.0, 357.0, 358.0, 359.0, 360.0, 361.0, 362.0, 363.0, 364.0, 365.0, 366.0, - 367.0, 112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, - 123.0, 124.0, 125.0, 126.0, 127.0, 368.0, 369.0, 370.0, 371.0, 372.0, 373.0, 374.0, - 375.0, 376.0, 377.0, 378.0, 379.0, 380.0, 381.0, 382.0, 383.0, 128.0, 129.0, 130.0, - 131.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, - 143.0, 384.0, 385.0, 386.0, 387.0, 388.0, 389.0, 390.0, 391.0, 392.0, 393.0, 394.0, - 395.0, 396.0, 397.0, 398.0, 399.0, 144.0, 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, - 151.0, 152.0, 153.0, 154.0, 155.0, 156.0, 157.0, 158.0, 159.0, 400.0, 401.0, 402.0, - 403.0, 404.0, 405.0, 406.0, 407.0, 408.0, 409.0, 410.0, 411.0, 412.0, 413.0, 414.0, - 415.0, 160.0, 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 169.0, 170.0, - 171.0, 172.0, 173.0, 174.0, 175.0, 416.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, - 423.0, 424.0, 425.0, 426.0, 427.0, 428.0, 429.0, 430.0, 431.0, 176.0, 177.0, 178.0, - 179.0, 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, - 191.0, 432.0, 433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0, - 443.0, 444.0, 445.0, 446.0, 447.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, - 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 448.0, 449.0, 450.0, - 451.0, 452.0, 453.0, 454.0, 455.0, 456.0, 457.0, 458.0, 459.0, 460.0, 461.0, 462.0, - 463.0, 208.0, 209.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, - 219.0, 220.0, 221.0, 222.0, 223.0, 464.0, 465.0, 466.0, 467.0, 468.0, 469.0, 470.0, - 471.0, 472.0, 473.0, 474.0, 475.0, 476.0, 477.0, 478.0, 479.0, - ], - device, - None, - ); -} - -/// Exported test -pub fn cmma_write_output_warp_whole_out_of_bounds_test(device: &R::Device) { - write_output_test_case::( - DimsTestCase { - m: 14, - k: 16, - n: 28, - }, - CmmaConfig { - b_mn: 32, - b_k: 16, - write_out_strategy: WriteOutStrategy::LargeSmem, - cube_dispatch: CubeDispatchStrategy::ColMajor, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, - 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, - 30.0, 31.0, 272.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0, 281.0, - 282.0, 283.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, - 44.0, 45.0, 46.0, 47.0, 288.0, 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, - 297.0, 298.0, 299.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, - 59.0, 60.0, 61.0, 62.0, 63.0, 304.0, 305.0, 306.0, 307.0, 308.0, 309.0, 310.0, 311.0, - 312.0, 313.0, 314.0, 315.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, - 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 320.0, 321.0, 322.0, 323.0, 324.0, 325.0, 326.0, - 327.0, 328.0, 329.0, 330.0, 331.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, - 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 336.0, 337.0, 338.0, 339.0, 340.0, - 341.0, 342.0, 343.0, 344.0, 345.0, 346.0, 347.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, - 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 352.0, 353.0, - 354.0, 355.0, 356.0, 357.0, 358.0, 359.0, 360.0, 361.0, 362.0, 363.0, 112.0, 113.0, - 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, - 126.0, 127.0, 368.0, 369.0, 370.0, 371.0, 372.0, 373.0, 374.0, 375.0, 376.0, 377.0, - 378.0, 379.0, 128.0, 129.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, - 138.0, 139.0, 140.0, 141.0, 142.0, 143.0, 384.0, 385.0, 386.0, 387.0, 388.0, 389.0, - 390.0, 391.0, 392.0, 393.0, 394.0, 395.0, 144.0, 145.0, 146.0, 147.0, 148.0, 149.0, - 150.0, 151.0, 152.0, 153.0, 154.0, 155.0, 156.0, 157.0, 158.0, 159.0, 400.0, 401.0, - 402.0, 403.0, 404.0, 405.0, 406.0, 407.0, 408.0, 409.0, 410.0, 411.0, 160.0, 161.0, - 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, 168.0, 169.0, 170.0, 171.0, 172.0, 173.0, - 174.0, 175.0, 416.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0, 425.0, - 426.0, 427.0, 176.0, 177.0, 178.0, 179.0, 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, - 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, 432.0, 433.0, 434.0, 435.0, 436.0, 437.0, - 438.0, 439.0, 440.0, 441.0, 442.0, 443.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, - 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 448.0, 449.0, - 450.0, 451.0, 452.0, 453.0, 454.0, 455.0, 456.0, 457.0, 458.0, 459.0, 208.0, 209.0, - 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, - 222.0, 223.0, 464.0, 465.0, 466.0, 467.0, 468.0, 469.0, 470.0, 471.0, 472.0, 473.0, - 474.0, 475.0, - ], - device, - None, - ); -} - -/// Exported test -pub fn cmma_write_output_second_warp_test(device: &R::Device) { - write_output_test_case::( - DimsTestCase { - m: 16, - k: 16, - n: 64, - }, - CmmaConfig { - b_mn: 32, - b_k: 16, - write_out_strategy: WriteOutStrategy::LargeSmem, - cube_dispatch: CubeDispatchStrategy::ColMajor, - ..Default::default() - }, - &[ - 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0, 13.0, 14.0, 15.0, - 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, 264.0, 265.0, 266.0, 267.0, - 268.0, 269.0, 270.0, 271.0, 0.0, 1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, - 11.0, 12.0, 13.0, 14.0, 15.0, 256.0, 257.0, 258.0, 259.0, 260.0, 261.0, 262.0, 263.0, - 264.0, 265.0, 266.0, 267.0, 268.0, 269.0, 270.0, 271.0, 16.0, 17.0, 18.0, 19.0, 20.0, - 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, 29.0, 30.0, 31.0, 272.0, 273.0, 274.0, - 275.0, 276.0, 277.0, 278.0, 279.0, 280.0, 281.0, 282.0, 283.0, 284.0, 285.0, 286.0, - 287.0, 16.0, 17.0, 18.0, 19.0, 20.0, 21.0, 22.0, 23.0, 24.0, 25.0, 26.0, 27.0, 28.0, - 29.0, 30.0, 31.0, 272.0, 273.0, 274.0, 275.0, 276.0, 277.0, 278.0, 279.0, 280.0, 281.0, - 282.0, 283.0, 284.0, 285.0, 286.0, 287.0, 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, - 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, 46.0, 47.0, 288.0, 289.0, 290.0, 291.0, - 292.0, 293.0, 294.0, 295.0, 296.0, 297.0, 298.0, 299.0, 300.0, 301.0, 302.0, 303.0, - 32.0, 33.0, 34.0, 35.0, 36.0, 37.0, 38.0, 39.0, 40.0, 41.0, 42.0, 43.0, 44.0, 45.0, - 46.0, 47.0, 288.0, 289.0, 290.0, 291.0, 292.0, 293.0, 294.0, 295.0, 296.0, 297.0, - 298.0, 299.0, 300.0, 301.0, 302.0, 303.0, 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, - 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, 62.0, 63.0, 304.0, 305.0, 306.0, 307.0, - 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, 314.0, 315.0, 316.0, 317.0, 318.0, 319.0, - 48.0, 49.0, 50.0, 51.0, 52.0, 53.0, 54.0, 55.0, 56.0, 57.0, 58.0, 59.0, 60.0, 61.0, - 62.0, 63.0, 304.0, 305.0, 306.0, 307.0, 308.0, 309.0, 310.0, 311.0, 312.0, 313.0, - 314.0, 315.0, 316.0, 317.0, 318.0, 319.0, 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, - 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, 78.0, 79.0, 320.0, 321.0, 322.0, 323.0, - 324.0, 325.0, 326.0, 327.0, 328.0, 329.0, 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, - 64.0, 65.0, 66.0, 67.0, 68.0, 69.0, 70.0, 71.0, 72.0, 73.0, 74.0, 75.0, 76.0, 77.0, - 78.0, 79.0, 320.0, 321.0, 322.0, 323.0, 324.0, 325.0, 326.0, 327.0, 328.0, 329.0, - 330.0, 331.0, 332.0, 333.0, 334.0, 335.0, 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, - 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, 94.0, 95.0, 336.0, 337.0, 338.0, 339.0, - 340.0, 341.0, 342.0, 343.0, 344.0, 345.0, 346.0, 347.0, 348.0, 349.0, 350.0, 351.0, - 80.0, 81.0, 82.0, 83.0, 84.0, 85.0, 86.0, 87.0, 88.0, 89.0, 90.0, 91.0, 92.0, 93.0, - 94.0, 95.0, 336.0, 337.0, 338.0, 339.0, 340.0, 341.0, 342.0, 343.0, 344.0, 345.0, - 346.0, 347.0, 348.0, 349.0, 350.0, 351.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, - 103.0, 104.0, 105.0, 106.0, 107.0, 108.0, 109.0, 110.0, 111.0, 352.0, 353.0, 354.0, - 355.0, 356.0, 357.0, 358.0, 359.0, 360.0, 361.0, 362.0, 363.0, 364.0, 365.0, 366.0, - 367.0, 96.0, 97.0, 98.0, 99.0, 100.0, 101.0, 102.0, 103.0, 104.0, 105.0, 106.0, 107.0, - 108.0, 109.0, 110.0, 111.0, 352.0, 353.0, 354.0, 355.0, 356.0, 357.0, 358.0, 359.0, - 360.0, 361.0, 362.0, 363.0, 364.0, 365.0, 366.0, 367.0, 112.0, 113.0, 114.0, 115.0, - 116.0, 117.0, 118.0, 119.0, 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0, - 368.0, 369.0, 370.0, 371.0, 372.0, 373.0, 374.0, 375.0, 376.0, 377.0, 378.0, 379.0, - 380.0, 381.0, 382.0, 383.0, 112.0, 113.0, 114.0, 115.0, 116.0, 117.0, 118.0, 119.0, - 120.0, 121.0, 122.0, 123.0, 124.0, 125.0, 126.0, 127.0, 368.0, 369.0, 370.0, 371.0, - 372.0, 373.0, 374.0, 375.0, 376.0, 377.0, 378.0, 379.0, 380.0, 381.0, 382.0, 383.0, - 128.0, 129.0, 130.0, 131.0, 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, - 140.0, 141.0, 142.0, 143.0, 384.0, 385.0, 386.0, 387.0, 388.0, 389.0, 390.0, 391.0, - 392.0, 393.0, 394.0, 395.0, 396.0, 397.0, 398.0, 399.0, 128.0, 129.0, 130.0, 131.0, - 132.0, 133.0, 134.0, 135.0, 136.0, 137.0, 138.0, 139.0, 140.0, 141.0, 142.0, 143.0, - 384.0, 385.0, 386.0, 387.0, 388.0, 389.0, 390.0, 391.0, 392.0, 393.0, 394.0, 395.0, - 396.0, 397.0, 398.0, 399.0, 144.0, 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, - 152.0, 153.0, 154.0, 155.0, 156.0, 157.0, 158.0, 159.0, 400.0, 401.0, 402.0, 403.0, - 404.0, 405.0, 406.0, 407.0, 408.0, 409.0, 410.0, 411.0, 412.0, 413.0, 414.0, 415.0, - 144.0, 145.0, 146.0, 147.0, 148.0, 149.0, 150.0, 151.0, 152.0, 153.0, 154.0, 155.0, - 156.0, 157.0, 158.0, 159.0, 400.0, 401.0, 402.0, 403.0, 404.0, 405.0, 406.0, 407.0, - 408.0, 409.0, 410.0, 411.0, 412.0, 413.0, 414.0, 415.0, 160.0, 161.0, 162.0, 163.0, - 164.0, 165.0, 166.0, 167.0, 168.0, 169.0, 170.0, 171.0, 172.0, 173.0, 174.0, 175.0, - 416.0, 417.0, 418.0, 419.0, 420.0, 421.0, 422.0, 423.0, 424.0, 425.0, 426.0, 427.0, - 428.0, 429.0, 430.0, 431.0, 160.0, 161.0, 162.0, 163.0, 164.0, 165.0, 166.0, 167.0, - 168.0, 169.0, 170.0, 171.0, 172.0, 173.0, 174.0, 175.0, 416.0, 417.0, 418.0, 419.0, - 420.0, 421.0, 422.0, 423.0, 424.0, 425.0, 426.0, 427.0, 428.0, 429.0, 430.0, 431.0, - 176.0, 177.0, 178.0, 179.0, 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, - 188.0, 189.0, 190.0, 191.0, 432.0, 433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, - 440.0, 441.0, 442.0, 443.0, 444.0, 445.0, 446.0, 447.0, 176.0, 177.0, 178.0, 179.0, - 180.0, 181.0, 182.0, 183.0, 184.0, 185.0, 186.0, 187.0, 188.0, 189.0, 190.0, 191.0, - 432.0, 433.0, 434.0, 435.0, 436.0, 437.0, 438.0, 439.0, 440.0, 441.0, 442.0, 443.0, - 444.0, 445.0, 446.0, 447.0, 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, - 200.0, 201.0, 202.0, 203.0, 204.0, 205.0, 206.0, 207.0, 448.0, 449.0, 450.0, 451.0, - 452.0, 453.0, 454.0, 455.0, 456.0, 457.0, 458.0, 459.0, 460.0, 461.0, 462.0, 463.0, - 192.0, 193.0, 194.0, 195.0, 196.0, 197.0, 198.0, 199.0, 200.0, 201.0, 202.0, 203.0, - 204.0, 205.0, 206.0, 207.0, 448.0, 449.0, 450.0, 451.0, 452.0, 453.0, 454.0, 455.0, - 456.0, 457.0, 458.0, 459.0, 460.0, 461.0, 462.0, 463.0, 208.0, 209.0, 210.0, 211.0, - 212.0, 213.0, 214.0, 215.0, 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0, - 464.0, 465.0, 466.0, 467.0, 468.0, 469.0, 470.0, 471.0, 472.0, 473.0, 474.0, 475.0, - 476.0, 477.0, 478.0, 479.0, 208.0, 209.0, 210.0, 211.0, 212.0, 213.0, 214.0, 215.0, - 216.0, 217.0, 218.0, 219.0, 220.0, 221.0, 222.0, 223.0, 464.0, 465.0, 466.0, 467.0, - 468.0, 469.0, 470.0, 471.0, 472.0, 473.0, 474.0, 475.0, 476.0, 477.0, 478.0, 479.0, - 224.0, 225.0, 226.0, 227.0, 228.0, 229.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, - 236.0, 237.0, 238.0, 239.0, 480.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, 487.0, - 488.0, 489.0, 490.0, 491.0, 492.0, 493.0, 494.0, 495.0, 224.0, 225.0, 226.0, 227.0, - 228.0, 229.0, 230.0, 231.0, 232.0, 233.0, 234.0, 235.0, 236.0, 237.0, 238.0, 239.0, - 480.0, 481.0, 482.0, 483.0, 484.0, 485.0, 486.0, 487.0, 488.0, 489.0, 490.0, 491.0, - 492.0, 493.0, 494.0, 495.0, 240.0, 241.0, 242.0, 243.0, 244.0, 245.0, 246.0, 247.0, - 248.0, 249.0, 250.0, 251.0, 252.0, 253.0, 254.0, 255.0, 496.0, 497.0, 498.0, 499.0, - 500.0, 501.0, 502.0, 503.0, 504.0, 505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, - 240.0, 241.0, 242.0, 243.0, 244.0, 245.0, 246.0, 247.0, 248.0, 249.0, 250.0, 251.0, - 252.0, 253.0, 254.0, 255.0, 496.0, 497.0, 498.0, 499.0, 500.0, 501.0, 502.0, 503.0, - 504.0, 505.0, 506.0, 507.0, 508.0, 509.0, 510.0, 511.0, - ], - device, - None, - ); -} - -/// Exported test -pub fn cmma_write_output_third_fourth_warps_test(device: &R::Device) { - write_output_test_case::( - DimsTestCase { - m: 32, - k: 16, - n: 64, - }, - CmmaConfig { - b_mn: 32, - b_k: 16, - write_out_strategy: WriteOutStrategy::LargeSmem, - cube_dispatch: CubeDispatchStrategy::ColMajor, - ..Default::default() - }, - &[ - 512.0, 513.0, 514.0, 515.0, 516.0, 517.0, 518.0, 519.0, 520.0, 521.0, 522.0, 523.0, - 524.0, 525.0, 526.0, 527.0, 768.0, 769.0, 770.0, 771.0, 772.0, 773.0, 774.0, 775.0, - 776.0, 777.0, 778.0, 779.0, 780.0, 781.0, 782.0, 783.0, 512.0, 513.0, 514.0, 515.0, - 516.0, 517.0, 518.0, 519.0, 520.0, 521.0, 522.0, 523.0, 524.0, 525.0, 526.0, 527.0, - 768.0, 769.0, 770.0, 771.0, 772.0, 773.0, 774.0, 775.0, 776.0, 777.0, 778.0, 779.0, - 780.0, 781.0, 782.0, 783.0, 528.0, 529.0, 530.0, 531.0, 532.0, 533.0, 534.0, 535.0, - 536.0, 537.0, 538.0, 539.0, 540.0, 541.0, 542.0, 543.0, 784.0, 785.0, 786.0, 787.0, - 788.0, 789.0, 790.0, 791.0, 792.0, 793.0, 794.0, 795.0, 796.0, 797.0, 798.0, 799.0, - 528.0, 529.0, 530.0, 531.0, 532.0, 533.0, 534.0, 535.0, 536.0, 537.0, 538.0, 539.0, - 540.0, 541.0, 542.0, 543.0, 784.0, 785.0, 786.0, 787.0, 788.0, 789.0, 790.0, 791.0, - 792.0, 793.0, 794.0, 795.0, 796.0, 797.0, 798.0, 799.0, 544.0, 545.0, 546.0, 547.0, - 548.0, 549.0, 550.0, 551.0, 552.0, 553.0, 554.0, 555.0, 556.0, 557.0, 558.0, 559.0, - 800.0, 801.0, 802.0, 803.0, 804.0, 805.0, 806.0, 807.0, 808.0, 809.0, 810.0, 811.0, - 812.0, 813.0, 814.0, 815.0, 544.0, 545.0, 546.0, 547.0, 548.0, 549.0, 550.0, 551.0, - 552.0, 553.0, 554.0, 555.0, 556.0, 557.0, 558.0, 559.0, 800.0, 801.0, 802.0, 803.0, - 804.0, 805.0, 806.0, 807.0, 808.0, 809.0, 810.0, 811.0, 812.0, 813.0, 814.0, 815.0, - 560.0, 561.0, 562.0, 563.0, 564.0, 565.0, 566.0, 567.0, 568.0, 569.0, 570.0, 571.0, - 572.0, 573.0, 574.0, 575.0, 816.0, 817.0, 818.0, 819.0, 820.0, 821.0, 822.0, 823.0, - 824.0, 825.0, 826.0, 827.0, 828.0, 829.0, 830.0, 831.0, 560.0, 561.0, 562.0, 563.0, - 564.0, 565.0, 566.0, 567.0, 568.0, 569.0, 570.0, 571.0, 572.0, 573.0, 574.0, 575.0, - 816.0, 817.0, 818.0, 819.0, 820.0, 821.0, 822.0, 823.0, 824.0, 825.0, 826.0, 827.0, - 828.0, 829.0, 830.0, 831.0, 576.0, 577.0, 578.0, 579.0, 580.0, 581.0, 582.0, 583.0, - 584.0, 585.0, 586.0, 587.0, 588.0, 589.0, 590.0, 591.0, 832.0, 833.0, 834.0, 835.0, - 836.0, 837.0, 838.0, 839.0, 840.0, 841.0, 842.0, 843.0, 844.0, 845.0, 846.0, 847.0, - 576.0, 577.0, 578.0, 579.0, 580.0, 581.0, 582.0, 583.0, 584.0, 585.0, 586.0, 587.0, - 588.0, 589.0, 590.0, 591.0, 832.0, 833.0, 834.0, 835.0, 836.0, 837.0, 838.0, 839.0, - 840.0, 841.0, 842.0, 843.0, 844.0, 845.0, 846.0, 847.0, 592.0, 593.0, 594.0, 595.0, - 596.0, 597.0, 598.0, 599.0, 600.0, 601.0, 602.0, 603.0, 604.0, 605.0, 606.0, 607.0, - 848.0, 849.0, 850.0, 851.0, 852.0, 853.0, 854.0, 855.0, 856.0, 857.0, 858.0, 859.0, - 860.0, 861.0, 862.0, 863.0, 592.0, 593.0, 594.0, 595.0, 596.0, 597.0, 598.0, 599.0, - 600.0, 601.0, 602.0, 603.0, 604.0, 605.0, 606.0, 607.0, 848.0, 849.0, 850.0, 851.0, - 852.0, 853.0, 854.0, 855.0, 856.0, 857.0, 858.0, 859.0, 860.0, 861.0, 862.0, 863.0, - 608.0, 609.0, 610.0, 611.0, 612.0, 613.0, 614.0, 615.0, 616.0, 617.0, 618.0, 619.0, - 620.0, 621.0, 622.0, 623.0, 864.0, 865.0, 866.0, 867.0, 868.0, 869.0, 870.0, 871.0, - 872.0, 873.0, 874.0, 875.0, 876.0, 877.0, 878.0, 879.0, 608.0, 609.0, 610.0, 611.0, - 612.0, 613.0, 614.0, 615.0, 616.0, 617.0, 618.0, 619.0, 620.0, 621.0, 622.0, 623.0, - 864.0, 865.0, 866.0, 867.0, 868.0, 869.0, 870.0, 871.0, 872.0, 873.0, 874.0, 875.0, - 876.0, 877.0, 878.0, 879.0, 624.0, 625.0, 626.0, 627.0, 628.0, 629.0, 630.0, 631.0, - 632.0, 633.0, 634.0, 635.0, 636.0, 637.0, 638.0, 639.0, 880.0, 881.0, 882.0, 883.0, - 884.0, 885.0, 886.0, 887.0, 888.0, 889.0, 890.0, 891.0, 892.0, 893.0, 894.0, 895.0, - 624.0, 625.0, 626.0, 627.0, 628.0, 629.0, 630.0, 631.0, 632.0, 633.0, 634.0, 635.0, - 636.0, 637.0, 638.0, 639.0, 880.0, 881.0, 882.0, 883.0, 884.0, 885.0, 886.0, 887.0, - 888.0, 889.0, 890.0, 891.0, 892.0, 893.0, 894.0, 895.0, 640.0, 641.0, 642.0, 643.0, - 644.0, 645.0, 646.0, 647.0, 648.0, 649.0, 650.0, 651.0, 652.0, 653.0, 654.0, 655.0, - 896.0, 897.0, 898.0, 899.0, 900.0, 901.0, 902.0, 903.0, 904.0, 905.0, 906.0, 907.0, - 908.0, 909.0, 910.0, 911.0, 640.0, 641.0, 642.0, 643.0, 644.0, 645.0, 646.0, 647.0, - 648.0, 649.0, 650.0, 651.0, 652.0, 653.0, 654.0, 655.0, 896.0, 897.0, 898.0, 899.0, - 900.0, 901.0, 902.0, 903.0, 904.0, 905.0, 906.0, 907.0, 908.0, 909.0, 910.0, 911.0, - 656.0, 657.0, 658.0, 659.0, 660.0, 661.0, 662.0, 663.0, 664.0, 665.0, 666.0, 667.0, - 668.0, 669.0, 670.0, 671.0, 912.0, 913.0, 914.0, 915.0, 916.0, 917.0, 918.0, 919.0, - 920.0, 921.0, 922.0, 923.0, 924.0, 925.0, 926.0, 927.0, 656.0, 657.0, 658.0, 659.0, - 660.0, 661.0, 662.0, 663.0, 664.0, 665.0, 666.0, 667.0, 668.0, 669.0, 670.0, 671.0, - 912.0, 913.0, 914.0, 915.0, 916.0, 917.0, 918.0, 919.0, 920.0, 921.0, 922.0, 923.0, - 924.0, 925.0, 926.0, 927.0, 672.0, 673.0, 674.0, 675.0, 676.0, 677.0, 678.0, 679.0, - 680.0, 681.0, 682.0, 683.0, 684.0, 685.0, 686.0, 687.0, 928.0, 929.0, 930.0, 931.0, - 932.0, 933.0, 934.0, 935.0, 936.0, 937.0, 938.0, 939.0, 940.0, 941.0, 942.0, 943.0, - 672.0, 673.0, 674.0, 675.0, 676.0, 677.0, 678.0, 679.0, 680.0, 681.0, 682.0, 683.0, - 684.0, 685.0, 686.0, 687.0, 928.0, 929.0, 930.0, 931.0, 932.0, 933.0, 934.0, 935.0, - 936.0, 937.0, 938.0, 939.0, 940.0, 941.0, 942.0, 943.0, 688.0, 689.0, 690.0, 691.0, - 692.0, 693.0, 694.0, 695.0, 696.0, 697.0, 698.0, 699.0, 700.0, 701.0, 702.0, 703.0, - 944.0, 945.0, 946.0, 947.0, 948.0, 949.0, 950.0, 951.0, 952.0, 953.0, 954.0, 955.0, - 956.0, 957.0, 958.0, 959.0, 688.0, 689.0, 690.0, 691.0, 692.0, 693.0, 694.0, 695.0, - 696.0, 697.0, 698.0, 699.0, 700.0, 701.0, 702.0, 703.0, 944.0, 945.0, 946.0, 947.0, - 948.0, 949.0, 950.0, 951.0, 952.0, 953.0, 954.0, 955.0, 956.0, 957.0, 958.0, 959.0, - 704.0, 705.0, 706.0, 707.0, 708.0, 709.0, 710.0, 711.0, 712.0, 713.0, 714.0, 715.0, - 716.0, 717.0, 718.0, 719.0, 960.0, 961.0, 962.0, 963.0, 964.0, 965.0, 966.0, 967.0, - 968.0, 969.0, 970.0, 971.0, 972.0, 973.0, 974.0, 975.0, 704.0, 705.0, 706.0, 707.0, - 708.0, 709.0, 710.0, 711.0, 712.0, 713.0, 714.0, 715.0, 716.0, 717.0, 718.0, 719.0, - 960.0, 961.0, 962.0, 963.0, 964.0, 965.0, 966.0, 967.0, 968.0, 969.0, 970.0, 971.0, - 972.0, 973.0, 974.0, 975.0, 720.0, 721.0, 722.0, 723.0, 724.0, 725.0, 726.0, 727.0, - 728.0, 729.0, 730.0, 731.0, 732.0, 733.0, 734.0, 735.0, 976.0, 977.0, 978.0, 979.0, - 980.0, 981.0, 982.0, 983.0, 984.0, 985.0, 986.0, 987.0, 988.0, 989.0, 990.0, 991.0, - 720.0, 721.0, 722.0, 723.0, 724.0, 725.0, 726.0, 727.0, 728.0, 729.0, 730.0, 731.0, - 732.0, 733.0, 734.0, 735.0, 976.0, 977.0, 978.0, 979.0, 980.0, 981.0, 982.0, 983.0, - 984.0, 985.0, 986.0, 987.0, 988.0, 989.0, 990.0, 991.0, 736.0, 737.0, 738.0, 739.0, - 740.0, 741.0, 742.0, 743.0, 744.0, 745.0, 746.0, 747.0, 748.0, 749.0, 750.0, 751.0, - 992.0, 993.0, 994.0, 995.0, 996.0, 997.0, 998.0, 999.0, 1000.0, 1001.0, 1002.0, 1003.0, - 1004.0, 1005.0, 1006.0, 1007.0, 736.0, 737.0, 738.0, 739.0, 740.0, 741.0, 742.0, 743.0, - 744.0, 745.0, 746.0, 747.0, 748.0, 749.0, 750.0, 751.0, 992.0, 993.0, 994.0, 995.0, - 996.0, 997.0, 998.0, 999.0, 1000.0, 1001.0, 1002.0, 1003.0, 1004.0, 1005.0, 1006.0, - 1007.0, 752.0, 753.0, 754.0, 755.0, 756.0, 757.0, 758.0, 759.0, 760.0, 761.0, 762.0, - 763.0, 764.0, 765.0, 766.0, 767.0, 1008.0, 1009.0, 1010.0, 1011.0, 1012.0, 1013.0, - 1014.0, 1015.0, 1016.0, 1017.0, 1018.0, 1019.0, 1020.0, 1021.0, 1022.0, 1023.0, 752.0, - 753.0, 754.0, 755.0, 756.0, 757.0, 758.0, 759.0, 760.0, 761.0, 762.0, 763.0, 764.0, - 765.0, 766.0, 767.0, 1008.0, 1009.0, 1010.0, 1011.0, 1012.0, 1013.0, 1014.0, 1015.0, - 1016.0, 1017.0, 1018.0, 1019.0, 1020.0, 1021.0, 1022.0, 1023.0, - ], - device, - Some(1024..2048), - ); -} diff --git a/crates/cubecl-linalg/src/matmul/tests/mod.rs b/crates/cubecl-linalg/src/matmul/tests/mod.rs index 09889ecb..7616e7a4 100644 --- a/crates/cubecl-linalg/src/matmul/tests/mod.rs +++ b/crates/cubecl-linalg/src/matmul/tests/mod.rs @@ -1,4 +1,3 @@ -pub mod cmma; pub mod matmul_tests; mod test_utils; pub mod tiling2d; diff --git a/crates/cubecl-linalg/src/matmul/tests/test_utils.rs b/crates/cubecl-linalg/src/matmul/tests/test_utils.rs index fd6824a0..eba3d81c 100644 --- a/crates/cubecl-linalg/src/matmul/tests/test_utils.rs +++ b/crates/cubecl-linalg/src/matmul/tests/test_utils.rs @@ -5,31 +5,12 @@ use cubecl_core::{ server::Handle, CubeElement, Feature, Runtime, }; -use half::f16; -use std::ops::Range; use crate::{ matmul::tiling2d::config::{CubeTiling2dConfig, Tiling2dConfig}, tensor::TensorHandle, }; -pub(crate) fn range_tensor_f16( - client: &ComputeClient, - x: usize, - y: usize, -) -> TensorHandle { - let n_elements = x * y; - - let mut data = Vec::with_capacity(n_elements); - for i in 0..n_elements { - data.push(half::f16::from_f32(i as f32)); - } - - let handle = client.create(cast_slice(&data)); - - TensorHandle::new_contiguous(vec![x, y], handle) -} - pub(crate) fn range_tensor( client: &ComputeClient, x: usize, @@ -140,18 +121,6 @@ pub(crate) fn assert_equals_approx( } } -pub(crate) fn assert_equals_range( - client: &ComputeClient, - output: Handle<::Server>, - expected: &[f32], - range: Range, -) { - let actual = client.read(output.binding()); - let actual = f32::from_bytes(&actual); - - assert_eq!(&actual[range], expected); -} - pub(crate) fn make_tiling2d_config(m: usize, k: usize, n: usize) -> CubeTiling2dConfig { let tiling2d_config = Tiling2dConfig { block_size_m: 8, diff --git a/crates/cubecl-linalg/src/tests/matmul/cmma/matmul_internal.rs b/crates/cubecl-linalg/src/tests/matmul/cmma/matmul_internal.rs deleted file mode 100644 index 0a6ade26..00000000 --- a/crates/cubecl-linalg/src/tests/matmul/cmma/matmul_internal.rs +++ /dev/null @@ -1,142 +0,0 @@ -#[allow(missing_docs)] -#[macro_export] -macro_rules! testgen_cmma_internal { - () => { - #[test] - pub fn cmma_compute_loop_block_equal_tile_test() { - tests::cmma::compute_loop::cmma_compute_loop_block_equal_tile_test::(&Default::default()) - } - - #[test] - pub fn cmma_compute_loop_block_larger_than_tile_test() { - tests::cmma::compute_loop::cmma_compute_loop_block_larger_than_tile_test::(&Default::default()) - } - - #[test] - pub fn cmma_compute_loop_b_mn_larger_than_b_k_test() { - tests::cmma::compute_loop::cmma_compute_loop_b_mn_larger_than_b_k_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_load_shared_memory_lhs_warp_test() { - tests::cmma::load_shared_memory::load_shared_memory_lhs_warp_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_load_shared_memory_lhs_vertical_out_of_bound_warp_test() { - tests::cmma::load_shared_memory::load_shared_memory_lhs_vertical_out_of_bound_warp_test::< - TestRuntime, - >(&Default::default()) - } - - #[test] - pub fn cmma_load_shared_memory_lhs_horizontal_out_of_bound_warp_test() { - tests::cmma::load_shared_memory::load_shared_memory_lhs_horizontal_out_of_bound_warp_test::< - TestRuntime, - >(&Default::default()) - } - - #[test] - pub fn cmma_load_shared_memory_lhs_whole_out_of_bound_warp_test() { - tests::cmma::load_shared_memory::load_shared_memory_lhs_whole_out_of_bound_warp_test::< - TestRuntime, - >(&Default::default()) - } - - #[test] - pub fn cmma_load_shared_memory_rhs_warp_test() { - tests::cmma::load_shared_memory::load_shared_memory_rhs_warp_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_load_shared_memory_lhs_second_warp_test() { - tests::cmma::load_shared_memory::load_shared_memory_lhs_second_warp_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_load_shared_memory_rhs_second_warp_test() { - tests::cmma::load_shared_memory::load_shared_memory_rhs_second_warp_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_load_shared_memory_lhs_third_warp_test() { - tests::cmma::load_shared_memory::load_shared_memory_lhs_third_warp_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_load_shared_memory_rhs_third_warp_test() { - tests::cmma::load_shared_memory::load_shared_memory_rhs_third_warp_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_load_shared_memory_lhs_k_offset_test() { - tests::cmma::load_shared_memory::load_shared_memory_lhs_k_offset_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_load_shared_memory_rhs_k_offset_test() { - tests::cmma::load_shared_memory::load_shared_memory_rhs_k_offset_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_write_output_warp_test() { - tests::cmma::write_output::cmma_write_output_warp_test::(&Default::default()) - } - - #[test] - pub fn cmma_write_output_warp_horizontal_out_of_bounds_test() { - tests::cmma::write_output::cmma_write_output_warp_horizontal_out_of_bounds_test::(&Default::default()) - } - - #[test] - pub fn cmma_write_output_warp_vertical_out_of_bounds_test() { - tests::cmma::write_output::cmma_write_output_warp_vertical_out_of_bounds_test::(&Default::default()) - } - - #[test] - pub fn cmma_write_output_warp_whole_out_of_bounds_test() { - tests::cmma::write_output::cmma_write_output_warp_whole_out_of_bounds_test::(&Default::default()) - } - - - #[test] - pub fn cmma_write_output_second_warp_test() { - tests::cmma::write_output::cmma_write_output_second_warp_test::( - &Default::default(), - ) - } - - #[test] - pub fn cmma_write_output_third_fourth_warps_test() { - tests::cmma::write_output::cmma_write_output_third_fourth_warps_test::( - &Default::default(), - ) - } - - #[test] - pub fn load_shared_memory_rhs_larger_block_test() { - tests::cmma::load_shared_memory::load_shared_memory_rhs_larger_block_test::( - &Default::default(), - ) - } - - }; -} diff --git a/crates/cubecl-linalg/src/tests/matmul/cmma/mod.rs b/crates/cubecl-linalg/src/tests/matmul/cmma/mod.rs index a338fa79..7f6ea232 100644 --- a/crates/cubecl-linalg/src/tests/matmul/cmma/mod.rs +++ b/crates/cubecl-linalg/src/tests/matmul/cmma/mod.rs @@ -1,7 +1,6 @@ #![allow(missing_docs)] pub mod matmul; -pub mod matmul_internal; #[macro_export] macro_rules! testgen_cmma { @@ -9,6 +8,5 @@ macro_rules! testgen_cmma { use super::*; cubecl_linalg::testgen_cmma_matmul!(); - cubecl_linalg::testgen_cmma_internal!(); }; }