Skip to content

Commit

Permalink
removing macros
Browse files Browse the repository at this point in the history
  • Loading branch information
jayhshah committed Jul 25, 2024
1 parent d762274 commit cd25ee5
Show file tree
Hide file tree
Showing 2 changed files with 86 additions and 150 deletions.
5 changes: 3 additions & 2 deletions hopper/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
// static constexpr int kBlockN = Ktraits::kBlockN;
// constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr bool Delay_V_release = Is_causal && kHeadDim == 128;

using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal>;
using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits>;
Expand Down Expand Up @@ -362,7 +363,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
// smem_pipe_release, tOrO, softmax, n_block_max,
// threadIdx.x - NumCopyThreads, work_idx, m_block,
// shared_storage);
collective_mainloop.mma_fp8_ver2(
collective_mainloop.mma_fp8_ver2<Delay_V_release>(
mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read, smem_pipe_release,
tOrO, softmax, n_block_max,
threadIdx.x - NumCopyThreads, work_idx, m_block,
Expand Down
231 changes: 83 additions & 148 deletions hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -1392,7 +1392,7 @@ struct CollectiveMainloopFwd {
return;
}

template <typename SharedStorage, typename FrgTensorO, typename Softmax>
template <bool Delay_V_release = false, typename SharedStorage, typename FrgTensorO, typename Softmax>
CUTLASS_DEVICE void
mma_fp8_ver2(Params const& mainloop_params,
MainloopPipeline pipeline_k,
Expand Down Expand Up @@ -1425,7 +1425,6 @@ struct CollectiveMainloopFwd {
Tensor tSrQ = threadMma0.partition_fragment_A(sQ);
Tensor tSrK = threadMma0.partition_fragment_B(sK);
// Allocate "fragments/descriptors" for second matmul.
// Note: S becomes P.
Tensor tOrV = threadMma1.partition_fragment_B(sVt);

auto consumer_wait = [](auto& pipeline, auto& smem_pipe_read) {
Expand Down Expand Up @@ -1457,10 +1456,8 @@ struct CollectiveMainloopFwd {
}
}
}
warpgroup_wait<0>();
// warp_scheduler_barrier_arrive();
pipeline_k.consumer_release(smem_pipe_read); // DEFAULT
// pipeline_k.consumer_release(smem_pipe_release);
warpgroup_wait<0>();
pipeline_k.consumer_release(smem_pipe_read);

auto col_limit_causal = [&](int row, int n_block) {
return row + 1 + seqlen_k - n_block * kBlockN - seqlen_q + m_block * kBlockM;
Expand All @@ -1473,22 +1470,15 @@ struct CollectiveMainloopFwd {
if constexpr (!Is_causal) { // Just masking based on col
if (int(get<1>(tScS(i))) >= int(seqlen_k - n_block * kBlockN)) { tSrS(i) = -INFINITY; }
} else { // mask based on both row and col
// using std::min is faster than doing col >= limit0 or col >= limit1
// Need to cast get<1>(tScS(i)) to (signed) int since by default it's unsigned, and the
// right hand side can be negative and might be converted to a very large unsigned integer.
if (int(get<1>(tScS(i))) >= std::min(seqlen_k - n_block * kBlockN,
col_limit_causal(int(get<0>(tScS(i))), n_block))) {
col_limit_causal(int(get<0>(tScS(i))), n_block))) {
tSrS(i) = -INFINITY;
}
}
}
}

// warp_scheduler_barrier_arrive();
// pipeline_k.consumer_release(smem_pipe_read);

softmax.template online_softmax</*Is_first=*/true>(tSrS, mainloop_params.softmax_scale_log2);

Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);

Expand All @@ -1497,32 +1487,19 @@ struct CollectiveMainloopFwd {

consumer_wait(pipeline_vt, smem_pipe_read);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
#ifndef RELEASE_PATTERN
pipeline_vt.consumer_release(smem_pipe_read); // DEFAULT
#endif
// pipeline_vt.consumer_release(smem_pipe_release);

++smem_pipe_read;
// ++smem_pipe_release;
--n_block;


constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM, kBlockN);
// constexpr int extra_iterations = kStages - 1;
if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }

