Skip to content

Commit

Permalink
addressed third round of comments
Browse files Browse the repository at this point in the history
  • Loading branch information
t4c1 committed Dec 6, 2024
1 parent dc12326 commit 51cb0b5
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 31 deletions.
26 changes: 13 additions & 13 deletions examples/35_gemm_softmax/gemm_online_softmax.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -128,7 +128,7 @@ struct Options {
/// Prints the usage statement.
std::ostream & print_usage(std::ostream &out) const {

out << "35_gemm_softmax example\n\n"
out << "35_gemm_online_softmax example\n\n"
<< " This example uses the CUTLASS Library to execute TF32 tensorop GEMM computations.\n\n"
<< "Options:\n\n"
<< " --help If specified, displays this usage statement.\n\n"
Expand Down Expand Up @@ -193,13 +193,13 @@ struct ExampleRunner {
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;
using StrideTmp = typename Gemm::CollectiveEpilogue::StrideD;
using StridePartials = typename Gemm::CollectiveEpilogue::StrideD;

using LayoutA = typename Gemm::LayoutA;
using LayoutB = typename Gemm::LayoutB;
using LayoutC = typename Gemm::LayoutC;
using LayoutD = typename Gemm::LayoutD;
using LayoutTmp = typename Gemm::LayoutTmp;
using LayoutPartials = typename Gemm::LayoutPartials;

using ElementA = typename Gemm::ElementA;
using ElementB = typename Gemm::ElementB;
Expand All @@ -223,7 +223,7 @@ struct ExampleRunner {
StrideB stride_B;
StrideC stride_C;
StrideD stride_D;
StrideTmp stride_tmp;
StridePartials stride_partials;
uint64_t seed = 0;

cutlass::DeviceAllocation<ElementA> block_A;
Expand Down Expand Up @@ -281,8 +281,8 @@ struct ExampleRunner {
LayoutA layout_A(lda);
LayoutB layout_B(ldb);
LayoutC layout_C(ldc);
LayoutTmp Layout_N(ldn);
LayoutTmp Layout_S(lds);
LayoutPartials Layout_N(ldn);
LayoutPartials Layout_S(lds);

cutlass::MatrixCoord extent_A{options.m, options.k};
cutlass::MatrixCoord extent_B{options.k, options.n};
Expand Down Expand Up @@ -365,21 +365,21 @@ struct ExampleRunner {
auto [M, N, K, L] = problem_shape_MNKL;

auto partials_N = cute::ceil_div(N, cute::shape<1>(typename Gemm::TileShape{}));
auto tmp_size = M * partials_N * L;
auto partials_size = M * partials_N * L;

stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L));
stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L));
stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L));
stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L));
stride_tmp = cutlass::make_cute_packed_stride(StrideTmp{}, cute::make_shape(M, partials_N, L));
stride_partials = cutlass::make_cute_packed_stride(StridePartials{}, cute::make_shape(M, partials_N, L));

block_A.reset(M * K * L);
block_B.reset(K * N * L);
block_C.reset(M * N * L);
block_D.reset(M * N * L);
block_ref_D.reset(M * N * L);
block_sum.reset(tmp_size);
block_max.reset(tmp_size);
block_sum.reset(partials_size);
block_max.reset(partials_size);

initialize_block(block_A, seed + 2023);
initialize_block(block_B, seed + 2022);
Expand All @@ -399,7 +399,7 @@ struct ExampleRunner {
options.beta},
block_C.get(), stride_C,
block_D.get(), stride_D,
block_max.get(), block_sum.get(), stride_tmp},
block_max.get(), block_sum.get(), stride_partials},
hw_info
};

