Skip to content

Commit

Permalink
clean up unneeded methods and variants
Browse files Browse the repository at this point in the history
  • Loading branch information
jayhshah committed Jul 26, 2024
1 parent cb8b453 commit a00492e
Show file tree
Hide file tree
Showing 7 changed files with 45 additions and 718 deletions.
2 changes: 0 additions & 2 deletions hopper/flash_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,6 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
// run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
// });
if (!params.is_e4m3) {
#if 0
if (params.is_bf16) {
if (params.d == 64) {
run_mha_fwd_<cutlass::bfloat16_t, 64>(params, stream);
Expand All @@ -241,7 +240,6 @@ void run_mha_fwd(Flash_fwd_params &params, cudaStream_t stream, bool force_split
run_mha_fwd_<cutlass::half_t, 256>(params, stream);
}
}
#endif
} else {
if (params.d == 64) {
run_mha_fwd_<cutlass::float_e4m3_t, 64>(params, stream);
Expand Down
70 changes: 15 additions & 55 deletions hopper/flash_fwd_kernel.h
Original file line number Diff line number Diff line change
Expand Up @@ -194,8 +194,8 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
static constexpr int NumCopyThreads = !Is_WS ? 0 : cutlass::NumThreadsPerWarpGroup;
static constexpr int kBlockM = Ktraits::kBlockM;
// static constexpr int kBlockN = Ktraits::kBlockN;
static constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr bool Delay_V_release = Is_causal && kHeadDim == 128;
// static constexpr int kHeadDim = Ktraits::kHeadDim;
static constexpr bool Delay_V_release = Is_causal && Ktraits::kHeadDim == 128;

using CollectiveMainloop = CollectiveMainloopFwd<Ktraits, Is_causal>;
using CollectiveEpilogue = CollectiveEpilogueFwd<Ktraits>;
Expand Down Expand Up @@ -238,11 +238,7 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,

if (warp_idx == 0 && lane_predicate) {
shared_storage.barrier_Q.init(1 /*numThreads*/);
#ifndef NO_UNION
#ifndef NEW_FP8_EPI_BARRIER
shared_storage.barrier_O.init(size(ClusterShape{}) /*numThreads*/);
#endif
#endif
}
// We're counting on pipeline_k to call cutlass::arch::fence_barrier_init();
MainloopPipeline pipeline_k(shared_storage.pipeline_k, pipeline_params, ClusterShape{});
Expand All @@ -266,15 +262,9 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
if (warp_group_idx == 0) { // Producer
cutlass::arch::warpgroup_reg_dealloc<Ktraits::kNWarps == 12 ? 40 : 32>();


#ifdef USE_TRI_MMA_FP8
PipelineState smem_pipe_write_k = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_write_v = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_read_v;
#else
PipelineState smem_pipe_write = cutlass::make_producer_start_state<MainloopPipeline>();
PipelineState smem_pipe_read, smem_pipe_release;
#endif


int work_idx = 0;

Expand All @@ -289,45 +279,31 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
if (Is_causal && n_block_max <= 0) {
scheduler.prefetch_next_work(scheduler_params, work_tile_info);
scheduler.broadcast_next_work(work_tile_info);
// TODO: remove this
// need to sync producer warpgroup
cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
continue;
}
#ifdef USE_TRI_MMA_FP8
collective_mainloop.load_fp8_ver1(
mainloop_params, pipeline_k, pipeline_v, pipeline_vt,
smem_pipe_write_k, smem_pipe_write_v, smem_pipe_read_v, shared_storage,
scheduler, scheduler_params, work_tile_info, block_coord, work_idx);
#else

collective_mainloop.load_fp8(
mainloop_params, pipeline_k, pipeline_v, pipeline_vt,
smem_pipe_write, smem_pipe_read, shared_storage,
scheduler, scheduler_params, work_tile_info, block_coord, work_idx);
#endif
scheduler, scheduler_params, work_tile_info, block_coord, work_idx);
++work_idx;
// need to sync producer warpgroup
// TODO: remove this
// if (Is_causal)
// cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/);
// don't need to sync producer warpgroup here
// if constexpr (Is_causal) {
// cutlass::arch::NamedBarrier::sync(NumCopyThreads, static_cast<int>(FwdNamedBarriers::ProducerWG) /*id*/); }
}
#ifdef USE_TRI_MMA_FP8
collective_mainloop.load_tail(pipeline_k, pipeline_v, smem_pipe_write_k, smem_pipe_write_v);
#else
collective_mainloop.load_tail_one_write(pipeline_k, pipeline_v, smem_pipe_write);
#endif


} else { // Consumer
cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 232 : 160>();

TileScheduler scheduler(&shared_storage.tile_count_semaphore);
// Initialize matmul objects.
typename Ktraits::TiledMma1 tiled_mma1;
#ifdef USE_TRI_MMA_FP8
PipelineState smem_pipe_read_k, smem_pipe_read_vt;
#else
PipelineState smem_pipe_read;
PipelineState smem_pipe_release;
#endif

collective_mainloop.mma_init_fp8();
scheduler.init_consumer();
Expand All @@ -349,32 +325,16 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
collective_epilogue.store_zero(epilogue_params, threadIdx.x - NumCopyThreads, block_coord);
continue;
}

#ifdef USE_TRI_MMA_FP8
collective_mainloop.mma_fp8_ver1(
mainloop_params, pipeline_k, pipeline_vt,
smem_pipe_read_k, smem_pipe_read_vt,
tOrO, softmax, n_block_max,
threadIdx.x - NumCopyThreads, work_idx, m_block,
shared_storage);
#else
// collective_mainloop.mma_fp8(
// mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read,
// smem_pipe_release, tOrO, softmax, n_block_max,
// threadIdx.x - NumCopyThreads, work_idx, m_block,
// shared_storage);
collective_mainloop.mma_fp8_ver2<Delay_V_release>(

collective_mainloop.mma_fp8<Delay_V_release>(
mainloop_params, pipeline_k, pipeline_vt, smem_pipe_read, smem_pipe_release,
tOrO, softmax, n_block_max,
threadIdx.x - NumCopyThreads, work_idx, m_block,
shared_storage);
#endif
shared_storage);

#ifdef COLUMN_PERMUTE
#ifndef NO_FP8_COLUMN_PERMUTE
collective_epilogue.store_fp8(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
threadIdx.x - NumCopyThreads, block_coord);
// collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
// threadIdx.x - NumCopyThreads, block_coord);
threadIdx.x - NumCopyThreads, block_coord);
#else
collective_epilogue.store(epilogue_params, tOrO, softmax.row_sum, shared_storage, tiled_mma1,
threadIdx.x - NumCopyThreads, block_coord);
Expand Down
26 changes: 4 additions & 22 deletions hopper/kernel_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,33 +39,15 @@ struct SharedStorageQKVOVt {
struct {
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutQ>> smem_q;
cute::array_aligned<Gemm1Type, cute::cosize_v<SmemLayoutK>> smem_k;
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
#ifdef NO_UNION
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
#else
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
union {
cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
};
#endif
// union {
// struct {
// cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v;
// cute::array_aligned<Gemm2Type, cute::cosize_v<SmemLayoutV>> smem_v_out;
// };
// struct {
// cute::array_aligned<OutputType, cute::cosize_v<SmemLayoutO>> smem_o;
// };
// };
};
struct {
cutlass::arch::ClusterTransactionBarrier barrier_Q;
#ifndef NO_UNION
#ifndef NEW_FP8_EPI_BARRIER
cutlass::arch::ClusterTransactionBarrier barrier_Q;
cutlass::arch::ClusterBarrier barrier_O;
#endif
#endif
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_k;
typename cutlass::PipelineTmaAsync<kStages>::SharedStorage pipeline_v;
typename cutlass::PipelineAsync<kStages>::SharedStorage pipeline_vt;
Expand Down Expand Up @@ -155,7 +137,7 @@ struct Flash_fwd_kernel_traits {

};

// Traits struct for fp8 kernel
// Traits struct for fp8 kernel with in-kernel transpose
template<int kHeadDim_, int kBlockM_, int kBlockN_, int kNWarps_, int kStages_, bool Is_Q_in_regs_=false,
int kClusterM_ = 1, typename elem_type=cutlass::float_e4m3_t>
struct Flash_fwd_kernel_traits_fp8 {
Expand Down Expand Up @@ -230,7 +212,7 @@ struct Flash_fwd_kernel_traits_fp8 {
decltype(composition(SmemLayoutVt{},
make_ordered_layout(product_each(shape(SmemLayoutV{})), Step<_2, _1, _3>{})));
using SmemLayoutDivideVt = decltype(tiled_divide(SmemLayoutVtTrans{}, TransposeShapeAtomV{}));
#ifdef COLUMN_PERMUTE
#ifndef NO_FP8_COLUMN_PERMUTE
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_8, _8>>;
#else
using SmemShapeSTSM = Shape<Shape<_16, _4>, Shape<_16, _4>>;
Expand Down
Loading

0 comments on commit a00492e

Please sign in to comment.