Skip to content

Commit

Permalink
Cmma/invert k n loops (#131)
Browse files Browse the repository at this point in the history
  • Loading branch information
louisfd authored Sep 19, 2024
1 parent a090a47 commit 2fdfb51
Show file tree
Hide file tree
Showing 17 changed files with 241 additions and 1,988 deletions.
10 changes: 5 additions & 5 deletions crates/cubecl-linalg/src/matmul/cmma/base.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,7 @@ pub(crate) struct Offsets {
}

#[derive(CubeType)]
pub(crate) struct CmmaMatrices<F: Float, FC: Float> {
pub(crate) struct Fragments<F: Float, FC: Float> {
pub accumulators: Sequence<cmma::Matrix<F>>,
pub lhs: cmma::Matrix<FC>,
pub rhs: cmma::Matrix<FC>,
Expand Down Expand Up @@ -127,9 +127,9 @@ fn calculate_offsets<F: Float>(

#[cube]
pub(crate) fn get_row_col(#[comptime] comptime_info: ComptimeCmmaInfo) -> (u32, u32) {
if comptime_info.cube_dispatch == 0 {
if comptime_info.cube_dispatch_strategy == 0 {
RowMajorCubeDispatch::get_row_col(comptime_info)
} else if comptime_info.cube_dispatch == 1 {
} else if comptime_info.cube_dispatch_strategy == 1 {
ColMajorCubeDispatch::get_row_col(comptime_info)
} else {
SwizzleCubeDispatch::get_row_col(comptime_info)
Expand All @@ -151,7 +151,7 @@ fn make_shared_memories<FC: Float>(#[comptime] config: ComptimeCmmaInfo) -> Shar
#[cube]
pub(crate) fn make_cmma_matrices<F: Float, FC: Float>(
#[comptime] config: ComptimeCmmaInfo,
) -> CmmaMatrices<F, FC> {
) -> Fragments<F, FC> {
let num_accumulators = config.num_accumulators;
let mut accumulators = Sequence::<cmma::Matrix<F>>::new();

Expand Down Expand Up @@ -186,7 +186,7 @@ pub(crate) fn make_cmma_matrices<F: Float, FC: Float>(
cmma::MatrixLayout::RowMajor,
);

CmmaMatrices::<F, FC> {
Fragments::<F, FC> {
accumulators,
lhs,
rhs,
Expand Down
26 changes: 8 additions & 18 deletions crates/cubecl-linalg/src/matmul/cmma/block_loop.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@ use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use super::{
base::{CmmaMatrices, RuntimeCmmaInfo, SharedMemories},
compute_loop::compute_loop,
base::{Fragments, RuntimeCmmaInfo, SharedMemories},
compute_loop::base::compute_loop,
config::ComptimeCmmaInfo,
load_shared_memory::load_to_shared_memories,
write_output::{base::OutputWriter, large_smem::LargeSmemWriter, reuse_smem::ReuseSmemWriter},
Expand All @@ -15,12 +15,12 @@ pub(crate) fn block_loop<F: Float, FC: Float>(
rhs: &Tensor<F>,
out: &mut Tensor<F>,
shared_memories: SharedMemories<FC>,
mut cmma_matrices: CmmaMatrices<F, FC>,
mut fragments: Fragments<F, FC>,
runtime_info: RuntimeCmmaInfo,
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
let block_size_k = comptime_info.block_size_k;
let write_out_reuse_smem = comptime_info.write_out_reuse_smem;
let write_out_reuse_smem = comptime_info.write_out_strategy;

// Equals ceil(dims.k / block_size_k)
let dims = runtime_info.dims;
Expand All @@ -42,27 +42,17 @@ pub(crate) fn block_loop<F: Float, FC: Float>(

compute_loop::<F, FC>(
shared_memories,
&mut cmma_matrices,
&mut fragments,
runtime_info.ids,
comptime_info,
);

sync_units();
}

if write_out_reuse_smem {
ReuseSmemWriter::write_to_output(
out,
cmma_matrices.accumulators,
runtime_info,
comptime_info,
);
if write_out_reuse_smem == 0 {
LargeSmemWriter::write_to_output(out, fragments.accumulators, runtime_info, comptime_info);
} else {
LargeSmemWriter::write_to_output(
out,
cmma_matrices.accumulators,
runtime_info,
comptime_info,
);
ReuseSmemWriter::write_to_output(out, fragments.accumulators, runtime_info, comptime_info);
}
}
77 changes: 0 additions & 77 deletions crates/cubecl-linalg/src/matmul/cmma/compute_loop.rs

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use crate::matmul::cmma::{
base::{Fragments, Ids, SharedMemories},
compute_loop::base::load_into_fragment,
config::ComptimeCmmaInfo,
};

use super::base::ComputeLoop;

pub(crate) struct AllAccumulatorsFirstComputeLoop {}

#[cube]
impl ComputeLoop for AllAccumulatorsFirstComputeLoop {
fn compute_loop<F: Float, FC: Float>(
shared_memories: SharedMemories<FC>,
fragments: &mut Fragments<F, FC>,
ids: Ids,
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
// Comptime values
let block_size_k = comptime_info.block_size_k;
let block_size_n = comptime_info.block_size_n;
let tile_size = comptime_info.tile_size;
let unroll = comptime_info.unroll;
let num_accumulators = comptime_info.num_accumulators;
let num_buffers = block_size_k / tile_size;
let num_coop_per_row = (block_size_n / tile_size) / num_accumulators;

// Runtime values
let tile_row = ids.coop / num_coop_per_row;
let tile_col_base = (ids.coop % num_coop_per_row) * num_accumulators;

#[unroll(unroll)]
for buffer_iter in 0..num_buffers {
#[unroll]
for accumulator_iter in 0..num_accumulators {
load_into_fragment(
tile_row * num_buffers + buffer_iter,
shared_memories.lhs,
&fragments.lhs,
comptime_info,
);

load_into_fragment(
(tile_col_base + accumulator_iter) * num_buffers + buffer_iter,
shared_memories.rhs,
&fragments.rhs,
comptime_info,
);

let accumulator = &fragments.accumulators.index(accumulator_iter);
cmma::execute::<FC, FC, F, F>(
&fragments.lhs,
&fragments.rhs,
accumulator,
accumulator,
);
}
}
}
}
55 changes: 55 additions & 0 deletions crates/cubecl-linalg/src/matmul/cmma/compute_loop/base.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use crate::matmul::cmma::{
base::{Fragments, Ids, SharedMemories},
compute_loop::{
accumulators_first::AllAccumulatorsFirstComputeLoop,
buffers_first::AllBuffersFirstComputeLoop,
},
config::ComptimeCmmaInfo,
};

#[cube]
pub(crate) fn compute_loop<F: Float, FC: Float>(
shared_memories: SharedMemories<FC>,
fragments: &mut Fragments<F, FC>,
ids: Ids,
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
if comptime_info.compute_loop_order_strategy == 0 {
AllBuffersFirstComputeLoop::compute_loop(shared_memories, fragments, ids, comptime_info);
} else {
AllAccumulatorsFirstComputeLoop::compute_loop(
shared_memories,
fragments,
ids,
comptime_info,
);
}
}

#[cube]
pub(crate) trait ComputeLoop {
fn compute_loop<F: Float, FC: Float>(
shared_memories: SharedMemories<FC>,
fragments: &mut Fragments<F, FC>,
ids: Ids,
#[comptime] comptime_info: ComptimeCmmaInfo,
);
}

#[cube]
pub(crate) fn load_into_fragment<FC: Float>(
tile: u32,
smem: SharedMemory<FC>,
fragment: &cmma::Matrix<FC>,
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
let tile_size = comptime_info.tile_size;
let smem_stride = tile_size * tile_size;

let smem_pos = tile * smem_stride;
let slice = smem.slice(smem_pos, smem_pos + smem_stride);
cmma::load::<FC>(fragment, slice, 16);
}
63 changes: 63 additions & 0 deletions crates/cubecl-linalg/src/matmul/cmma/compute_loop/buffers_first.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
use cubecl_core as cubecl;
use cubecl_core::prelude::*;

use crate::matmul::cmma::{
base::{Fragments, Ids, SharedMemories},
compute_loop::base::load_into_fragment,
config::ComptimeCmmaInfo,
};

use super::base::ComputeLoop;

pub(crate) struct AllBuffersFirstComputeLoop {}

#[cube]
impl ComputeLoop for AllBuffersFirstComputeLoop {
fn compute_loop<F: Float, FC: Float>(
shared_memories: SharedMemories<FC>,
fragments: &mut Fragments<F, FC>,
ids: Ids,
#[comptime] comptime_info: ComptimeCmmaInfo,
) {
// Comptime values
let block_size_k = comptime_info.block_size_k;
let block_size_n = comptime_info.block_size_n;
let tile_size = comptime_info.tile_size;
let unroll = comptime_info.unroll;
let num_accumulators = comptime_info.num_accumulators;
let num_buffers = block_size_k / tile_size;
let num_coop_per_row = (block_size_n / tile_size) / num_accumulators;

// Runtime values
let tile_row = ids.coop / num_coop_per_row;
let tile_col_base = (ids.coop % num_coop_per_row) * num_accumulators;

#[unroll]
for accumulator_iter in 0..num_accumulators {
#[unroll(unroll)]
for buffer_iter in 0..num_buffers {
load_into_fragment(
tile_row * num_buffers + buffer_iter,
shared_memories.lhs,
&fragments.lhs,
comptime_info,
);

load_into_fragment(
(tile_col_base + accumulator_iter) * num_buffers + buffer_iter,
shared_memories.rhs,
&fragments.rhs,
comptime_info,
);

let accumulator = &fragments.accumulators.index(accumulator_iter);
cmma::execute::<FC, FC, F, F>(
&fragments.lhs,
&fragments.rhs,
accumulator,
accumulator,
);
}
}
}
}
3 changes: 3 additions & 0 deletions crates/cubecl-linalg/src/matmul/cmma/compute_loop/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
mod accumulators_first;
pub(crate) mod base;
mod buffers_first;
Loading

0 comments on commit 2fdfb51

Please sign in to comment.