if constexpr(Is_causal) {
// if constexpr (kHeadDim == 128)
// warp_scheduler_barrier_sync();
++smem_pipe_read;
--n_block;
constexpr int extra_iterations = !Is_causal ? kStages - 1 : cute::ceil_div(kBlockM, kBlockN);

if constexpr(Is_causal) {
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));

consumer_wait(pipeline_k, smem_pipe_read);
consumer_wait(pipeline_k, smem_pipe_read);
warp_scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
warpgroup_wait<0>();
// pipeline_k.consumer_release(smem_pipe_read);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);

Tensor cS = cute::make_identity_tensor(select<0, 1>(TileShape_MNK{}));
Tensor tScS = threadMma0.partition_C(cS);
Expand All @@ -1534,152 +1511,110 @@ struct CollectiveMainloopFwd {
}

warp_scheduler_barrier_arrive();
pipeline_k.consumer_release(smem_pipe_read); // DEFAULT
consumer_wait(pipeline_vt, smem_pipe_read);

cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
pipeline_k.consumer_release(smem_pipe_read);
consumer_wait(pipeline_vt, smem_pipe_read);

Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(),
convert_layout_acc_Aregs_fp8(tSrS.layout()));
cute::copy(softmax.template max</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false, /*Check_inf=*/true>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);

// consumer_wait(pipeline_vt, smem_pipe_read);
#ifdef RELEASE_PATTERN

if constexpr(Delay_V_release) {
pipeline_vt.consumer_release(smem_pipe_release);
++smem_pipe_release;
#endif
}
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
#ifndef RELEASE_PATTERN
pipeline_vt.consumer_release(smem_pipe_read);
#endif
++smem_pipe_read;
if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
++smem_pipe_read;
}
}
#if 1
else {
CUTLASS_PRAGMA_UNROLL
for (int iter = 0; iter < extra_iterations && n_block >= 0; ++iter, --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));

consumer_wait(pipeline_k, smem_pipe_read);
#ifdef RELEASE_PATTERN
if constexpr(Delay_V_release) {
pipeline_vt.consumer_release(smem_pipe_release);
++smem_pipe_release;
#endif
}
warp_scheduler_barrier_sync();
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
warp_scheduler_barrier_arrive();
#ifndef RELEASE_PATTERN
pipeline_k.consumer_release(smem_pipe_read);
#else
consumer_wait(pipeline_vt, smem_pipe_read);
#endif

cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
if constexpr(!Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
else { consumer_wait(pipeline_vt, smem_pipe_read); }

Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(),
convert_layout_acc_Aregs_fp8(tSrS.layout()));
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);

// consumer_wait(pipeline_vt, smem_pipe_read);

// warp_scheduler_barrier_sync();
#ifdef RELEASE_PATTERN
pipeline_k.consumer_release(smem_pipe_read);
#else
consumer_wait(pipeline_vt, smem_pipe_read);
#endif

if constexpr (Delay_V_release) { pipeline_k.consumer_release(smem_pipe_read); }
else { consumer_wait(pipeline_vt, smem_pipe_read); }
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
// warp_scheduler_barrier_arrive();
#ifndef RELEASE_PATTERN
pipeline_vt.consumer_release(smem_pipe_read);
#endif
++smem_pipe_read;
if constexpr(!Delay_V_release) { pipeline_vt.consumer_release(smem_pipe_read); }
++smem_pipe_read;
}
}
#endif
}