Expand Down Expand Up @@ -513,7 +513,7 @@ int main(int argc, char const **args) {
using LayoutB = cutlass::layout::ColumnMajor;
using LayoutC = cutlass::layout::ColumnMajor;
using LayoutD = cutlass::layout::ColumnMajor;
using LayoutTmp = cutlass::layout::ColumnMajor;
using LayoutPartials = cutlass::layout::ColumnMajor;

// Tiling configuration selection
using TileShape = Shape<_128,_128,_32>;
Expand Down Expand Up @@ -580,7 +580,7 @@ int main(int argc, char const **args) {
using CollectiveEpilogue = cutlass::epilogue::collective::SoftmaxEpilogue<
cutlass::detail::TagToStrideC_t<LayoutC>,
cutlass::detail::TagToStrideC_t<LayoutD>,
cutlass::detail::TagToStrideC_t<LayoutTmp>,
cutlass::detail::TagToStrideC_t<LayoutPartials>,
TileShape,
EpilogueOp,
cutlass::gemm::EpilogueDefault>;
Expand Down
7 changes: 3 additions & 4 deletions examples/35_gemm_softmax/gemm_softmax_adapter.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
Expand Down Expand Up @@ -86,15 +85,15 @@ class GemmSoftmaxAdapter

using SoftmaxFinalizeKernel = reduction::kernel::SoftmaxFinalize<
ElementD, typename GemmKernel::StrideD,
ElementAccumulator, typename GemmKernel::CollectiveEpilogue::StrideTmp,
ElementAccumulator, typename GemmKernel::CollectiveEpilogue::StridePartials,
ElementD, typename GemmKernel::StrideD>;

// Map back to 2.x type as best as possible
using LayoutA = gemm::detail::StrideToLayoutTagA_t<typename GemmKernel::StrideA>;
using LayoutB = gemm::detail::StrideToLayoutTagB_t<typename GemmKernel::StrideB>;
using LayoutC = gemm::detail::StrideToLayoutTagC_t<typename GemmKernel::StrideC>;
using LayoutD = gemm::detail::StrideToLayoutTagC_t<typename GemmKernel::StrideD>;
using LayoutTmp = gemm::detail::StrideToLayoutTagC_t<typename GemmKernel::StrideD>;
using LayoutPartials = gemm::detail::StrideToLayoutTagC_t<typename GemmKernel::StrideD>;

static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER;

Expand Down Expand Up @@ -270,7 +269,7 @@ class GemmSoftmaxAdapter
softmax_args.partialN = cute::ceil_div(get<1>(args.problem_shape), cute::shape<1>(TileShape{}));
softmax_args.batch_count = get<3>(args.problem_shape);
softmax_args.dInput = args.epilogue.dD;
softmax_args.dPartial = args.epilogue.dTmp;
softmax_args.dPartial = args.epilogue.dPartials;
softmax_args.dOutput = args.epilogue.dD;
softmax_args.ptr_in = args.epilogue.ptr_D;
softmax_args.ptr_partial_max = args.epilogue.ptr_max;
Expand Down
13 changes: 6 additions & 7 deletions examples/35_gemm_softmax/softmax_epilogue.hpp
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
/***************************************************************************************************
* Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
Expand Down Expand Up @@ -53,7 +52,7 @@ namespace collective {
template <
class StrideC_,
class StrideD_,
class StrideTmp_,
class StridePartials_,
class BlockShapeMNK,
class ThreadEpilogueOp_,
class EpilogueSchedule_
Expand All @@ -76,7 +75,7 @@ class SoftmaxEpilogue {
using StrideC = StrideC_;
using ElementD = typename ThreadEpilogueOp::ElementD;
using StrideD = StrideD_;
using StrideTmp = StrideTmp_;
using StridePartials = StridePartials_;

using GmemTiledCopyC = void;
using GmemTiledCopyD = void;
Expand All @@ -102,7 +101,7 @@ class SoftmaxEpilogue {
StrideD dD{};
ElementAccumulator* ptr_max;
ElementAccumulator* ptr_sum;
StrideTmp dTmp{};
StridePartials dPartials{};
};

// Device side epilogue params
Expand Down Expand Up @@ -187,7 +186,7 @@ class SoftmaxEpilogue {
auto N_tile = get<1>(blk_shape_MNK);
auto K_tile = get<2>(blk_shape_MNK);

auto N_tmp = cute::ceil_div(N, N_tile);
auto N_partials = cute::ceil_div(N, N_tile);

cute::packed_tuple partial_block(M_tile, K_tile);

Expand All @@ -197,8 +196,8 @@ class SoftmaxEpilogue {
// Represent the full output tensors
Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l)
Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l)
Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_tmp,L), params.dTmp);
Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_tmp,L), params.dTmp);
Tensor mMax_mnl = make_tensor(make_gmem_ptr(params.ptr_max), make_shape(M,N_partials,L), params.dPartials);
Tensor mSum_mnl = make_tensor(make_gmem_ptr(params.ptr_sum), make_shape(M,N_partials,L), params.dPartials);
Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l)
Tensor gMax_mnl = local_tile(mMax_mnl, partial_block, make_coord(_,_), Step<_1, X>{});
Expand Down
1 change: 0 additions & 1 deletion include/cutlass/fast_math.h
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
/***************************************************************************************************
* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
* Copyright (c) 2024 - 2024 Codeplay Software Ltd. All rights reserved.
* SPDX-License-Identifier: BSD-3-Clause
*
* Redistribution and use in source and binary forms, with or without
Expand Down
12 changes: 6 additions & 6 deletions include/cutlass/gpu_generics.h
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,12 @@

////////////////////////////////////////////////////////////////////////////////////////////////////

static constexpr int NumThreadsPerWarp = 32;
static constexpr int NumThreadsPerWarpGroup = 128;
static constexpr int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp;
static constexpr int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2;
static constexpr int NumThreadsPerQuad = 4;
static constexpr int NumThreadsPerQuadPair = NumThreadsPerQuad * 2;
static const int NumThreadsPerWarp = 32;
static const int NumThreadsPerWarpGroup = 128;
static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp;
static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2;
static const int NumThreadsPerQuad = 4;
static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2;
static constexpr int MaxNumThreadsPerBlock = 1024;

////////////////////////////////////////////////////////////////////////////////////////////////////
Expand Down

0 comments on commit 51cb0b5

Please sign in to comment.