#ifdef RELEASE_PATTERN
warp_scheduler_barrier_sync();
#else
if constexpr (kHeadDim == 128)
if constexpr(Delay_V_release) {
warp_scheduler_barrier_sync();
#endif
CUTLASS_PRAGMA_NO_UNROLL
for (; n_block >= 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read);
pipeline_vt.consumer_release(smem_pipe_release);
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
warp_scheduler_barrier_arrive();
warpgroup_wait<0>();
consumer_wait(pipeline_vt, smem_pipe_read);

CUTLASS_PRAGMA_NO_UNROLL
for (; n_block >= 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read); // wait K

#ifdef RELEASE_PATTERN
pipeline_vt.consumer_release(smem_pipe_release);
++smem_pipe_release;
#else
if constexpr (kHeadDim == 256)
warp_scheduler_barrier_sync();
#endif
flash::gemm</*zero_init=*/true, /*wg_wait=*/-1>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
#ifdef RELEASE_PATTERN
warp_scheduler_barrier_arrive();
#endif
warpgroup_wait<0>();
#ifndef RELEASE_PATTERN
warp_scheduler_barrier_arrive();
pipeline_k.consumer_release(smem_pipe_read); // release current K
#else
consumer_wait(pipeline_vt, smem_pipe_read);
#endif

cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);

Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(),
convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);

// consumer_wait(pipeline_vt, smem_pipe_read);

#ifdef RELEASE_PATTERN
pipeline_k.consumer_release(smem_pipe_read); // release current K
#else
consumer_wait(pipeline_vt, smem_pipe_read);
if constexpr (kHeadDim == 128)
warp_scheduler_barrier_sync();
#endif
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
// warp_scheduler_barrier_arrive();
#ifdef RELEASE_PATTERN
pipeline_k.consumer_release(smem_pipe_read);
flash::gemm</*zero_init=*/false, /*wg_wait=*/-1>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
warp_scheduler_barrier_sync();
#endif
warpgroup_wait<0>();
// warp_scheduler_barrier_sync();
#ifndef RELEASE_PATTERN
pipeline_vt.consumer_release(smem_pipe_read);
#endif

++smem_pipe_read;
}

#ifdef RELEASE_PATTERN
warp_scheduler_barrier_arrive();
warpgroup_wait<0>();
++smem_pipe_read;
++smem_pipe_release;
}
warp_scheduler_barrier_arrive();
pipeline_vt.consumer_release(smem_pipe_release);
++smem_pipe_release;
#else
if constexpr (kHeadDim == 128)
warp_scheduler_barrier_arrive();
#endif

} else {
if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
CUTLASS_PRAGMA_NO_UNROLL
for (; n_block >= 0; --n_block) {
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
consumer_wait(pipeline_k, smem_pipe_read);
if constexpr (kHeadDim == 256) { warp_scheduler_barrier_sync(); }
flash::gemm</*zero_init=*/true, /*wg_wait=*/0>(tiled_mma0, tSrQ, tSrK(_, _, _, smem_pipe_read.index()), tSrS);
warp_scheduler_barrier_arrive();
pipeline_k.consumer_release(smem_pipe_read);

cute::copy(softmax.template max</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.template online_softmax</*Is_first=*/false>(tSrS, mainloop_params.softmax_scale_log2);
Tensor tOrP = make_tensor(convert_type<Element>(tSrS).data(), convert_layout_acc_Aregs_fp8(tSrS.layout()));
permute_regs_A_to_C(tOrP);

consumer_wait(pipeline_vt, smem_pipe_read);
if constexpr (kHeadDim == 128) { warp_scheduler_barrier_sync(); }
flash::gemm</*zero_init=*/false, /*wg_wait=*/0>(tiled_mma1, tOrP, tOrV(_, _, _, smem_pipe_read.index()), tOrO);
pipeline_vt.consumer_release(smem_pipe_read);
++smem_pipe_read;
}
if constexpr (kHeadDim == 128) { warp_scheduler_barrier_arrive(); }
}
cutlass::arch::NamedBarrier::arrive(NumMmaThreads + cutlass::NumThreadsPerWarpGroup, static_cast<int>(FwdNamedBarriers::QueryEmpty) /*id*/);

cute::copy(softmax.template finalize</*Check_inf=*/Is_causal>(tSrS, mainloop_params.softmax_scale_log2), scores_scale);
softmax.rescale_o(tOrO, scores_scale);
softmax.rescale_o(tOrO, scores_scale);
return;
}

Expand Down

0 comments on commit cd25ee5

Please sign in to comment.