From 1926f4398d5276a40a5101e3fa788bf720c26fbf Mon Sep 17 00:00:00 2001 From: wangyibo1005 <74347676+wangyibo1005@users.noreply.github.com> Date: Thu, 25 Sep 2025 20:05:40 +0800 Subject: [PATCH 1/6] Add files via upload --- act/act.hpp | 38 + act/arch/arch.hpp | 54 + act/arch/cross_core_sync.hpp | 107 ++ act/arch/local_tensor_buffer.hpp | 232 +++++ act/arch/resource.hpp | 42 + act/coord.hpp | 271 +++++ act/detail/alignment.hpp | 51 + act/detail/callback.hpp | 55 + act/detail/dependent_false.hpp | 22 + act/detail/macros.hpp | 20 + act/detail/tag_to_layout.hpp | 90 ++ act/epilogue/block/block_epilogue.hpp | 28 + .../block_epilogue_per_token_dequant.hpp | 867 ++++++++++++++++ act/epilogue/dispatch_policy.hpp | 75 ++ act/epilogue/tile/copy_gm_to_ub.hpp | 174 ++++ act/epilogue/tile/copy_ub_to_gm.hpp | 131 +++ .../tile/tile_broadcast_inplace_by_column.hpp | 71 ++ .../tile/tile_broadcast_inplace_by_row.hpp | 58 ++ act/epilogue/tile/tile_broadcast_mul.hpp | 131 +++ act/epilogue/tile/tile_broadcast_one_blk.hpp | 56 + act/epilogue/tile/tile_cast.hpp | 45 + act/epilogue/tile/tile_copy.hpp | 111 ++ act/epilogue/tile/tile_elemwise_add.hpp | 47 + act/epilogue/tile/tile_elemwise_mul.hpp | 46 + act/epilogue/tile/tile_elemwise_muls.hpp | 38 + act/epilogue/tile/tile_swizzle.hpp | 82 ++ act/gemm/block/block_mmad.hpp | 67 ++ ...block_mmad_preload_async_with_callback.hpp | 455 ++++++++ act/gemm/block/block_swizzle.hpp | 234 +++++ act/gemm/dispatch_policy.hpp | 91 ++ act/gemm/gemm_type.hpp | 30 + act/gemm/helper.hpp | 274 +++++ ...per_token_dequant_multistage_workspace.hpp | 388 +++++++ act/gemm/tile/copy_gm_to_l1.hpp | 856 +++++++++++++++ act/gemm/tile/copy_gm_to_ub.hpp | 64 ++ act/gemm/tile/copy_l0c_to_gm.hpp | 257 +++++ act/gemm/tile/copy_l1_to_l0a.hpp | 438 ++++++++ act/gemm/tile/copy_l1_to_l0b.hpp | 598 +++++++++++ act/gemm/tile/copy_ub_to_gm.hpp | 99 ++ act/gemm/tile/tile_copy.hpp | 214 ++++ act/gemm/tile/tile_mmad.hpp | 114 ++ act/gemm_coord.hpp | 120 +++ act/gemv_coord.hpp | 89 ++ act/layout/layout.hpp | 20 + act/layout/matrix.hpp | 982 ++++++++++++++++++ act/layout/vector.hpp | 108 ++ act/matrix_coord.hpp | 98 ++ 47 files changed, 8538 insertions(+) create mode 100644 act/act.hpp create mode 100644 act/arch/arch.hpp create mode 100644 act/arch/cross_core_sync.hpp create mode 100644 act/arch/local_tensor_buffer.hpp create mode 100644 act/arch/resource.hpp create mode 100644 act/coord.hpp create mode 100644 act/detail/alignment.hpp create mode 100644 act/detail/callback.hpp create mode 100644 act/detail/dependent_false.hpp create mode 100644 act/detail/macros.hpp create mode 100644 act/detail/tag_to_layout.hpp create mode 100644 act/epilogue/block/block_epilogue.hpp create mode 100644 act/epilogue/block/block_epilogue_per_token_dequant.hpp create mode 100644 act/epilogue/dispatch_policy.hpp create mode 100644 act/epilogue/tile/copy_gm_to_ub.hpp create mode 100644 act/epilogue/tile/copy_ub_to_gm.hpp create mode 100644 act/epilogue/tile/tile_broadcast_inplace_by_column.hpp create mode 100644 act/epilogue/tile/tile_broadcast_inplace_by_row.hpp create mode 100644 act/epilogue/tile/tile_broadcast_mul.hpp create mode 100644 act/epilogue/tile/tile_broadcast_one_blk.hpp create mode 100644 act/epilogue/tile/tile_cast.hpp create mode 100644 act/epilogue/tile/tile_copy.hpp create mode 100644 act/epilogue/tile/tile_elemwise_add.hpp create mode 100644 act/epilogue/tile/tile_elemwise_mul.hpp create mode 100644 act/epilogue/tile/tile_elemwise_muls.hpp create mode 100644 act/epilogue/tile/tile_swizzle.hpp create mode 100644 act/gemm/block/block_mmad.hpp create mode 100644 act/gemm/block/block_mmad_preload_async_with_callback.hpp create mode 100644 act/gemm/block/block_swizzle.hpp create mode 100644 act/gemm/dispatch_policy.hpp create mode 100644 act/gemm/gemm_type.hpp create mode 100644 act/gemm/helper.hpp create mode 100644 act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp create mode 100644 act/gemm/tile/copy_gm_to_l1.hpp create mode 100644 act/gemm/tile/copy_gm_to_ub.hpp create mode 100644 act/gemm/tile/copy_l0c_to_gm.hpp create mode 100644 act/gemm/tile/copy_l1_to_l0a.hpp create mode 100644 act/gemm/tile/copy_l1_to_l0b.hpp create mode 100644 act/gemm/tile/copy_ub_to_gm.hpp create mode 100644 act/gemm/tile/tile_copy.hpp create mode 100644 act/gemm/tile/tile_mmad.hpp create mode 100644 act/gemm_coord.hpp create mode 100644 act/gemv_coord.hpp create mode 100644 act/layout/layout.hpp create mode 100644 act/layout/matrix.hpp create mode 100644 act/layout/vector.hpp create mode 100644 act/matrix_coord.hpp diff --git a/act/act.hpp b/act/act.hpp new file mode 100644 index 00000000..0fc19b54 --- /dev/null +++ b/act/act.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_ACT_HPP +#define ACT_ACT_HPP + +#include + +#include "../act/detail/alignment.hpp" +#include "../act/detail/dependent_false.hpp" +#include "../act/detail/macros.hpp" + +namespace Act { + +constexpr uint32_t BYTE_PER_C0 = 32; +constexpr uint32_t C0_NUM_PER_FRACTAL = 16; +constexpr uint32_t BYTE_PER_FRACTAL = BYTE_PER_C0 * C0_NUM_PER_FRACTAL; + +constexpr uint32_t BYTE_PER_BLK = 32; +constexpr uint32_t BLK_NUM_PER_VECTOR_FRACTAL = 8; +constexpr uint32_t BYTE_PER_VECTOR_FRACTAL = + BYTE_PER_BLK * BLK_NUM_PER_VECTOR_FRACTAL; + +constexpr uint64_t L2_OFFSET = 0; +constexpr uint32_t STRIDE_LIMIT = 65536; + +} // namespace Act + +#endif // ACT_ACT_HPP diff --git a/act/arch/arch.hpp b/act/arch/arch.hpp new file mode 100644 index 00000000..bb0a2b4d --- /dev/null +++ b/act/arch/arch.hpp @@ -0,0 +1,54 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_ARCH_ARCH_HPP +#define ACT_ARCH_ARCH_HPP + +namespace Act::Arch { + +struct AtlasA2 { + static constexpr uint32_t BIAS_SIZE = 1024; + static constexpr uint32_t FIXBUF_SIZE = 7 * 1024; + static constexpr uint32_t UB_SIZE = 192 * 1024; + static constexpr uint32_t L1_SIZE = 512 * 1024; + static constexpr uint32_t L0A_SIZE = 64 * 1024; + static constexpr uint32_t L0B_SIZE = 64 * 1024; + static constexpr uint32_t L0C_SIZE = 128 * 1024; +}; + +struct PositionGM { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::GM; +}; + +struct PositionL1 { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::A1; +}; + +struct PositionL0A { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::A2; +}; + +struct PositionL0B { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::B2; +}; + +struct PositionL0C { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::CO1; +}; + +struct PositionUB { + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::VECCALC; +}; + +} // namespace Act::Arch + +#endif // ACT_ARCH_ARCH_HPP diff --git a/act/arch/cross_core_sync.hpp b/act/arch/cross_core_sync.hpp new file mode 100644 index 00000000..1617304f --- /dev/null +++ b/act/arch/cross_core_sync.hpp @@ -0,0 +1,107 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_ARCH_CROSS_CORE_SYNC_HPP +#define ACT_ARCH_CROSS_CORE_SYNC_HPP + +#include "../../act/act.hpp" + +namespace Act::Arch { + +constexpr uint32_t MAX_REVERSE_DEPTH = 16; + +using FlagID = uint16_t; +constexpr FlagID AIV_INTER_BLOCK_BARRIER = 8; +constexpr FlagID AIC_INTER_BLOCK_BARRIER = 9; +constexpr FlagID AIV_INTER_SUBBLOCK_BARRIER = 10; +constexpr FlagID FFTS_MAX_FLAG = 7; + +struct CrossCoreFlag { + ACT_DEVICE + CrossCoreFlag() : id(0) {} + + ACT_DEVICE + CrossCoreFlag(FlagID id) : id(id) {} + + FlagID id; +}; + +template +struct CrossCoreFlagWithReverse { + ACT_DEVICE + CrossCoreFlagWithReverse() : id(0), reverseId(0) {} + + ACT_DEVICE + CrossCoreFlagWithReverse(FlagID id, FlagID reverseId) + : id(id), reverseId(reverseId) {} + + FlagID id; + FlagID reverseId; + uint32_t count{0}; +}; + +template struct BarrierFlag { + static_assert(MODE != MODE, "Unsupporteded cross core barrier flag, can not " + "find the specialization."); +}; + +template <> struct BarrierFlag<0x0, AscendC::AIV> { + static constexpr FlagID ID = AIV_INTER_BLOCK_BARRIER; +}; + +template <> struct BarrierFlag<0x0, AscendC::AIC> { + static constexpr FlagID ID = AIC_INTER_BLOCK_BARRIER; +}; + +template <> struct BarrierFlag<0x1, AscendC::AIV> { + static constexpr FlagID ID = AIV_INTER_SUBBLOCK_BARRIER; +}; + +template ACT_DEVICE void CrossCoreBarrier() { + constexpr FlagID flagId = BarrierFlag::ID; + AscendC::CrossCoreSetFlag(flagId); + AscendC::CrossCoreWaitFlag(flagId); +} + +template +ACT_DEVICE void CrossCoreSetFlag(CrossCoreFlag &flag) { + AscendC::CrossCoreSetFlag(flag.id); +} + +ACT_DEVICE +void CrossCoreWaitFlag(CrossCoreFlag &flag) { + AscendC::CrossCoreWaitFlag(flag.id); +} + +template +ACT_DEVICE void +CrossCoreSetFlagWithReverse(CrossCoreFlagWithReverse &flag) { + AscendC::CrossCoreSetFlag(flag.id); + if (++flag.count >= REVERSE_DEPTH) { + AscendC::CrossCoreWaitFlag(flag.reverseId); + flag.count = 0; + } +} + +template +ACT_DEVICE void +CrossCoreWaitFlagWithReverse(CrossCoreFlagWithReverse &flag) { + AscendC::CrossCoreWaitFlag(flag.id); + if (++flag.count >= REVERSE_DEPTH) { + AscendC::CrossCoreSetFlag(flag.reverseId); + flag.count = 0; + } +} + +} // namespace Act::Arch + +#endif // ACT_ARCH_CROSS_CORE_SYNC_HPP diff --git a/act/arch/local_tensor_buffer.hpp b/act/arch/local_tensor_buffer.hpp new file mode 100644 index 00000000..c94841a4 --- /dev/null +++ b/act/arch/local_tensor_buffer.hpp @@ -0,0 +1,232 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef INCLUDE_ACT_ARCH_MEMORY_H +#define INCLUDE_ACT_ARCH_MEMORY_H + +#include "../../act/act.hpp" +#include "../../act/arch/arch.hpp" + +namespace Act::Arch { + +struct LocalTensorBufferBase { +public: + template + ACT_DEVICE AscendC::LocalTensor + GetBufferByByte(const uint32_t offset) const { + return tensor[offset].template ReinterpretCast(); + } + +protected: + ACT_DEVICE + LocalTensorBufferBase() = default; + + AscendC::LocalTensor tensor; +}; + +template struct LocalTensorBuffer { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded local tensor buffer, can not find the specialization."); +}; + +/// Partial specialization for TPosition::A1 +template +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::A1; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufA1; + GetTPipePtr()->InitBuffer(tbufA1, ArchTag::L1_SIZE); + tensor = tbufA1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::A2 +template +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::A2; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufA2; + GetTPipePtr()->InitBuffer(tbufA2, ArchTag::L0A_SIZE); + tensor = tbufA2.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::B1 +template +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::B1; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufB1; + GetTPipePtr()->InitBuffer(tbufB1, ArchTag::L1_SIZE); + tensor = tbufB1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::B2 +template +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::B2; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufB2; + GetTPipePtr()->InitBuffer(tbufB2, ArchTag::L0B_SIZE); + tensor = tbufB2.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::C1 +template <> +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C1; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufC1; + GetTPipePtr()->InitBuffer(tbufC1, ArchTag::L1_SIZE); + tensor = tbufC1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::C2 +template <> +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C2; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufC2; + GetTPipePtr()->InitBuffer(tbufC2, ArchTag::BIAS_SIZE); + tensor = tbufC2.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::CO1 +template +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::CO1; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufCO1; + GetTPipePtr()->InitBuffer(tbufCO1, ArchTag::L0C_SIZE); + tensor = tbufCO1.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for AtlasA2, TPosition::C2PIPE2GM +template <> +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C2PIPE2GM; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufC2PIPE2GM; + GetTPipePtr()->InitBuffer(tbufC2PIPE2GM, ArchTag::FIXBUF_SIZE); + tensor = tbufC2PIPE2GM.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::VECIN +template +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECIN; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufVECIN; + GetTPipePtr()->InitBuffer(tbufVECIN, ArchTag::UB_SIZE); + tensor = tbufVECIN.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::VECOUT +template +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECOUT; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufVECOUT; + GetTPipePtr()->InitBuffer(tbufVECOUT, ArchTag::UB_SIZE); + tensor = tbufVECOUT.Get(); + } +}; + +/////////////////////////////////////////////////////////// + +/// Partial specialization for TPosition::VECCALC +template +struct LocalTensorBuffer + : LocalTensorBufferBase { +public: + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECCALC; + + ACT_DEVICE + LocalTensorBuffer() { + AscendC::TBuf tbufVECCALC; + GetTPipePtr()->InitBuffer(tbufVECCALC, ArchTag::UB_SIZE); + tensor = tbufVECCALC.Get(); + } +}; + +} // namespace Act::Arch + +#endif // INCLUDE_ACT_ARCH_MEMORY_H diff --git a/act/arch/resource.hpp b/act/arch/resource.hpp new file mode 100644 index 00000000..d5c8531b --- /dev/null +++ b/act/arch/resource.hpp @@ -0,0 +1,42 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef INCLUDE_ACT_ARCH_RESOURCE_HPP +#define INCLUDE_ACT_ARCH_RESOURCE_HPP + +#include "../../act/act.hpp" +#include "../../act/arch/local_tensor_buffer.hpp" + +namespace Act::Arch { + +template struct Resource { +public: + AscendC::TPipe pipe; + + LocalTensorBuffer l1Buf; + LocalTensorBuffer l0ABuf; + LocalTensorBuffer l0BBuf; + LocalTensorBuffer l0CBuf; + LocalTensorBuffer ubBuf; + + ACT_DEVICE + Resource() { + // The initialization of AscendC::Tpipe will insert some synchronization + // interfaces, which may conflict with the usage by users. Therefore, the + // "destroy" interface is used for releasing. + pipe.Destroy(); + } +}; + +} // namespace Act::Arch + +#endif // INCLUDE_ACT_ARCH_RESOURCE_HPP diff --git a/act/coord.hpp b/act/coord.hpp new file mode 100644 index 00000000..f2e065e8 --- /dev/null +++ b/act/coord.hpp @@ -0,0 +1,271 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_COORD_HPP +#define ACT_COORD_HPP + +#include "../act/act.hpp" + +namespace Act { + +/// Statically-sized array specifying Coords within a tensor +template +struct Coord { +public: + // Number of elements in Coord + static const int RANK = RANK_; + + // Index typen used to store elements + using Index = Index_; + + // Type used to represent linear offsets + using LongIndex = LongIndex_; + + // Default ctor initializes uniformly + ACT_HOST_DEVICE constexpr explicit Coord(Index value = Index(0)) { + for (int i = 0; i < RANK; ++i) { + idx[i] = value; + } + } + + // Constructs from an array of integers + ACT_HOST_DEVICE constexpr Coord(Index const (&idx_)[RANK]) { + for (int i = 0; i < RANK; ++i) { + idx[i] = idx_[i]; + } + } + + // Constructs from an array of integers + ACT_HOST_DEVICE + int Argmin() const { + int i = 0; + for (int j = 1; j < RANK; ++j) { + if (idx[j] < idx[i]) { + i = j; + } + } + return i; + } + + // Returns the index of the dimension with greatest value + ACT_HOST_DEVICE + int Argmax() const { + int i = 0; + for (int j = 1; j < RANK; ++j) { + if (idx[j] > idx[i]) { + i = j; + } + } + return i; + } + + // Returns true if Coord is non-zero + ACT_HOST_DEVICE + explicit operator bool() const { + for (int i = 0; i < RANK; ++i) { + if (idx[i]) { + return true; + } + } + return false; + } + + // Return true if Coord is uniformly zero. + ACT_HOST_DEVICE + bool operator!() const { + for (int i = 0; i < RANK; ++i) { + if (idx[i]) { + return false; + } + } + return true; + } + + // Element-wise addition + ACT_HOST_DEVICE + Coord operator+(Coord const &b) const { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] + b.idx[i]; + } + return c; + } + + // Add a scalar to each element + ACT_HOST_DEVICE + Coord operator+(const Index val) const { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] + val; + } + return c; + } + + // Element-wise subtraction + ACT_HOST_DEVICE + Coord operator-(Coord const &b) const { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] - b.idx[i]; + } + return c; + } + + // Subtract a scalar from each element + ACT_HOST_DEVICE + Coord operator-(Index const val) const { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] - val; + } + return c; + } + + // Element-wise multiply + ACT_HOST_DEVICE + Coord operator*(Coord const &b) const { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] * b.idx[i]; + } + return c; + } + + // Element-wise division + ACT_HOST_DEVICE + Coord operator/(Coord const &b) const { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] / b.idx[i]; + } + return c; + } + + // Element-wise mod + ACT_HOST_DEVICE + Coord operator%(Coord const &b) const { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] % b.idx[i]; + } + return c; + } + + // In-place addition + ACT_HOST_DEVICE + Coord &operator+=(Coord const &b) { + for (int i = 0; i < RANK; ++i) { + idx[i] += b.idx[i]; + } + return *this; + } + + // In-place equal + ACT_HOST_DEVICE + bool operator==(Coord const &b) const { + for (int i = 0; i < RANK; ++i) { + if (idx[i] != b.idx[i]) { + return false; + } + } + return true; + } + + // In-place equal + ACT_HOST_DEVICE + bool operator==(Index const val) const { + for (int i = 0; i < RANK; ++i) { + if (idx[i] != val) { + return false; + } + } + return true; + } + + // Member access operator + ACT_HOST_DEVICE + Index &operator[](int dim) { return idx[dim]; } + + // Member access operator + ACT_HOST_DEVICE + Index const &operator[](int dim) const { return idx[dim]; } + + // Gets the index of a given Coord element + template ACT_HOST_DEVICE Index &At() { return idx[DIM]; } + + // Access via index; may limit unrolling potential + ACT_HOST_DEVICE + Index &At(int dim) { return idx[dim]; } + + // Gets the index of a given Coord element + template ACT_HOST_DEVICE Index const &At() const { + return idx[DIM]; + } + + // Access via index; may limit unrolling potential + ACT_HOST_DEVICE + Index const &At(int dim) const { return idx[dim]; } + + template ACT_HOST_DEVICE auto GetCoordByAxis() const { + Index idx_[sizeof...(Is)]{idx[Is]...}; + return Coord{idx_}; + } + + ACT_HOST_DEVICE + static Coord Min(Coord const &a, Coord const &b) { + Coord res; + for (int i = 0; i < RANK; ++i) { + res[i] = a[i] < b[i] ? a[i] : b[i]; + } + return res; + } + +private: + // Indices + Index idx[RANK]; +}; + +// Helper to make a 1-element coordinate +template ACT_HOST_DEVICE constexpr Coord<1, T> MakeCoord(T dim0) { + T values[1] = {dim0}; + return Coord<1, T>(values); +} + +/// Helper to make a 2-element coordinate +template +ACT_HOST_DEVICE constexpr Coord<2, T> MakeCoord(T dim0, T dim1) { + T values[2] = {dim0, dim1}; + return Coord<2, T>(values); +} + +/// Helper to make a 3-element coordinate +template +ACT_HOST_DEVICE constexpr Coord<3, T> MakeCoord(T dim0, T dim1, T dim2) { + T values[3] = {dim0, dim1, dim2}; + return Coord<3, T>(values); +} + +/// Helper to make a 4-element coordinate +template +ACT_HOST_DEVICE constexpr Coord<4, T> MakeCoord(T dim0, T dim1, T dim2, + T dim3) { + T values[4] = {dim0, dim1, dim2, dim3}; + return Coord<4, T>(values); +} + +} // namespace Act + +#endif // ACT_COORD_HPP diff --git a/act/detail/alignment.hpp b/act/detail/alignment.hpp new file mode 100644 index 00000000..fe9e3e1e --- /dev/null +++ b/act/detail/alignment.hpp @@ -0,0 +1,51 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_ALIGNMENT_HPP +#define ACT_ALIGNMENT_HPP + +#include "../../act/detail/macros.hpp" + +template +ACT_HOST_DEVICE constexpr T RoundUp(const T &val) { + static_assert(ALIGN != 0, "ALIGN must not be 0"); + return (val + ALIGN - 1) / ALIGN * ALIGN; +} + +template +ACT_HOST_DEVICE constexpr T RoundUp(const T &val, const T align) { + return (val + align - 1) / align * align; +} + +template +ACT_HOST_DEVICE constexpr T RoundDown(const T val) { + static_assert(ALIGN != 0, "ALIGN must not be 0"); + return val / ALIGN * ALIGN; +} + +template +ACT_HOST_DEVICE constexpr T RoundDown(const T val, const T align) { + return val / align * align; +} + +template +ACT_HOST_DEVICE constexpr T CeilDiv(const T dividend) { + static_assert(DIVISOP != 0, "DIVISOP must not be 0"); + return (dividend + DIVISOP - 1) / DIVISOP; +} + +template +ACT_HOST_DEVICE constexpr T CeilDiv(const T dividend, const T divisor) { + return (dividend + divisor - 1) / divisor; +} + +#endif // ACT_ALIGNMENT_HPP diff --git a/act/detail/callback.hpp b/act/detail/callback.hpp new file mode 100644 index 00000000..5c47c6f8 --- /dev/null +++ b/act/detail/callback.hpp @@ -0,0 +1,55 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_DETAIL_CALLBACK_HPP +#define ACT_DETAIL_CALLBACK_HPP + +#include "../../act/detail/macros.hpp" + +/// @brief Callback is an alternative to std::function, providing a +/// general carrier of callable structure with no parameters and no return +/// value. Compared with function pointers of type void (*)(), Callback can +/// carry lambda expressions with captures, and does not need to pay attention +/// to the captured content. It should be noted that Callback itself does not +/// store the callable structure it carries like std::function, so +/// it is necessary to ensure that it is used within the life cycle of the +/// callable structure. +struct Callback { + void const *func{nullptr}; + void (*caller)(void const *){nullptr}; + + Callback() = default; + + ACT_DEVICE + void operator()() const { + if (func) { + caller(func); + } + } + + ACT_DEVICE + operator bool() const { return func != nullptr; } +}; + +template ACT_DEVICE static void FuncWrapper(void const *func) { + (*static_cast(func))(); +} + +// Use this to make a callback +template ACT_DEVICE Callback MakeCallback(Func *func) { + Callback callback; + callback.func = func; + callback.caller = &FuncWrapper; + return callback; +} + +#endif // ACT_DETAIL_CALLBACK_HPP diff --git a/act/detail/dependent_false.hpp b/act/detail/dependent_false.hpp new file mode 100644 index 00000000..9a76dd52 --- /dev/null +++ b/act/detail/dependent_false.hpp @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_DETAIL_DEPENDENT_FALSE_HPP +#define ACT_DETAIL_DEPENDENT_FALSE_HPP + +template +constexpr bool DEPENDENT_BOOL_VALUE = VALUE; + +template +constexpr bool DEPENDENT_FALSE = DEPENDENT_BOOL_VALUE; + +#endif // ACT_DETAIL_DEPENDENT_FALSE_HPP diff --git a/act/detail/macros.hpp b/act/detail/macros.hpp new file mode 100644 index 00000000..fa31d68a --- /dev/null +++ b/act/detail/macros.hpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_DETAIL_MACROS_HPP +#define ACT_DETAIL_MACROS_HPP + +#define ACT_DEVICE __forceinline__[aicore] +#define ACT_HOST_DEVICE __forceinline__[host, aicore] +#define ACT_GLOBAL __global__[aicore] + +#endif // ACT_DETAIL_MACROS_HPP diff --git a/act/detail/tag_to_layout.hpp b/act/detail/tag_to_layout.hpp new file mode 100644 index 00000000..ec649e4e --- /dev/null +++ b/act/detail/tag_to_layout.hpp @@ -0,0 +1,90 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_DETAIL_TAG_TO_LAYOUT_HPP +#define ACT_DETAIL_TAG_TO_LAYOUT_HPP + +#include "../../act/layout/layout.hpp" +#include "../../tla/layout.hpp" + +using namespace tla; +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace Act::detail { +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For each Act::layout, provides its corresponding tla layout types +template struct TagToLayout { + using type = LayoutTag; +}; + +template struct TagToLayout { + using type = Layout, Stride>, + Shape>; +}; + +template struct TagToLayout { + using type = Layout, Stride, int64_t>, + Shape>; +}; + +template struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + using type = + Layout, uint32_t>, + Shape, uint32_t>>, + Stride, Int>, + Stride, int64_t>>, + Shape>; +}; + +template struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + using type = Layout, uint32_t>, + Shape, uint32_t>>, + Stride, int64_t>, + Stride, Int>>, + Shape>; +}; + +template struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + using type = + Layout, uint32_t>, + Shape, uint32_t>>, + Stride, int64_t>, + Stride, Int>>, + Shape>; +}; + +// Convenience aliases +template +using TagToLayout_t = typename TagToLayout::type; + +constexpr uint32_t ELE_NUM_PER_FRACTAL_L0C = 256; +using LayoutL0C = + Layout, uint32_t>, + Shape, uint32_t>>, + Stride, Int>, + Stride, int64_t>>, + Shape>; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::detail + +#endif // ACT_DETAIL_TAG_TO_LAYOUT_HPP diff --git a/act/epilogue/block/block_epilogue.hpp b/act/epilogue/block/block_epilogue.hpp new file mode 100644 index 00000000..f7057680 --- /dev/null +++ b/act/epilogue/block/block_epilogue.hpp @@ -0,0 +1,28 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP +#define ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Block { + +template class BlockEpilogue { + static_assert(DEPENDENT_FALSE, + "Could not find an epilogue specialization"); +}; + +} // namespace Act::Epilogue::Block + +#include "../../../act/epilogue/block/block_epilogue_per_token_dequant.hpp" +#endif // ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP diff --git a/act/epilogue/block/block_epilogue_per_token_dequant.hpp b/act/epilogue/block/block_epilogue_per_token_dequant.hpp new file mode 100644 index 00000000..ac21c634 --- /dev/null +++ b/act/epilogue/block/block_epilogue_per_token_dequant.hpp @@ -0,0 +1,867 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP +#define ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP + +#include "../../../../cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h" +#include "../../../act/act.hpp" +#include "../../../act/arch/resource.hpp" +#include "../../../act/detail/callback.hpp" +#include "../../../act/epilogue/dispatch_policy.hpp" +#include "../../../act/gemm_coord.hpp" +#include "../../../act/layout/layout.hpp" +#include "../../../act/matrix_coord.hpp" + +#define ENABLE_EP_SEND_COUNT_HASH 0 + +namespace Act::Epilogue::Block { + +template +class BlockEpilogue< + EpilogueAtlasA2PerTokenDequant, CType_, ScaleType_, + PerTokenScaleType_, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, + TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_> { +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = typename ScaleType_::Element; + using LayoutScale = typename ScaleType_::Layout; + using ElementPerTokenScale = typename PerTokenScaleType_::Element; + using LayoutPerTokenScale = typename PerTokenScaleType_::Layout; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert( + std::is_same_v && + (std::is_same_v || + std::is_same_v) && + std::is_same_v && + std::is_same_v, + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert( + TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + + TileShape::ROW) * + sizeof(float) + + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + ACT_DEVICE + Params() {}; + + ACT_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), ptrD(ptrD_), + layoutD(layoutD_) {} + }; + + ACT_DEVICE + BlockEpilogue(Arch::Resource const &resource, + Params const ¶ms = Params{}) + : params(params) { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = + resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = + resource.ubBuf.template GetBufferByByte( + ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag( + eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(float); + ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubPerTokenScaleFp32 = + resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(float); + ubPerTokenScaleFp32Brcb = + resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubPerTokenMul = ubMul; + } + + ACT_DEVICE + ~BlockEpilogue() { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag( + eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + ACT_DEVICE + void UpdateParams(Params const ¶ms_) { params = params_; } + + ACT_DEVICE + void operator()(GemmCoord const &blockShapeMNK, + GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, + AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, + Callback &&callback = Callback{}) { + if (actualBlockShapeMNK.k() == 0) { + return; + } + callback(); + + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; + loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag( + eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = + LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag( + eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag( + eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = + actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = + gmPerTokenScale[params.layoutPerTokenScale.GetOffset( + perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = + params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb( + perTokenScaleTileShape); + + AscendC::WaitFlag( + eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, + layoutUbPerTokenScale, layoutGmTilePerTokenScale); + AscendC::SetFlag( + eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag( + eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, + TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::WaitFlag( + eventUbScaleMTE2VList[ubListId]); + AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, + TileShape::COLUMN); + AscendC::SetFlag( + eventUbScaleVMTE2List[ubListId]); + + AscendC::WaitFlag( + eventUbPerTokenScaleMTE2VList[ubListId]); + AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, + AscendC::RoundMode::CAST_NONE, TileShape::ROW); + AscendC::SetFlag( + eventUbPerTokenScaleVMTE2List[ubListId]); + + tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32); + tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32); + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, + ubPerTokenScaleFp32Brcb); + AscendC::PipeBarrier(); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag( + eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, + TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag( + eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32; + AscendC::LocalTensor ubScaleFp32; + AscendC::LocalTensor ubMul; + AscendC::LocalTensor ubPerTokenScaleFp32; + AscendC::LocalTensor ubPerTokenScaleFp32Brcb; + AscendC::LocalTensor ubPerTokenMul; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +template +class BlockEpilogue, + CType_, Gemm::GemmType, + Gemm::GemmType, DType_, + TileRowBroadcastMul_, TileBroadcastOneBlk_, + TileOneBlkColumnBroadcastMul_, TileCopy_, + EpilogueTileSwizzle_> { +public: + using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert( + std::is_same_v && + (std::is_same_v || + std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && + std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert( + TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + ACT_DEVICE + Params() {}; + + ACT_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, + LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), ptrD(ptrD_), + layoutD(layoutD_) {} + }; + + ACT_DEVICE void AlignUbOffset() { + size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1); + if (ubMask != 0) { + ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask; + } + } + + ACT_DEVICE + BlockEpilogue(Arch::Resource &resource, + MoeDistributeCombineImpl::CombineCalcInfo &calcInfo, + Params const ¶ms = Params{}) + : resource(resource), calcInfo(calcInfo), params(params) { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = + resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = + resource.ubBuf.template GetBufferByByte( + ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag( + eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubPerTokenScaleBrcb = + resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubPerTokenMul = ubCFp32; + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AlignUbOffset(); + epSendCountLocal_ = + resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t); + AlignUbOffset(); + AscendC::GlobalTensor epSendCountGM; + epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_); + uint32_t epSendCountSize = calcInfo.isSharedExpert_ + ? calcInfo.epWorldSize_ + : calcInfo.moeSendNum_; + AscendC::DataCopyExtParams epSendCntParams = { + 1U, static_cast(epSendCountSize * sizeof(uint32_t)), 0U, 0U, + 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, + copyPadParams); + AscendC::SetFlag(eventMTE2S); + AscendC::WaitFlag(eventMTE2S); +#if ENABLE_EP_SEND_COUNT_HASH + tokenToEpRankHashLocal_ = + resource.ubBuf.template GetBufferByByte(ubOffset); + uint32_t maxGroupSendCount = 0; + uint32_t groupSendCount = 0; + for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; + ++expertIdx) { + uint32_t prevGroupSendCount = groupSendCount; + groupSendCount = epSendCountLocal_.GetValue( + (expertIdx + 1) * calcInfo.epWorldSize_ - 1); + if (maxGroupSendCount < groupSendCount - prevGroupSendCount) { + maxGroupSendCount = groupSendCount - prevGroupSendCount; + } + } + ubOffset += maxGroupSendCount * sizeof(int32_t); + AlignUbOffset(); + // assert: ubOffset <= AscendC::TOTAL_UB_SIZE or + // AscendC::TOTAL_VEC_LOCAL_SIZE +#endif + } + } + + ACT_DEVICE + ~BlockEpilogue() { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag( + eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } + + ACT_DEVICE + void UpdateParams(Params const ¶ms_) { params = params_; } + + ACT_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, + const uint8_t expertLocalId = 0U) { + return (GM_ADDR)((calcInfo.epRankId_ == rankId) + ? calcInfo.epWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(calcInfo.epWinContext_ + ->remoteRes[rankId] + .nextDevicePtr)) + ->windowsIn) + + calcInfo.winDataSizeOffset_ + + expertLocalId * calcInfo.expertPerSizeOnWin_ + + rankId * OPT_RANK_OFFSET; + } +#if ENABLE_EP_SEND_COUNT_HASH + ACT_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, + uint32_t epRank, + uint32_t copyLen) { + constexpr uint32_t DUPLICATE_MASK_COUNT = 8; + uint32_t hashOffsetMask = + (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1)); + if (hashOffsetMask != 0) { + uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask; + if (copyLen < remainMaskCount) { + remainMaskCount = copyLen; + } + uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask; + AscendC::Duplicate( + tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, + ©Mask, 1, 1, DUPLICATE_MASK_COUNT); + hashOffset += remainMaskCount; + copyLen -= remainMaskCount; + } + if (copyLen > 0) { + AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset], epRank, + copyLen); + hashOffset += copyLen; + } + } +#endif + + ACT_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, + uint32_t &localEpRank) { + if ((calcInfo.isSharedExpert_) && + (epRank < calcInfo.sharedExpertRankNum_)) { + remoteEpRank = calcInfo.epRankId_; + localEpRank = epRank; + } else { + remoteEpRank = epRank; + localEpRank = calcInfo.epRankId_; + } + } + + ACT_DEVICE void DoCombineSend(AscendC::LocalTensor &ubD, + layout::RowMajor &layoutGmTileD, + LayoutD &layoutUbD, int64_t groupOffsetD, + uint32_t expertIdx, uint32_t tileOffsetD) { + const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD); + const uint32_t copyTokenSrcStride = + (layoutUbD.stride(0) - layoutUbD.shape(1)) / + (BYTE_PER_C0 / sizeof(ElementD)); + const uint32_t copyTokenDstStride = + (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD); + + int64_t offsetD = groupOffsetD + tileOffsetD; + uint32_t startToken = offsetD / calcInfo.axisH_; + uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_; + uint32_t itToken = startToken; + uint32_t endToken = startToken + layoutGmTileD.shape(0); +#if ENABLE_EP_SEND_COUNT_HASH + uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken); +#else + constexpr uint32_t epRankStart = 0; +#endif + uint32_t sendCount = + expertIdx == 0 && epRankStart == 0 + ? 0 + : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1); + for (uint32_t epRank = epRankStart; + epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) { + uint32_t prevSendCount = sendCount; + sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); + if (prevSendCount <= itToken && itToken < sendCount) { + uint32_t copyTokenCount = + (sendCount < endToken ? sendCount : endToken) - itToken; + AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, + copyTokenSrcStride, + copyTokenDstStride, 0); + uint32_t remoteEpRank; + uint32_t localEpRank; + SetCombineSendEpRank(epRank, remoteEpRank, localEpRank); + GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) + + localEpRank * calcInfo.moeExpertPerRankNum_ * + calcInfo.expertPerSizeOnWin_; + AscendC::GlobalTensor rankWindow; + rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM); + AscendC::DataCopyPad( + rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + + tokenOffset], + ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams); + itToken += copyTokenCount; + } + } + } + + ACT_DEVICE + void operator()(int64_t groupOffsetD, uint32_t expertIdx, + GemmCoord const &blockShapeMNK, + GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, + AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, + Callback &&callback = Callback{}) { + if (actualBlockShapeMNK.k() == 0) { + return; + } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + expertOffset = expertIdx * calcInfo.epWorldSize_; +#if ENABLE_EP_SEND_COUNT_HASH + if (currentExpertIdx_ != expertIdx) { + uint32_t hashOffset = 0; + uint32_t sendCount = + expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1); + for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) { + uint32_t prevSendCount = sendCount; + sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); + InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, + sendCount - prevSendCount); + } + AscendC::SetFlag(eventVS); + AscendC::WaitFlag(eventVS); + currentExpertIdx_ = expertIdx; + } +#endif + } + + callback(); + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; + loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag( + eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = + LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag( + eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag( + eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = + actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = + gmPerTokenScale[params.layoutPerTokenScale.GetOffset( + perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = + params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb( + perTokenScaleTileShape); + + AscendC::WaitFlag( + eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, + layoutUbPerTokenScale, layoutGmTilePerTokenScale); + AscendC::SetFlag( + eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag( + eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, + TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::WaitFlag( + eventUbScaleMTE2VList[ubListId]); + tileRowBroadcastMul(ubMul, ubCFp32, ubScale); + AscendC::SetFlag( + eventUbScaleVMTE2List[ubListId]); + + AscendC::WaitFlag( + eventUbPerTokenScaleMTE2VList[ubListId]); + tileBroadcastOneBlk(ubPerTokenScaleBrcb, ubPerTokenScale); + AscendC::SetFlag( + eventUbPerTokenScaleVMTE2List[ubListId]); + + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleBrcb); + AscendC::PipeBarrier(); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag( + eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, + TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto tileOffsetD = params.layoutD.GetOffset(tileOffset); + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag( + eventUbDVMTE3List[ubListId]); + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, + tileOffsetD); + } else { + auto gmTileD = gmD[tileOffsetD]; + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + } + + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } + } + +private: + Params params; + Arch::Resource &resource; + MoeDistributeCombineImpl::CombineCalcInfo calcInfo; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + AscendC::LocalTensor epSendCountLocal_; +#if ENABLE_EP_SEND_COUNT_HASH + AscendC::LocalTensor tokenToEpRankHashLocal_; + uint32_t currentExpertIdx_{static_cast(-1)}; +#endif + + size_t ubOffset{0}; + int32_t eventVMTE2{0}; + int32_t eventMTE2V{0}; + int32_t eventMTE3V{0}; + int32_t eventVMTE3{0}; + int32_t eventVS{0}; + int32_t eventMTE2S{0}; + + uint32_t expertOffset; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32; + AscendC::LocalTensor ubMul; + AscendC::LocalTensor ubPerTokenScaleBrcb; + AscendC::LocalTensor ubPerTokenMul; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; +}; + +} // namespace Act::Epilogue::Block + +#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP diff --git a/act/epilogue/dispatch_policy.hpp b/act/epilogue/dispatch_policy.hpp new file mode 100644 index 00000000..0323c274 --- /dev/null +++ b/act/epilogue/dispatch_policy.hpp @@ -0,0 +1,75 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_DISPATCH_POLICY_HPP +#define ACT_EPILOGUE_DISPATCH_POLICY_HPP + +#include "../../act/arch/arch.hpp" + +namespace Act::Epilogue { + +// For AtlasA2, an element wise epilogue of the form D = C + X, where X is an +// additional source +struct EpilogueAtlasA2ElemWiseOneSource { + using ArchTag = Arch::AtlasA2; + // Number of operands. Including C, X, and D 3 operands + static constexpr uint32_t OPERANDS_NUM = 3; +}; + +// For AtlasA2, FA Softmax +struct EpilogueAtlasA2FASoftmax { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, FA RescaleO +struct EpilogueAtlasA2FARescaleO { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA Softmax +struct EpilogueAtlasA2MLASoftmax { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA RescaleO +struct EpilogueAtlasA2MLARescaleO { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA FD RescaleO +template struct EpilogueAtlasA2MLAFDRescaleO { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t KV_SPLIT_MAX = 64; + static constexpr uint32_t HEADS_PROCESS_MAX = 16; + static constexpr uint32_t COMPUTE_ELE_NUM = COMPUTE_ELE_NUM_; +}; + +// For AtlasA2, MLA TP1 Softmax +struct EpilogueAtlasA2MLATP1Softmax { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, MLA TP1 RescaleO +struct EpilogueAtlasA2MLATP1RescaleO { + using ArchTag = Arch::AtlasA2; +}; + +// For AtlasA2, per token dequant +template +struct EpilogueAtlasA2PerTokenDequant { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; +}; +} // namespace Act::Epilogue + +#endif // ACT_EPILOGUE_DISPATCH_POLICY_HPP diff --git a/act/epilogue/tile/copy_gm_to_ub.hpp b/act/epilogue/tile/copy_gm_to_ub.hpp new file mode 100644 index 00000000..ede41844 --- /dev/null +++ b/act/epilogue/tile/copy_gm_to_ub.hpp @@ -0,0 +1,174 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_COPY_GM_TO_UB_HPP +#define ACT_EPILOGUE_TILE_TILE_COPY_GM_TO_UB_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/gemm_type.hpp" +#include "../../../act/layout/layout.hpp" + +namespace Act::Epilogue::Tile { + +template struct CopyGm2Ub { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded copy gm to ub, can not find the specialization."); +}; + +template +struct CopyGm2Ub> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyGm2Ub() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + layout::RowMajor const &layoutDst, + layout::RowMajor const &layoutSrc) { + AscendC::DataCopyExtParams dataCopyParams( + layoutSrc.shape(0), layoutSrc.shape(1) * sizeof(Element), + (layoutSrc.stride(0) - layoutSrc.shape(1)) * sizeof(Element), + (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + }; +}; + +template +struct CopyGm2Ub> { + using LayoutSrc = layout::VectorLayout; + using LayoutDst = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyGm2Ub() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + layout::VectorLayout const &layoutDst, + layout::VectorLayout const &layoutSrc) { + AscendC::DataCopyExtParams dataCopyParams( + 1, layoutSrc.shape(0) * sizeof(Element), 0, 0, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + }; +}; + +/// @brief This copy instruction used to copy per token scale from GM to UB. +/// Copy the scale of shape (m,1) on GM to the first column of shape (m,n) on +/// UB, and pad the first block of each row (i.e. pad to shape (m,8) when +/// element type is float). +/// @tparam ArchTag: Architecture tag. +/// @tparam GmType: Type of data on GM. +template struct CopyPerTokenScale2Ub { + static_assert(std::is_same_v, + "Unsupporteded layout for CopyPerTokenScale2Ub."); + + using Element = typename GmType::Element; + using LayoutSrc = typename GmType::Layout; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyPerTokenScale2Ub() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::DataCopyExtParams dataCopyParams; + AscendC::DataCopyPadExtParams padParams; + + dataCopyParams.blockCount = layoutSrc.shape(0); + dataCopyParams.blockLen = + layoutSrc.shape(1) * + sizeof(Element); // per token scale has only one column + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = + (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + // Pad the data to the complete block + padParams.isPad = true; + padParams.leftPadding = 0; + padParams.rightPadding = 0; + + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + } +}; + +template struct CopyGm2UbAligned { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded copy gm to ub aligned, can not find the specialization."); +}; + +template +struct CopyGm2UbAligned> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + static constexpr uint32_t STRIDE_LIMIT = 65536; + + ACT_DEVICE + CopyGm2UbAligned() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + layout::RowMajor const &layoutDst, + layout::RowMajor const &layoutSrc) { + uint32_t rows = layoutSrc.shape(0); + uint32_t cols = layoutSrc.shape(1); + uint32_t srcStride = + (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = + (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && + (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && + (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = + (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, + srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], + dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], + srcTensor[i * layoutSrc.stride(0)], cols); + } + } + }; +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/act/epilogue/tile/copy_ub_to_gm.hpp b/act/epilogue/tile/copy_ub_to_gm.hpp new file mode 100644 index 00000000..2c584048 --- /dev/null +++ b/act/epilogue/tile/copy_ub_to_gm.hpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_COPY_UB_TO_GM_HPP +#define ACT_EPILOGUE_TILE_TILE_COPY_UB_TO_GM_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/gemm_type.hpp" +#include "../../../act/layout/layout.hpp" + +namespace Act::Epilogue::Tile { + +template struct CopyUb2Gm { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded copy ub to gm, can not find the specialization."); +}; + +template +struct CopyUb2Gm> { + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyUb2Gm() = default; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + layout::RowMajor const &layoutDst, + layout::RowMajor const &layoutSrc) { + AscendC::DataCopyExtParams dataCopyParams( + layoutDst.shape(0), layoutDst.shape(1) * sizeof(Element), + (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_C0, + (layoutDst.stride(0) - layoutDst.shape(1)) * sizeof(Element), 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); + } +}; + +// new add vectorlayout version +template +struct CopyUb2Gm> { + using LayoutSrc = layout::VectorLayout; + using LayoutDst = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyUb2Gm() = default; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + layout::VectorLayout const &layoutDst, + layout::VectorLayout const &layoutSrc) { + AscendC::DataCopyExtParams dataCopyParams( + 1, layoutDst.shape(0) * sizeof(Element), 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); + }; +}; + +template struct CopyUb2GmAligned { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded copy ub to gm aligned, can not find the specialization."); +}; + +template +struct CopyUb2GmAligned> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + static constexpr uint32_t STRIDE_LIMIT = 65536; + + ACT_DEVICE + CopyUb2GmAligned() = default; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + layout::RowMajor const &layoutDst, + layout::RowMajor const &layoutSrc) { + uint32_t rows = layoutDst.shape(0); + uint32_t cols = layoutDst.shape(1); + uint32_t srcStride = + (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = + (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && + (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && + (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = + (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, + srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], + dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], + srcTensor[i * layoutSrc.stride(0)], cols); + } + } + }; +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp b/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp new file mode 100644 index 00000000..da5eeaca --- /dev/null +++ b/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp @@ -0,0 +1,71 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_COLUMN_HPP +#define ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_COLUMN_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + class TileShape_> +struct TileBroadcastInplaceByColumn { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileBroadcastInplaceByColumn() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubInOut) { + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + constexpr uint32_t blkNumPerRow = TileShape::COLUMN / eleNumPerBlk; + + constexpr uint64_t defaultMask = + BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + constexpr uint64_t tailMask = + (TileShape::ROW % BLK_NUM_PER_VECTOR_FRACTAL) * eleNumPerBlk; + + constexpr uint8_t repeatTimes = 1; + + AscendC::CopyRepeatParams repeatParams; + repeatParams.dstStride = blkNumPerRow; + repeatParams.srcStride = blkNumPerRow; + repeatParams.dstRepeatSize = 1; + repeatParams.srcRepeatSize = 1; + + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; + rowOffset += BLK_NUM_PER_VECTOR_FRACTAL) { + uint64_t mask = + ((TileShape::ROW - rowOffset) >= BLK_NUM_PER_VECTOR_FRACTAL) + ? defaultMask + : tailMask; + for (uint32_t colOffset = eleNumPerBlk; colOffset < TileShape::COLUMN; + colOffset += eleNumPerBlk) { + AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN + colOffset], + ubInOut[rowOffset * TileShape::COLUMN], mask, 1, + repeatParams); + } + } + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp b/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp new file mode 100644 index 00000000..f507f94c --- /dev/null +++ b/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp @@ -0,0 +1,58 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_ROW_HPP +#define ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_ROW_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + class TileShape_> +struct TileBroadcastInplaceByRow { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileBroadcastInplaceByRow() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubInOut) { + constexpr uint32_t eleNumPerVectorFractal = + BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + + constexpr uint64_t mask = eleNumPerVectorFractal; + constexpr uint8_t repeatTimes = TileShape::COLUMN / eleNumPerVectorFractal; + + AscendC::CopyRepeatParams repeatParams; + repeatParams.dstStride = 1; + repeatParams.srcStride = 1; + repeatParams.dstRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; + repeatParams.srcRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; + + for (uint32_t rowOffset = 1; rowOffset < TileShape::ROW; ++rowOffset) { + AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN], ubInOut, mask, + repeatTimes, repeatParams); + } + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/act/epilogue/tile/tile_broadcast_mul.hpp b/act/epilogue/tile/tile_broadcast_mul.hpp new file mode 100644 index 00000000..9e31d69a --- /dev/null +++ b/act/epilogue/tile/tile_broadcast_mul.hpp @@ -0,0 +1,131 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_MUL_HPP +#define ACT_EPILOGUE_TILE_TILE_BROADCAST_MUL_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +/// BroadcastMul computes the elementwise multiplication of a tensor of shape +/// (m, n) and a tensor of shape (m, n) after broadcasting. There are two +/// broadcast modes: row-broadcast and column-broadcast. + +/// @brief Computes the elementwise multiplication of a tensor with shape (m, n) +/// and a tensor with original shape (1, n) broadcast to (m, n). +/// @tparam ArchTag_ is the architecture tag. +/// @tparam ComputeType_ includes the element type and layout information. +/// @tparam TileShape_ is the shape (m, n). +template +struct TileRowBroadcastMul { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileRowBroadcastMul() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) { + constexpr uint32_t maxRepeatTimes = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.src0BlkStride = 1; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = blkNumPerColumn; + repeatParams.src0RepStride = blkNumPerColumn; + repeatParams.src1RepStride = 0; + + constexpr uint32_t rowNumPerCompute = maxRepeatTimes; + constexpr uint32_t colNumPerCompute = + BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; + rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint8_t repeatTimes = static_cast( + (residueM > rowNumPerCompute) ? rowNumPerCompute : residueM); + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; + colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint64_t mask = + (residueN > colNumPerCompute) ? colNumPerCompute : residueN; + AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], + ubIn0[rowOffset * TileShape::COLUMN + colOffset], + ubIn1[colOffset], mask, repeatTimes, repeatParams); + } + } + } +}; + +/// @brief Compute the elementwise multiplication of a tensor of shape (m, n) +/// and a tensor of shape (m, eleNumPerBlk), which is broadcast from a tensor of +/// shape (m, 1), broadcast to (m, n). +/// @tparam ArchTag_ is the architecture tag. +/// @tparam ComputeType_ includes the element type and layout information. +/// @tparam TileShape_ is the shape (m, n). +template +struct TileOneBlkColumnBroadcastMul { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileOneBlkColumnBroadcastMul() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) { + constexpr uint32_t maxRepeatNum = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.dstBlkStride = blkNumPerColumn; + repeatParams.src0BlkStride = blkNumPerColumn; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = 1; + repeatParams.src0RepStride = 1; + repeatParams.src1RepStride = 0; + + constexpr uint32_t rowNumPerCompute = BLK_NUM_PER_VECTOR_FRACTAL; + constexpr uint32_t colNumPerCompute = eleNumPerBlk * maxRepeatNum; + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; + rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint64_t mask = + ((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM) * + eleNumPerBlk; + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; + colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint8_t repeatTimes = static_cast( + ((residueN > colNumPerCompute) ? colNumPerCompute : residueN) / + eleNumPerBlk); + AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], + ubIn0[rowOffset * TileShape::COLUMN + colOffset], + ubIn1[rowOffset * eleNumPerBlk], mask, repeatTimes, + repeatParams); + } + } + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/act/epilogue/tile/tile_broadcast_one_blk.hpp b/act/epilogue/tile/tile_broadcast_one_blk.hpp new file mode 100644 index 00000000..799a1bd1 --- /dev/null +++ b/act/epilogue/tile/tile_broadcast_one_blk.hpp @@ -0,0 +1,56 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_ONE_BLK_HPP +#define ACT_EPILOGUE_TILE_TILE_BROADCAST_ONE_BLK_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template +struct TileBroadcastOneBlk { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + ACT_DEVICE + TileBroadcastOneBlk() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn) { + constexpr uint32_t maxRepeatNum = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + AscendC::BrcbRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.dstRepStride = BLK_NUM_PER_VECTOR_FRACTAL; + + constexpr uint32_t eleNumPerCompute = + RoundDown(maxRepeatNum * BLK_NUM_PER_VECTOR_FRACTAL); + for (uint32_t offset = 0; offset < COMPUTE_LENGTH; + offset += eleNumPerCompute) { + uint32_t residueM = COMPUTE_LENGTH - offset; + uint32_t computeM = + (residueM > eleNumPerCompute) ? eleNumPerCompute : residueM; + uint8_t repeatTimes = + static_cast(CeilDiv(computeM)); + AscendC::Brcb(ubOut[offset * eleNumPerBlk], ubIn[offset], repeatTimes, + repeatParams); + } + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/act/epilogue/tile/tile_cast.hpp b/act/epilogue/tile/tile_cast.hpp new file mode 100644 index 00000000..c0fa588d --- /dev/null +++ b/act/epilogue/tile/tile_cast.hpp @@ -0,0 +1,45 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_CAST_HPP +#define ACT_EPILOGUE_TILE_TILE_CAST_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class DstType_, class SrcType_, + /// Length of the compute buffer + class TileShape_> +struct TileCast { + using ArchTag = ArchTag_; + using ElementDst = typename DstType_::Element; + using ElementSrc = typename SrcType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileCast() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn) { + AscendC::Cast(ubOut, ubIn, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/act/epilogue/tile/tile_copy.hpp b/act/epilogue/tile/tile_copy.hpp new file mode 100644 index 00000000..abc7c96f --- /dev/null +++ b/act/epilogue/tile/tile_copy.hpp @@ -0,0 +1,111 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_COPY_HPP +#define ACT_EPILOGUE_TILE_TILE_COPY_HPP + +#include "../../../act/epilogue/tile/copy_gm_to_ub.hpp" +#include "../../../act/epilogue/tile/copy_ub_to_gm.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag, class... Args> +struct TileCopy { + static_assert(DEPENDENT_FALSE, + "Unsupporteded tile copy, can not find the specialization."); +}; + +template +struct TileCopy { + using ElementC = typename CType::Element; + using ElementX = typename XType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub; + using CopyUbToGmD = CopyUb2Gm; + using CopyGmToUbY = CopyGm2Ub; + using CopyGmToUbTemp = CopyGm2Ub; + using CopyUbToGmZ = CopyUb2Gm; +}; + +template +struct TileCopy { + using ElementC = typename CType::Element; + using ElementX = typename XType::Element; + using ElementY = typename YType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub; + using CopyGmToUbY = CopyGm2Ub; + using CopyUbToGmD = CopyUb2Gm; +}; + +template +struct TileCopyBf16 { + using ElementC = typename CType::Element; + using ElementX = bfloat16_t; + using ElementY = bfloat16_t; + using ElementD = bfloat16_t; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = + CopyGm2Ub>; + using CopyGmToUbY = + CopyGm2Ub>; + using CopyUbToGmD = + CopyUb2Gm>; +}; + +template +struct TileCopyPerTokenDequant { + using ElementC = typename CType::Element; + using ElementScale = typename ScaleType::Element; + using ElementPerTokenScale = typename PerTokenScaleType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbScale = CopyGm2Ub; + using CopyGmToUbPerTokenScale = + CopyPerTokenScale2Ub; + using CopyUbToGmD = CopyUb2Gm; +}; + +template +struct TileCopyPerTokenDequantGemm { + using ElementX = typename XType::Element; + using ElementScale = typename ScaleType::Element; + using ElementPerTokenScale = typename PerTokenScaleType::Element; + using ElementBias = typename BiasType::Element; + using ElementC = typename CType::Element; + + using CopyGmToUbX = CopyGm2Ub; + using CopyGmToUbScale = CopyGm2Ub; + using CopyGmToUbPerTokenScale = CopyGm2Ub; + using CopyGmToUbBias = CopyGm2Ub; + using CopyUbToGmC = CopyUb2Gm; +}; + +} // namespace Act::Epilogue::Tile + +#endif // ACT_EPILOGUE_TILE_TILE_COPY_HPP diff --git a/act/epilogue/tile/tile_elemwise_add.hpp b/act/epilogue/tile/tile_elemwise_add.hpp new file mode 100644 index 00000000..047fefc6 --- /dev/null +++ b/act/epilogue/tile/tile_elemwise_add.hpp @@ -0,0 +1,47 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_ELEMWISE_ADD_HPP +#define ACT_EPILOGUE_TILE_TILE_ELEMWISE_ADD_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + uint32_t COMPUTE_LENGTH_> +struct TileElemWiseAdd { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + ACT_DEVICE + TileElemWiseAdd() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) { + // Do the calculation + AscendC::Add(ubOut, ubIn0, ubIn1, COMPUTE_LENGTH); + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/act/epilogue/tile/tile_elemwise_mul.hpp b/act/epilogue/tile/tile_elemwise_mul.hpp new file mode 100644 index 00000000..f79ea98e --- /dev/null +++ b/act/epilogue/tile/tile_elemwise_mul.hpp @@ -0,0 +1,46 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_ELEMWISE_MUL_HPP +#define ACT_EPILOGUE_TILE_TILE_ELEMWISE_MUL_HPP + +#include "../../../act/act.hpp" + +namespace Act::Epilogue::Tile { + +template < + /// Tag indicating architecture + class ArchTag_, + /// Compute data type + class ComputeType_, + /// Length of the compute buffer + class TileShape_> +struct TileElemwiseMul { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileElemwiseMul() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) { + // Do the calculation + AscendC::Mul(ubOut, ubIn0, ubIn1, TileShape::COUNT); + } +}; + +} // namespace Act::Epilogue::Tile + +#endif diff --git a/act/epilogue/tile/tile_elemwise_muls.hpp b/act/epilogue/tile/tile_elemwise_muls.hpp new file mode 100644 index 00000000..8af5d5c7 --- /dev/null +++ b/act/epilogue/tile/tile_elemwise_muls.hpp @@ -0,0 +1,38 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP +#define ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP + +#include "../../../act/gemm/helper.hpp" + +namespace Act::Epilogue::Tile { +template +struct TileElemWiseMuls { + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + ACT_DEVICE + TileElemWiseMuls() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstLocal, + AscendC::LocalTensor srcTensor, + ElementCompute scalar) { + AscendC::Muls(dstLocal, srcTensor, scalar, COMPUTE_LENGTH); + } +}; +} // namespace Act::Epilogue::Tile + +#endif // ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP diff --git a/act/epilogue/tile/tile_swizzle.hpp b/act/epilogue/tile/tile_swizzle.hpp new file mode 100644 index 00000000..13c05298 --- /dev/null +++ b/act/epilogue/tile/tile_swizzle.hpp @@ -0,0 +1,82 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP +#define ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP + +#include "../../../act/act.hpp" +#include "../../../act/detail/alignment.hpp" +#include "../../../act/matrix_coord.hpp" + +namespace Act::Epilogue::Tile { + +struct EpilogueIdentityTileSwizzle { + MatrixCoord blockShape; + MatrixCoord tileShape; + MatrixCoord loopsMN; + + ACT_DEVICE + EpilogueIdentityTileSwizzle() = default; + + ACT_DEVICE + EpilogueIdentityTileSwizzle(MatrixCoord const &blockShape, + MatrixCoord const &tileShape) + : blockShape(blockShape), tileShape(tileShape) { + loopsMN = CeilDiv(blockShape, tileShape); + } + + ACT_DEVICE + uint32_t GetLoops() const { return loopsMN.row() * loopsMN.column(); } + + ACT_DEVICE + MatrixCoord GetTileCoord(uint32_t loopIdx) const { + return MatrixCoord{loopIdx / loopsMN.column(), loopIdx % loopsMN.column()}; + } + + ACT_DEVICE + MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const { + return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); + } +}; + +struct EpilogueHorizontalTileSwizzle { + MatrixCoord blockShape; + MatrixCoord tileShape; + MatrixCoord loopsMN; + + ACT_DEVICE + EpilogueHorizontalTileSwizzle() = default; + + ACT_DEVICE + EpilogueHorizontalTileSwizzle(MatrixCoord const &blockShape, + MatrixCoord const &tileShape) + : blockShape(blockShape), tileShape(tileShape) { + loopsMN = CeilDiv(blockShape, tileShape); + } + + ACT_DEVICE + uint32_t GetLoops() const { return loopsMN.row() * loopsMN.column(); } + + ACT_DEVICE + MatrixCoord GetTileCoord(uint32_t loopIdx) const { + return MatrixCoord{loopIdx % loopsMN.row(), loopIdx / loopsMN.row()}; + } + + ACT_DEVICE + MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const { + return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); + } +}; + +} // namespace Act::Epilogue::Tile + +#endif // ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP diff --git a/act/gemm/block/block_mmad.hpp b/act/gemm/block/block_mmad.hpp new file mode 100644 index 00000000..0a8d8d8f --- /dev/null +++ b/act/gemm/block/block_mmad.hpp @@ -0,0 +1,67 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_BLOCK_BLOCK_MMAD_HPP +#define ACT_GEMM_BLOCK_BLOCK_MMAD_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/tile/tile_copy.hpp" +#include "../../../act/gemm/tile/tile_mmad.hpp" + +namespace Act::Gemm::Block { + +template , + class TileMmad = Gemm::Tile::TileMmad< + typename DispatchPolicy::ArchTag, AType, BType, BiasType>> +struct BlockMmad { + static_assert(DEPENDENT_FALSE, + "BlockMmad is not implemented for this DispatchPolicy"); +}; + +template , + class TileMmad = Gemm::Tile::TileMmadTla< + typename DispatchPolicy::ArchTag, typename TileCopy::TensorL0A, + typename TileCopy::TensorL0B, typename TileCopy::TensorL0C>> +struct BlockMmadTla { + static_assert(DEPENDENT_FALSE, + "BlockMmadTla is not implemented for this DispatchPolicy"); +}; + +/// new add for the reason that i am using the dispatchpolicy which is same as +/// the policy of the optimized_matmul +// so i add a new one class to avoid the conflict +template < + class DispatchPolicy, class L1TileShape, class L0TileShape, class AType, + class BType, class CType, class BiasType = void, + class TileCopy = + Gemm::Tile::TileCopyGemm, // change the name + class TileMmad = Gemm::Tile::TileMmad> +struct BlockGemm { + static_assert(DEPENDENT_FALSE, + "BlockMmad is not implemented for this DispatchPolicy"); +}; + +} // namespace Act::Gemm::Block + +#include "../../../act/gemm/block/block_mmad_preload_async_with_callback.hpp" + +#endif // ACT_GEMM_BLOCK_BLOCK_MMAD_HPP diff --git a/act/gemm/block/block_mmad_preload_async_with_callback.hpp b/act/gemm/block/block_mmad_preload_async_with_callback.hpp new file mode 100644 index 00000000..7b28f80b --- /dev/null +++ b/act/gemm/block/block_mmad_preload_async_with_callback.hpp @@ -0,0 +1,455 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP +#define ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP + +#include "../../../act/act.hpp" +#include "../../../act/arch/resource.hpp" +#include "../../../act/coord.hpp" +#include "../../../act/detail/callback.hpp" +#include "../../../act/gemm/dispatch_policy.hpp" +#include "../../../act/gemm/helper.hpp" +#include "../../../act/gemm_coord.hpp" + +namespace Act::Gemm::Block { + +template +struct BlockMmad, + L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, + TileCopy_, TileMmad_> { +public: + // Type Aliases + using DispatchPolicy = + MmadAtlasA2PreloadAsyncWithCallback; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< + ElementA, ElementB>::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; + static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = + L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = + L1TileShape::N * L1TileShape::K * sizeof(ElementB); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = + L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = + L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = + L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, + "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE, + "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, + "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, + "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, + "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && + L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on " + "the m and n axes is not supported yet"); + + static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout( + L1TileShape::M, L1TileShape::K); + static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout( + L1TileShape::K, L1TileShape::N); + + ACT_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) { + InitL1(resource, l1BufAddrStart); + InitL0A(resource); + InitL0B(resource); + InitL0C(resource); + } + + ACT_DEVICE + ~BlockMmad() { + SynchronizeBlock(); + for (uint32_t i = 0; i < L1_STAGES; ++i) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + } + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + AscendC::WaitFlag(l0CEventList[i]); + } + } + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockA, + LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, + LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutC, GemmCoord const &actualShape, + Callback const &callbackBeforeFixpipe, + Callback const &callbackAfterFixpipe) { + uint32_t kTileCount = CeilDiv(actualShape.k()); + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() % kTileCount; + } + + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { + uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) + ? (startTileIdx + kLoopIdx) + : (startTileIdx + kLoopIdx - kTileCount); + + uint32_t kActual = (kTileIdx < kTileCount - 1) + ? L1TileShape::K + : (actualShape.k() - kTileIdx * L1TileShape::K); + + // Emission load instruction from GM to L1 + MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = + layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = + layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + + // If the number of preload instructions reaches the upper limit, perform + // an mmad calculation on L1 tile + if (preloadCount == PRELOAD_STAGES) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + } + + // Store the current load status + uint32_t preloadL1TileMmadParamsId = + (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) + ? (l1TileMmadParamsId + preloadCount) + : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); + auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; + l1TileMmadParams.l1ListId = l1ListId; + l1TileMmadParams.mRound = mRound; + l1TileMmadParams.nRound = nRound; + l1TileMmadParams.kActual = kActual; + l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); + l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); + if (kLoopIdx == kTileCount - 1) { + l1TileMmadParams.gmBlockC = gmBlockC; + l1TileMmadParams.layoutCInGm = + layoutC.GetTileLayout(actualShape.GetCoordMN()); + l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe; + l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe; + } + + if (preloadCount < PRELOAD_STAGES) { + ++preloadCount; + } else { + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) + ? (l1TileMmadParamsId + 1) + : 0; + } + l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0; + } + } + + ACT_DEVICE + void SynchronizeBlock() { + while (preloadCount > 0) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) + ? (l1TileMmadParamsId + 1) + : 0; + --preloadCount; + } + } + +private: + struct L1TileMmadParams { + uint32_t l1ListId; + uint32_t mRound; + uint32_t nRound; + uint32_t kActual; + bool isKLoopFirst; + bool isKLoopLast; + AscendC::GlobalTensor gmBlockC; + LayoutC layoutCInGm; + Callback callbackBeforeFixpipe; + Callback callbackAfterFixpipe; + + ACT_DEVICE + L1TileMmadParams() = default; + }; + + ACT_DEVICE + void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES; + for (uint32_t i = 0; i < L1_STAGES; ++i) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte( + l1AOffset + L1A_TILE_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte( + l1BOffset + L1B_TILE_SIZE * i); + l1AEventList[i] = i; + l1BEventList[i] = i + L1_STAGES; + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + } + } + + ACT_DEVICE + void InitL0A(Arch::Resource &resource) { + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + l0ATensorList[i] = + resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } + } + + ACT_DEVICE + void InitL0B(Arch::Resource &resource) { + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + l0BTensorList[i] = + resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); + } + } + + ACT_DEVICE + void InitL0C(Arch::Resource &resource) { + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + l0CTensorList[i] = + resource.l0CBuf.template GetBufferByByte( + L0C_TILE_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); + } + } + + ACT_DEVICE + void L1TileMmad(L1TileMmadParams const ¶ms) { + uint32_t mPartLoop = CeilDiv(params.mRound); + uint32_t nPartLoop = CeilDiv(params.nRound); + uint32_t kPartLoop = CeilDiv(params.kActual); + auto &l1ATensor = l1ATensorList[params.l1ListId]; + auto &l1BTensor = l1BTensorList[params.l1ListId]; + + auto &l0CTensor = l0CTensorList[l0CListId]; + LayoutCInL0 layoutCInL0 = + LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); + + if constexpr (!ENABLE_UNIT_FLAG) { + if (params.isKLoopFirst) { + AscendC::WaitFlag(l0CEventList[l0CListId]); + } + } + + for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { + uint32_t mPartActual = (mPartIdx < mPartLoop - 1) + ? L0TileShape::M + : (params.mRound - mPartIdx * L0TileShape::M); + + for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) + ? L0TileShape::K + : (params.kActual - kPartIdx * L0TileShape::K); + + auto &l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = LayoutAInL0::template MakeLayout( + mPartActual, kPartActual); + auto l1AOffset = + MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); + auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag( + l1AEventList[params.l1ListId]); + } + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag( + l1AEventList[params.l1ListId]); + } + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) + ? L0TileShape::N + : (params.nRound - nPartIdx * L0TileShape::N); + + auto &l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = LayoutBInL0::template MakeLayout( + kPartActual, nPartActual); + auto l1BOffset = + MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); + auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; + + AscendC::WaitFlag( + l0BEventList[l0BListId]); + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag( + l1BEventList[params.l1ListId]); + } + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag( + l1BEventList[params.l1ListId]); + } + + AscendC::SetFlag(EVENT_ID0); + + auto l0COffset = + MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); + auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; + + AscendC::WaitFlag(EVENT_ID0); + // If the current tile is the first tile on the k axis, the + // accumulator needs to be reset to 0 + bool initC = (params.isKLoopFirst && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the + // calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && + (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, + kPartActual, initC, unitFlag); + + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; + } + } + + if (params.isKLoopLast) { + auto layoutCInGm = params.layoutCInGm; + + params.callbackBeforeFixpipe(); + + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(l0CEventList[l0CListId]); + AscendC::WaitFlag(l0CEventList[l0CListId]); + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); + AscendC::SetFlag(l0CEventList[l0CListId]); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); + } + l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; + + params.callbackAfterFixpipe(); + } + } + + AscendC::LocalTensor l1ATensorList[L1_STAGES]; + AscendC::LocalTensor l1BTensorList[L1_STAGES]; + int32_t l1AEventList[L1_STAGES]; + int32_t l1BEventList[L1_STAGES]; + uint32_t l1ListId{0}; + + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + uint32_t l0AListId{0}; + + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + uint32_t l0BListId{0}; + + AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; + int32_t l0CEventList[L0C_STAGES_]; + uint32_t l0CListId{0}; + + L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; + uint32_t l1TileMmadParamsId{0}; + uint32_t preloadCount{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; +}; + +} // namespace Act::Gemm::Block + +#endif // ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP diff --git a/act/gemm/block/block_swizzle.hpp b/act/gemm/block/block_swizzle.hpp new file mode 100644 index 00000000..81b3df23 --- /dev/null +++ b/act/gemm/block/block_swizzle.hpp @@ -0,0 +1,234 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP +#define ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP + +#include "../../../act/act.hpp" +#include "../../../act/detail/alignment.hpp" +#include "../../../act/gemm_coord.hpp" +#include "../../../act/matrix_coord.hpp" + +namespace Act::Gemm::Block { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Block swizzling function for Gemms +template +struct GemmIdentityBlockSwizzle { + /// Data members + + GemmCoord problemShape; + MatrixCoord tileMN; + MatrixCoord loopsMN; + + /// Methods + + ACT_DEVICE + GemmIdentityBlockSwizzle() {} + + ACT_DEVICE + GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, + MatrixCoord const &tileMN_) + : problemShape(problemShape_), tileMN(tileMN_) { + loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); + } + + ACT_DEVICE + GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, + MatrixCoord const &tileMN_, + MatrixCoord const &loopsMN_) + : problemShape(problemShape_), tileMN(tileMN_), loopsMN(loopsMN_) {} + + ACT_DEVICE + void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) { + problemShape = problemShape_; + tileMN = tileMN_; + + loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); + } + + ACT_DEVICE + void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, + MatrixCoord const &loopsMN_) { + problemShape = problemShape_; + tileMN = tileMN_; + loopsMN = loopsMN_; + } + + ACT_DEVICE + uint32_t GetCoreLoops() const { return loopsMN.row() * loopsMN.column(); } + + ACT_DEVICE + uint32_t GetBatchIdx(uint32_t taskIdx) { return taskIdx / (GetCoreLoops()); } + + ACT_DEVICE + GemmCoord GetBlockCoord(uint32_t taskIdx) { + uint32_t innerIdx = taskIdx % GetCoreLoops(); + if constexpr (SwizzleDirection == 0) { // Zn + uint32_t tileBlockLoop = CeilDiv(loopsMN.row(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.column()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.column()); + + uint32_t nRow = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nRow = loopsMN.row() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; + uint32_t nIdx = inTileBlockIdx / nRow; + if (tileBlockIdx % 2 == 1) { + nIdx = loopsMN.column() - nIdx - 1; + } + return GemmCoord{mIdx, nIdx, 0}; + } else if constexpr (SwizzleDirection == 1) { // Nz + uint32_t tileBlockLoop = CeilDiv(loopsMN.column(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.row()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.row()); + + uint32_t nCol = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nCol = loopsMN.column() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = inTileBlockIdx / nCol; + uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; + if (tileBlockIdx % 2 == 1) { + mIdx = loopsMN.row() - mIdx - 1; + } + return GemmCoord{mIdx, nIdx, 0}; + } + } + + ACT_DEVICE + GemmCoord GetActualBlockShape(GemmCoord blockCoord) { + uint32_t mActual = (blockCoord.m() == (loopsMN.row() - 1)) + ? (problemShape.m() - blockCoord.m() * tileMN.row()) + : tileMN.row(); + uint32_t nActual = + (blockCoord.n() == (loopsMN.column() - 1)) + ? (problemShape.n() - blockCoord.n() * tileMN.column()) + : tileMN.column(); + uint32_t kActual = problemShape.k(); + return GemmCoord{mActual, nActual, kActual}; + } +}; + +/// Block swizzling function for Splitk Gemms +template +struct SplitkGemmIdentityBlockSwizzle { + /// Data members + + GemmCoord problemShape; + GemmCoord tileShape; + GemmCoord loopsMNK; + uint32_t splitkFactor = 1; // split k dim into virtual cores + + /// Methods + + ACT_DEVICE + SplitkGemmIdentityBlockSwizzle() {} + + ACT_DEVICE + SplitkGemmIdentityBlockSwizzle(GemmCoord const &problemShape_, + GemmCoord const &tileShape_, + uint32_t splitkFactor_ = 1) + : problemShape(problemShape_), tileShape(tileShape_), + splitkFactor(splitkFactor_) { + loopsMNK = CeilDiv(problemShape, tileShape); + } + + ACT_DEVICE + uint32_t GetKIdxBySplitkSliceIdx(uint32_t splitkSliceIdx) const { + if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { + return (loopsMNK.k() / splitkFactor + 1) * splitkSliceIdx; + } else { + return splitkSliceIdx * (loopsMNK.k() / splitkFactor) + + loopsMNK.k() % splitkFactor; + } + } + + ACT_DEVICE + uint32_t GetSplitkSliceIdx(uint32_t taskIdx) const { + uint32_t mnLoops = loopsMNK.m() * loopsMNK.n(); + return taskIdx % GetCoreLoops() / mnLoops; + } + + ACT_DEVICE + uint32_t GetCoreLoops() const { + return loopsMNK.m() * loopsMNK.n() * splitkFactor; + } + + ACT_DEVICE + uint32_t GetBatchIdx(uint32_t taskIdx) { return taskIdx / GetCoreLoops(); } + + ACT_DEVICE + GemmCoord GetBlockCoord(uint32_t taskIdx) { + uint32_t splitkSliceIdx = GetSplitkSliceIdx(taskIdx); + uint32_t kIdx = GetKIdxBySplitkSliceIdx(splitkSliceIdx); + + uint32_t innerIdx = taskIdx % (loopsMNK.m() * loopsMNK.n()); + if constexpr (SwizzleDirection == 0) { // Zn + uint32_t tileBlockLoop = CeilDiv(loopsMNK.m(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.n()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.n()); + + uint32_t nRow = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nRow = loopsMNK.m() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; + uint32_t nIdx = inTileBlockIdx / nRow; + if (tileBlockIdx % 2 == 1) { + nIdx = loopsMNK.n() - nIdx - 1; + } + return GemmCoord{mIdx, nIdx, kIdx}; + } else if constexpr (SwizzleDirection == 1) { // Nz + uint32_t tileBlockLoop = CeilDiv(loopsMNK.n(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.m()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.m()); + + uint32_t nCol = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nCol = loopsMNK.n() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = inTileBlockIdx / nCol; + uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; + if (tileBlockIdx % 2 == 1) { + mIdx = loopsMNK.m() - mIdx - 1; + } + return GemmCoord{mIdx, nIdx, kIdx}; + } + } + + ACT_DEVICE + GemmCoord GetActualBlockShape(GemmCoord blockCoord, uint32_t splitkSliceIdx) { + uint32_t splitkSliceLen; + if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { + splitkSliceLen = (loopsMNK.k() / splitkFactor + 1) * tileShape.k(); + } else { + splitkSliceLen = (loopsMNK.k() / splitkFactor) * tileShape.k(); + } + uint32_t mActual = (blockCoord.m() == (loopsMNK.m() - 1)) + ? (problemShape.m() - blockCoord.m() * tileShape.m()) + : tileShape.m(); + uint32_t nActual = (blockCoord.n() == (loopsMNK.n() - 1)) + ? (problemShape.n() - blockCoord.n() * tileShape.n()) + : tileShape.n(); + uint32_t kActual = (splitkSliceIdx == (splitkFactor - 1)) + ? (problemShape.k() - blockCoord.k() * tileShape.k()) + : splitkSliceLen; + return GemmCoord{mActual, nActual, kActual}; + } +}; + +} // namespace Act::Gemm::Block + +#endif // ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP diff --git a/act/gemm/dispatch_policy.hpp b/act/gemm/dispatch_policy.hpp new file mode 100644 index 00000000..df0abfe2 --- /dev/null +++ b/act/gemm/dispatch_policy.hpp @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_DISPATCH_POLICY_HPP +#define ACT_GEMM_DISPATCH_POLICY_HPP + +#include "../../act/act.hpp" + +namespace Act::Gemm { + +// Block Mmad Policies + +template struct MmadAtlasA2Base { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t ASYNC = ASYNC_; +}; + +using MmadAtlasA2 = MmadAtlasA2Base; +using MmadAtlasA2Async = MmadAtlasA2Base; + +// Now ENABLE_UNIT_FLAG_ must be false when input element is int8 +template +struct MmadAtlasA2Pingpong : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; +}; + +template +struct MmadAtlasA2Preload : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; +}; + +struct MmadAtlasA2FAQK : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2FAPV : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAQK : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAPV : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAQKTp1Spec : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +struct MmadAtlasA2MLAPVTp1Spec : public MmadAtlasA2 { + static constexpr uint32_t STAGES = 2; +}; + +template +struct MmadAtlasA2PreloadAsync : public MmadAtlasA2Async { + static constexpr uint32_t PRELOAD_STAGES = + PRELOAD_STAGES_; // Stages of emitting load instruction in advance + static constexpr uint32_t L1_STAGES = L1_STAGES_; + static constexpr uint32_t L0A_STAGES = L0A_STAGES_; + static constexpr uint32_t L0B_STAGES = L0B_STAGES_; + static constexpr uint32_t L0C_STAGES = L0C_STAGES_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; +}; + +template +struct MmadAtlasA2PreloadAsyncWithCallback + : public MmadAtlasA2PreloadAsync {}; +} // namespace Act::Gemm + +#endif // ACT_GEMM_DISPATCH_POLICY_HPP diff --git a/act/gemm/gemm_type.hpp b/act/gemm/gemm_type.hpp new file mode 100644 index 00000000..6b71040f --- /dev/null +++ b/act/gemm/gemm_type.hpp @@ -0,0 +1,30 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_GEMM_TYPE_HPP +#define ACT_GEMM_GEMM_TYPE_HPP + +namespace Act::Gemm { + +//////////////////////////////////////////////////////////////////// + +template +struct GemmType { + using Element = Element_; + using Layout = Layout_; + static constexpr AscendC::TPosition POSITION = POSITION_; +}; + +} // namespace Act::Gemm + +#endif // ACT_GEMM_GEMM_TYPE_HPP diff --git a/act/gemm/helper.hpp b/act/gemm/helper.hpp new file mode 100644 index 00000000..bb448a8e --- /dev/null +++ b/act/gemm/helper.hpp @@ -0,0 +1,274 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_HELPER_HPP +#define ACT_GEMM_HELPER_HPP + +#include "../../act/act.hpp" +#include "../../act/layout/layout.hpp" +#include "../../tla/layout.hpp" + +namespace Act::Gemm::helper { + +template struct L1AlignHelper { + static_assert(DEPENDENT_FALSE, + "Unsupporteded align helper, can not find the specialization."); +}; + +template struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +template struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +template struct ElementAccumulatorSelector { + static_assert(DEPENDENT_FALSE, + "Unsupporteded element accumulator selector, can not find the " + "specialization."); +}; + +template <> struct ElementAccumulatorSelector { + using ElementAccumulator = float; +}; + +template <> struct ElementAccumulatorSelector { + using ElementAccumulator = float; +}; + +template <> struct ElementAccumulatorSelector { + using ElementAccumulator = int32_t; +}; + +template <> struct ElementAccumulatorSelector { + using ElementAccumulator = float; +}; + +template struct L1ATypeSelector { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded layout selector, can not find the specialization."); +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelector> { + using L1AType = Gemm::GemmType; +}; + +template struct L1BTypeSelector { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded layout selector, can not find the specialization."); +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelector> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1AlignHelperTla { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded align helper tla, can not find the specialization."); +}; + +template +struct L1AlignHelperTla< + Element, Layout, std::enable_if_t::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +}; + +template +struct L1AlignHelperTla< + Element, Layout, + std::enable_if_t::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +}; + +/////////////////////////////////////// +// new add +template struct L1ATypeSelectorGemm { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded layout selector, can not find the specialization."); +}; + +template +struct L1ATypeSelectorGemm> { + using L1AType = Gemm::GemmType; +}; + +template <> +struct L1ATypeSelectorGemm> { + using L1AType = Gemm::GemmType; +}; + +template +struct L1ATypeSelectorGemm> { + using L1AType = Gemm::GemmType; +}; + +template struct L1BTypeSelectorGemm { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded layout selector, can not find the specialization."); +}; + +template +struct L1BTypeSelectorGemm> { + using L1BType = Gemm::GemmType; +}; + +template <> +struct L1BTypeSelectorGemm> { + using L1BType = Gemm::GemmType; +}; + +template +struct L1BTypeSelectorGemm> { + using L1BType = Gemm::GemmType; +}; + +template struct L0ATypeSelector {}; + +template +struct L0ATypeSelector> { + using L0AType = Gemm::GemmType; +}; + +template +struct L0ATypeSelector> { + using L0AType = Gemm::GemmType; +}; + +template <> struct L0ATypeSelector> { + using L0AType = Gemm::GemmType; +}; + +template struct L0BTypeSelectorGemm {}; + +template +struct L0BTypeSelectorGemm> { + using L0BType = Gemm::GemmType; +}; + +template <> struct L0BTypeSelectorGemm> { + using L0BType = Gemm::GemmType; +}; + +template +struct L0BTypeSelectorGemm> { + using L0BType = Gemm::GemmType; +}; + +template struct L0BTypeSelectorGemv {}; + +template +struct L0BTypeSelectorGemv> { + using L0BType = Gemm::GemmType; +}; + +template +struct L0BTypeSelectorGemv> { + using L0BType = Gemm::GemmType; +}; + +template <> struct L0BTypeSelectorGemv> { + using L0BType = Gemm::GemmType; +}; +} // namespace Act::Gemm::helper + +#endif // ACT_GEMM_HELPER_HPP diff --git a/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp b/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp new file mode 100644 index 00000000..baf7c00d --- /dev/null +++ b/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp @@ -0,0 +1,388 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP +#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP + +#include "../../../../cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h" +#include "../../../act/act.hpp" +#include "../../../act/arch/cross_core_sync.hpp" +#include "../../../act/arch/resource.hpp" +#include "../../../act/coord.hpp" +#include "../../../act/detail/callback.hpp" +#include "../../../act/gemm_coord.hpp" +#include "../../../act/matrix_coord.hpp" + +namespace Act::Gemm::Kernel { + +template +class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace { +public: + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + void *combiner; + + // Methods + ACT_DEVICE + Params() {} + + ACT_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, + GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, + LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, + LayoutD layoutD_, GM_ADDR ptrWorkspace_, void *combiner_) + : problemShape(problemShape_), problemCount(problemCount_), + ptrGroupList( + reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>( + ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_), combiner(combiner_) {} + }; + + // Methods + ACT_DEVICE + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace() { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; + } + } + + template + ACT_DEVICE void operator()(Params const ¶ms); + + template <> ACT_DEVICE void operator()(Params const ¶ms) { + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer( + reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, + L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - + groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), + params.problemShape.k()}; + + LayoutA layoutA = + params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, + MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current + // groupIdx + uint32_t startLoopIdx = + ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - + startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; + loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = + blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, + blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, + blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, + gmB[gmGroupOffsetB + gmOffsetB], layoutB, gmC[gmOffsetC], + layoutC, actualBlockShape, callbackBeforeFixpipe, + callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, + gmB[gmGroupOffsetB + gmOffsetB], layoutB, gmC[gmOffsetC], + layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } + + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) + : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } + } + + template <> ACT_DEVICE void operator()(Params const ¶ms) { + auto *combiner = + (MoeDistributeCombineImpl::CamMoeDistributeCombine + *)params.combiner; + { + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>( + MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID); + } + } + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo()); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer( + reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{ + L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) + ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - + groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), + params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout( + inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = + params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); + + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = + ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - + startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; + loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = + blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, + 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = + layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, + actualBlockShapeMNK, gmBlockC, layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>( + flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + icache_preload(4); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->AllToAllSend(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->ReducePermute(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->Process(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } + +private: + friend struct AicWaitFunc; + friend struct AicSetFunc; + + struct AicWaitFunc { + using MatmulKernel = GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, + WORKSPACE_STAGES, ElementGroupList>; + + ACT_DEVICE + AicWaitFunc() = default; + + ACT_DEVICE + void operator()() const { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + struct AicSetFunc { + using MatmulKernel = GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace< + TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, + WORKSPACE_STAGES, ElementGroupList>; + + ACT_DEVICE + AicSetFunc() = default; + + ACT_DEVICE + void operator()() const { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>( + ptr->flagAicFinishStoreList[stageId]); + } + + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; + + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; +}; + +} // namespace Act::Gemm::Kernel + +#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP diff --git a/act/gemm/tile/copy_gm_to_l1.hpp b/act/gemm/tile/copy_gm_to_l1.hpp new file mode 100644 index 00000000..ba5e8207 --- /dev/null +++ b/act/gemm/tile/copy_gm_to_l1.hpp @@ -0,0 +1,856 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_GM_TO_L1_HPP +#define ACT_GEMM_TILE_COPY_GM_TO_L1_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/gemm_type.hpp" +#include "../../../act/layout/layout.hpp" +#include "../../../tla/tensor.hpp" + +using namespace tla; + +namespace Act::Gemm::Tile { + +template struct CopyGmToL1 { + static_assert(DEPENDENT_FALSE, + "Unsupported copy gm to l1, can not find the specialization."); +}; + +/// Partial specialization for AtlasA2, half, RowMajor in and zN out. +/// Matrix A confirm +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(0) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], + srcTensor[i * layoutSrc.stride(0)], intriParams); + } + } + } +}; + +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::Nd2NzParams intriParams; + uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(0); + uint32_t ndNum = layoutSrc.shape(0) / C0_NUM_PER_FRACTAL; + uint32_t remains = layoutSrc.shape(0) % C0_NUM_PER_FRACTAL; + if (srcNdStride < STRIDE_LIMIT) { + if (ndNum) { + intriParams.ndNum = ndNum; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = srcNdStride; + intriParams.srcDValue = layoutSrc.stride(0); + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + + intriParams.dstNzMatrixStride = layoutDst.stride(1); + + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } + + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(1); + tailParams.srcNdMatrixStride = srcNdStride; + tailParams.srcDValue = layoutSrc.stride(0); + + tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; //` + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], + srcTensor[ndNum * srcNdStride], tailParams); + } + } else if (layoutSrc.stride(0) < STRIDE_LIMIT) { + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = layoutSrc.stride(0); + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[i * layoutDst.stride(1)], + srcTensor[i * srcNdStride], intriParams); + } + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(1); + tailParams.srcNdMatrixStride = 0; + tailParams.srcDValue = layoutSrc.stride(0); + + tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], + srcTensor[ndNum * srcNdStride], tailParams); + } + } else { + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; + uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = 0; + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = 0; + intriParams.dstNzMatrixStride = 0; + + uint32_t offsetDst = + i * idxR0 * layoutDst.stride(1) + idxInR0 * ELE_NUM_PER_C0; + uint32_t offsetSrc = i * layoutSrc.stride(0); + AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], + intriParams); + } + } + } +}; + +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::nN; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::Nd2NzParams intriParams; + uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(1); + uint32_t ndNum = layoutSrc.shape(1) / C0_NUM_PER_FRACTAL; + uint32_t remains = layoutSrc.shape(1) % C0_NUM_PER_FRACTAL; + if (srcNdStride < STRIDE_LIMIT) { + if (ndNum) { + intriParams.ndNum = ndNum; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = srcNdStride; + intriParams.srcDValue = layoutSrc.stride(1); + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + + intriParams.dstNzMatrixStride = layoutDst.stride(3); + + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } + + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(0); + tailParams.srcNdMatrixStride = srcNdStride; + tailParams.srcDValue = layoutSrc.stride(1); + + tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], + srcTensor[ndNum * srcNdStride], tailParams); + } + } else if (layoutSrc.stride(1) < STRIDE_LIMIT) { + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = layoutSrc.stride(1); + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[i * layoutDst.stride(3)], + srcTensor[i * srcNdStride], intriParams); + } + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(0); + tailParams.srcNdMatrixStride = 0; + tailParams.srcDValue = layoutSrc.stride(1); + + tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], + srcTensor[ndNum * srcNdStride], tailParams); + } + } else { + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; + uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = 0; + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = 0; + intriParams.dstNzMatrixStride = 0; + + uint32_t offsetDst = + i * idxR0 * layoutDst.stride(3) + idxInR0 * ELE_NUM_PER_C0; + uint32_t offsetSrc = i * layoutSrc.stride(1); + AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], + intriParams); + } + } + } +}; + +template +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(1) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(1); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], + srcTensor[i * layoutSrc.stride(1)], intriParams); + } + } + } +}; + +/// Partial specialization for AtlasA2, RowMajor in and zN out. +template +struct CopyGmToL1> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(0) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], + srcTensor[i * layoutSrc.stride(0)], intriParams); + } + } + } + + // layoutSrc must be the layout of one of the src matrices + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc, + uint32_t ndNum, uint32_t srcNdMatrixStride, + uint32_t dstNzNStride, uint32_t dstNzMatrixStride, + uint32_t dstNzC0Stride) { + AscendC::Nd2NzParams intriParams; + + intriParams.nValue = layoutSrc.shape(0); + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = dstNzNStride; + intriParams.dstNzC0Stride = dstNzC0Stride; + if (srcNdMatrixStride < STRIDE_LIMIT) { + intriParams.ndNum = ndNum; + intriParams.srcNdMatrixStride = srcNdMatrixStride; + intriParams.dstNzMatrixStride = dstNzMatrixStride; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.ndNum = 1; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzMatrixStride = 0; + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], + srcTensor[i * srcNdMatrixStride], intriParams); + } + } + } +}; + +/// Partial specialization for AtlasA2, ColumnMajor in and nZ out. +template +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(1) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(1); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], + srcTensor[i * layoutSrc.stride(1)], intriParams); + } + } + } +}; + +/// Partial specialization for zN in and zN out. +template +struct CopyGmToL1> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + uint32_t blockCount = CeilDiv(layoutSrc.orgShape(1)); + uint32_t blockLen = RoundUp(layoutSrc.orgShape(0)); + + AscendC::DataCopyParams repeatParams; + + if (layoutSrc.stride(3) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { + repeatParams.blockCount = blockCount; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_C0 - blockLen; + repeatParams.dstStride = layoutDst.stride(3) / ELE_NUM_PER_C0 - blockLen; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } else { + repeatParams.blockCount = 1; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = 0; + repeatParams.dstStride = 0; + for (uint32_t i = 0; i < blockCount; i++) { + uint64_t dstOffset = i * layoutDst.stride(3); + uint64_t srcOffset = i * layoutSrc.stride(3); + AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], + repeatParams); + } + } + } +}; + +/// Partial specialization for nZ in and nZ out. +template +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + uint32_t blockCount = CeilDiv(layoutSrc.orgShape(0)); + uint32_t blockLen = RoundUp(layoutSrc.orgShape(1)); + + AscendC::DataCopyParams repeatParams; + + if (layoutSrc.stride(1) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { + repeatParams.blockCount = blockCount; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_C0 - blockLen; + repeatParams.dstStride = layoutDst.stride(1) / ELE_NUM_PER_C0 - blockLen; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } else { + repeatParams.blockCount = 1; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = 0; + repeatParams.dstStride = 0; + for (uint32_t i = 0; i < blockCount; i++) { + uint64_t dstOffset = i * layoutDst.stride(1); + uint64_t srcOffset = i * layoutSrc.stride(1); + AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], + repeatParams); + } + } + } +}; + +/// Partial specialization for AtlasA2, PaddingRowMajor in and zN out. +template +struct CopyGmToL1> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::PaddingRowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.orgShape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = layoutSrc.orgShape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } +}; + +/// Partial specialization for AtlasA2, ColumnMajor in and nZ out. +template +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::PaddingColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.orgShape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = layoutSrc.orgShape(1); + intriParams.srcDValue = layoutSrc.stride(2); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } +}; + +/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. +template +struct CopyGmToL1< + Arch::AtlasA2, Gemm::GemmType, + Gemm::GemmType> { + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + uint32_t rows = layoutSrc.shape(0); + uint32_t cols = layoutSrc.shape(1); + uint32_t srcStride = + (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = + (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && + (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && + (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = + (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, + srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], + dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], + srcTensor[i * layoutSrc.stride(0)], cols); + } + } + } +}; + +///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// +/// Partial specialization for CopyGmToL1, AtlasA2, RowMajor in and zN out. +template +struct TileCopyTla< + Arch::AtlasA2, + Tensor, LayoutSrc_, + AscendC::TPosition::GM>, + Tensor, LayoutDst_, + AscendC::TPosition::A1>, + std::enable_if_t::value && + tla::detail::iszN::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + const uint32_t nValue = get<0>(srcTensor.shape()); + const uint32_t dValue = get<1>(srcTensor.shape()); + const uint32_t srcDValue = get<0>(srcTensor.stride()); + const uint32_t dstInnerStrideRow = get<0, 0>(dstTensor.stride()); + const uint32_t dstOuterStrideCol = get<1, 1>(dstTensor.stride()); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = dValue; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (srcDValue < STRIDE_LIMIT) { + intriParams.nValue = nValue; + intriParams.srcDValue = srcDValue; + intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < nValue; i++) { + AscendC::DataCopy(dstTensor.data()[i * ELE_NUM_PER_C0], + srcTensor.data()[i * srcDValue], intriParams); + } + } + } +}; + +/// Partial specialization for CopyGmToL1, AtlasA2, ColumnMajor in and nZ out. +template +struct TileCopyTla< + Arch::AtlasA2, + Tensor, LayoutSrc_, + AscendC::TPosition::GM>, + Tensor, LayoutDst_, + AscendC::TPosition::A1>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + const uint32_t nValue = get<1>(srcTensor.shape()); + const uint32_t dValue = get<0>(srcTensor.shape()); + const uint32_t srcDValue = get<1>(srcTensor.stride()); + const uint32_t dstInnerStrideRow = get<1, 0>(dstTensor.stride()); + const uint32_t dstOuterStrideCol = get<0, 1>(dstTensor.stride()); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = dValue; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (srcDValue < STRIDE_LIMIT) { + intriParams.nValue = nValue; + intriParams.srcDValue = srcDValue; + intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < nValue; i++) { + AscendC::DataCopy(dstTensor.data()[i * ELE_NUM_PER_C0], + srcTensor.data()[i * srcDValue], intriParams); + } + } + } +}; + +/// Partial specialization for CopyGmToL1, AtlasA2, PaddingRowMajor in and zN +/// out. +template +struct TileCopyTlaExt, LayoutSrc_, + AscendC::TPosition::GM>, + Tensor, LayoutDst_, + AscendC::TPosition::A1>, + layout::PaddingRowMajor, layout::zN> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTlaExt() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = get<1>(srcTensor.orgShape()); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = get<1, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = get<0>(srcTensor.orgShape()); + intriParams.srcDValue = get<0, 0>(srcTensor.stride()); + intriParams.dstNzNStride = get<0, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } +}; + +/// Partial specialization for TileCopyTlaExt, CopyGmToL1, AtlasA2, +/// PaddingColumnMajor in and nZ out. +template +struct TileCopyTlaExt, LayoutSrc_, + AscendC::TPosition::GM>, + Tensor, LayoutDst_, + AscendC::TPosition::A1>, + layout::PaddingColumnMajor, layout::nZ> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTlaExt() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = get<0>(srcTensor.orgShape()); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = get<0, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = get<1>(srcTensor.orgShape()); + intriParams.srcDValue = get<1, 0>(srcTensor.stride()); + intriParams.dstNzNStride = get<1, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_GM_TO_L1_HPP diff --git a/act/gemm/tile/copy_gm_to_ub.hpp b/act/gemm/tile/copy_gm_to_ub.hpp new file mode 100644 index 00000000..6e690d94 --- /dev/null +++ b/act/gemm/tile/copy_gm_to_ub.hpp @@ -0,0 +1,64 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_GM_TO_UB_HPP +#define ACT_GEMM_TILE_COPY_GM_TO_UB_HPP + +#include "../../../act/act.hpp" +#include "../../../tla/tensor.hpp" + +namespace Act::Gemm::Tile { + +/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. +template +struct TileCopyTla< + Arch::AtlasA2, + Tensor, LayoutSrc_, + AscendC::TPosition::GM>, + Tensor, LayoutDst_, + AscendC::TPosition::VECCALC>, + std::enable_if_t::value && + tla::detail::isRowMajor::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::VECCALC>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + AscendC::DataCopyExtParams dataCopyParams( + get<0>(srcTensor.shape()), + get<1>(srcTensor.shape()) * sizeof(ElementSrc), + (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) * + sizeof(ElementSrc), + (get<0>(dstTensor.stride()) - get<1>(dstTensor.shape())) / + ELE_NUM_PER_BLK, + 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams, + padParams); + }; +}; + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_GM_TO_UB_HPP diff --git a/act/gemm/tile/copy_l0c_to_gm.hpp b/act/gemm/tile/copy_l0c_to_gm.hpp new file mode 100644 index 00000000..534af09a --- /dev/null +++ b/act/gemm/tile/copy_l0c_to_gm.hpp @@ -0,0 +1,257 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP +#define ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP + +#include "../../../act/gemm/gemm_type.hpp" + +namespace Act::Gemm::Tile { + +enum class ScaleGranularity { + UNDEFINED = -1, + NO_QUANT = 0, + PER_TENSOR, + PER_CHANNEL, + PER_GROUP +}; + +template +struct CopyL0CToGmQuantMode { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded copy l0c to gm, can not find the specialization."); +}; + +// CopyL0CToGm cast fp32 to fp16 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::F322F16; +}; + +// CopyL0CToGm cast fp32 to bf16 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::F322BF16; +}; + +// CopyL0CToGm output fp32 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::NoQuant; +}; + +// CopyL0CToGm output int32 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::NoQuant; +}; + +// CopyL0CToGm cast int32_t to fp16 +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::DEQF16; +}; + +template <> +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::VDEQF16; +}; + +template +struct CopyL0CToGm { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded copy l0c to gm, can not find the specialization."); +}; + +template +struct CopyL0CToGm, + ScaleGranularity::NO_QUANT, ReluEnable_> { + using ArchTag = Act::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Act::layout::zN; + using LayoutDst = Act::layout::RowMajor; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dst, + AscendC::LocalTensor const &src, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, + uint8_t unitFlag = 0) { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = dstLayout.shape(1); + intriParams.mSize = dstLayout.shape(0); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0); + intriParams.dstStride = dstLayout.stride(0); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe( + dst, src, intriParams); + } +}; + +template +struct CopyL0CToGm, + ScaleGranularity::NO_QUANT, ReluEnable_> { + using ArchTag = Act::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Act::layout::zN; + using LayoutDst = Act::layout::ColumnMajor; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementDst); + + ACT_DEVICE + CopyL0CToGm() {} + + ACT_DEVICE + void operator()(AscendC::GlobalTensor dstTensor, + AscendC::LocalTensor srcTensor, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, + uint8_t unitFlag = 0) { + AscendC::DataCopyCO12DstParams params; + + params.nSize = dstLayout.shape(0); + params.mSize = dstLayout.shape(1); + params.dstStride = dstLayout.stride(1); + params.srcStride = srcLayout.shape(2) * srcLayout.shape(3); + params.quantPre = quantPre; + params.reluPre = 0; + params.channelSplit = false; + params.nz2ndEn = true; + AscendC::DataCopy(dstTensor, srcTensor, params); + } +}; + +template +struct CopyL0CToGm, + ScaleGranularity::NO_QUANT, ReluEnable_> { + using ArchTag = Act::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Act::layout::zN; + using LayoutDst = Act::layout::zN; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dst, + AscendC::LocalTensor const &src, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, + uint8_t unitFlag = 0) { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = dstLayout.shape(2) * dstLayout.shape(3); + intriParams.mSize = dstLayout.shape(0) * dstLayout.shape(1); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.shape(2); + intriParams.dstStride = + dstLayout.stride(3) / (BYTE_PER_C0 / sizeof(ElementDst)); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dst, src, + intriParams); + } +}; + +///////////////////////////////////////////CopyL0CToGmTla///////////////////////////////////////////////// +template +struct CopyL0CToGmTla { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded copy l0c to gm, can not find the specialization."); +}; + +template +struct CopyL0CToGmTla< + Act::Arch::AtlasA2, TensorSrc_, + Tensor, LayoutDst_, + AscendC::TPosition::GM>, + ScaleGranularity::NO_QUANT, ReluEnable_, + std::enable_if_t::value>> { + using ArchTag = Act::Arch::AtlasA2; + using TensorDst = Tensor, LayoutDst_, + AscendC::TPosition::GM>; + using ElementDst = ElementDst_; + using TensorSrc = TensorSrc_; + using ElementSrc = typename TensorSrc::Element; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor, + uint8_t unitFlag = 0) { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = get<1>(dstTensor.shape()); + intriParams.mSize = get<0>(dstTensor.shape()); + intriParams.srcStride = + get<1, 1>(srcTensor.stride()) / get<0, 0>(srcTensor.stride()); + intriParams.dstStride = get<0>(dstTensor.stride()); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe( + dstTensor.data(), srcTensor.data(), intriParams); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP diff --git a/act/gemm/tile/copy_l1_to_l0a.hpp b/act/gemm/tile/copy_l1_to_l0a.hpp new file mode 100644 index 00000000..ba88499b --- /dev/null +++ b/act/gemm/tile/copy_l1_to_l0a.hpp @@ -0,0 +1,438 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP +#define ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/gemm_type.hpp" +#include "../../../act/layout/layout.hpp" +#include "../../../tla/tensor.hpp" + +using namespace tla; + +namespace Act::Gemm::Tile { + +template struct CopyL1ToL0A { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded copy l1 to l0, can not find the specialization."); +}; + +//////////////////////////////// +/// new add gemm +template +struct CopyL1ToL0A, + Act::Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, + AscendC::LocalTensor srcTensor, LayoutDst layoutDst, + LayoutSrc layoutSrc) { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], + srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A, + Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, + AscendC::LocalTensor srcTensor, LayoutDst layoutDst, + LayoutSrc layoutSrc) { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(1)); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutSrc.stride(3)], + srcTensor[i * layoutSrc.stride(3)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A, + Act::Gemm::GemmType> { + using Element = float; + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, + AscendC::LocalTensor srcTensor, LayoutDst layoutDst, + LayoutSrc layoutSrc) { + AscendC::LoadData2dTransposeParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(1) / 2); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = + static_cast(layoutSrc.shape(1) / 2) - 1; + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutSrc.stride(3)], + srcTensor[i * layoutSrc.stride(3)], + loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A, + Act::Gemm::GemmType> { + using Element = int8_t; + using LayoutDst = layout::zN; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, + AscendC::LocalTensor srcTensor, LayoutDst layoutDst, + LayoutSrc layoutSrc) { + uint32_t MRound = layoutSrc.shape(0) * layoutSrc.shape(1); + uint32_t KRound = layoutSrc.shape(2) * layoutSrc.shape(3); + uint32_t KL0Alignment = C0_NUM_PER_FRACTAL * 2; + uint32_t KLoops = CeilDiv(KRound, KL0Alignment); + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(MRound / ELE_NUM_PER_C0); + loadDataParams.srcStride = static_cast(KRound / KL0Alignment); + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < KLoops; i++) { + AscendC::LoadDataWithTranspose( + dstTensor[i * MRound * KL0Alignment], + srcTensor[i * KL0Alignment * ELE_NUM_PER_C0], loadDataParams); + } + } +}; +////////////////////////////////////////// + +/// Partial specialization for zN in and zZ out. +template +struct CopyL1ToL0A< + ArchTag, Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0A() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], + srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0A< + ArchTag, Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast( + CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); + i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], + srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +/// Partial specialization for int8_t, nZ in and zZ out. (Transpose A) +template +struct CopyL1ToL0A> { + using Element = int8_t; + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0A() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = + static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = + CeilDiv(layoutDst.orgShape(1)) - 1; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); + i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1) * 2], + srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } + } +}; + +///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// + +/// Partial specialization for CopyL1ToL0A, AtlasA2, zN in and zZ out. +template +struct TileCopyTla< + Arch::AtlasA2, + Tensor, LayoutSrc_, + AscendC::TPosition::A1>, + Tensor, LayoutDst_, + AscendC::TPosition::A2>, + std::enable_if_t::value && + tla::detail::iszN::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::A2>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], + srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0A, AtlasA2, nZ in and zZ out. +/// (Transpose A) +template +struct TileCopyTla< + Arch::AtlasA2, + Tensor, LayoutSrc_, + AscendC::TPosition::A1>, + Tensor, LayoutDst_, + AscendC::TPosition::A2>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::A2>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], + srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0A, AtlasA2, int8_t, nZ in and zZ out. +/// (Transpose A) +template +struct TileCopyTla< + Arch::AtlasA2, + Tensor, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::A2>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + using Element = int8_t; + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = + Tensor, LayoutDst, AscendC::TPosition::A2>; + using TensorSrc = + Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + const uint32_t srcOuterShapeRow = get<0, 1>(srcTensor.shape()); + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = dstOuterShapeCol - 1; + + for (uint32_t i = 0; i < srcOuterShapeRow; i++) { + AscendC::LoadDataWithTranspose( + dstTensor.data()[i * dstOuterStrideRow * 2], + srcTensor.data()[i * srcOuterStrideRow], loadDataParams); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP diff --git a/act/gemm/tile/copy_l1_to_l0b.hpp b/act/gemm/tile/copy_l1_to_l0b.hpp new file mode 100644 index 00000000..f9778dcf --- /dev/null +++ b/act/gemm/tile/copy_l1_to_l0b.hpp @@ -0,0 +1,598 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP +#define ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/gemm_type.hpp" +#include "../../../act/layout/layout.hpp" +#include "../../../tla/tensor.hpp" + +using namespace tla; + +namespace Act::Gemm::Tile { + +template struct CopyL1ToL0B { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded copy l1 to l0, can not find the specialization."); +}; + +//////////////////////////////////////// +/// new add gemm +template +struct CopyL1ToL0B, + Act::Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, + AscendC::LocalTensor srcTensor, LayoutDst layoutDst, + LayoutSrc layoutSrc) { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(3)); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + for (uint32_t i = 0; i < layoutDst.shape(3); i++) { // K N + AscendC::LoadData(dstTensor[i * layoutSrc.stride(1)], + srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, + Act::Gemm::GemmType> { + using Element = float; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, + AscendC::LocalTensor srcTensor, LayoutDst layoutDst, + LayoutSrc layoutSrc) { + AscendC::LoadData2dTransposeParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(3) / 2); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = + static_cast(layoutSrc.shape(3) / 2) - 1; + for (uint32_t i = 0; i < layoutDst.shape(3); i++) { // K N + AscendC::LoadDataWithTranspose(dstTensor[i * layoutSrc.stride(1)], + srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, + Act::Gemm::GemmType> { + using Element = int8_t; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, + AscendC::LocalTensor srcTensor, LayoutDst layoutDst, + LayoutSrc layoutSrc) { + uint32_t NRound = layoutSrc.shape(2) * layoutSrc.shape(3); + uint32_t KRound = layoutSrc.shape(0) * layoutSrc.shape(1); + uint32_t KL0Alignment = C0_NUM_PER_FRACTAL * 2; + uint32_t KLoops = CeilDiv(KRound, KL0Alignment); + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(NRound / ELE_NUM_PER_C0); + loadDataParams.srcStride = static_cast(KRound / KL0Alignment); + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < KLoops; i++) { + AscendC::LoadDataWithTranspose( + dstTensor[i * NRound * KL0Alignment], + srcTensor[i * KL0Alignment * ELE_NUM_PER_C0], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, + Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(3); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], + srcTensor[i * layoutSrc.stride(3)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, + Act::Gemm::GemmType> { + using LayoutDst = layout::nN; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, + AscendC::LocalTensor srcTensor, LayoutDst layoutDst, + LayoutSrc layoutSrc) { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); + loadDataParams.srcStride = layoutSrc.shape(3); + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutSrc.shape(3); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], + srcTensor[i * layoutSrc.stride(3)], loadDataParams); + } + } +}; + +template +struct CopyL1ToL0B, + Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = layoutDst.shape(1) * layoutDst.shape(3); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + }; +}; + +template +struct CopyL1ToL0B, + Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + using Element = float; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast( + CeilDiv(layoutDst.orgShape(0))); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = + CeilDiv(layoutDst.orgShape(0)) - 1; + + for (uint32_t i = 0; i < CeilDiv<2 * ELE_NUM_PER_C0>(layoutDst.orgShape(1)); + i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3) * 2], + srcTensor[i * layoutSrc.stride(3)], + loadDataParams); + } + }; +}; + +template +struct CopyL1ToL0B, + Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nZ; + using Element = int8_t; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = + static_cast(CeilDiv(layoutDst.orgShape(0))); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(1)); + i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3)], + srcTensor[i * layoutSrc.stride(3) * 2], + loadDataParams); + } + } +}; +//////////////////////////////////////////// + +/// Partial specialization for int8_t, zN in and nZ out. +template +struct CopyL1ToL0B> { + using Element = int8_t; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = + static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); + i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1)], + srcTensor[i * layoutSrc.stride(1) * 2], + loadDataParams); + } + } +}; + +/// Partial specialization for zN in and nZ out. +template +struct CopyL1ToL0B< + ArchTag, Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = + static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); + i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], + srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } +}; + +/// Partial specialization for nZ in and nZ out. (Transpose B) +template +struct CopyL1ToL0B< + ArchTag, Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, + AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { + AscendC::LoadData2DParams loadDataParams; + if (layoutSrc.shape(3) == layoutDst.shape(3)) { + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = + static_cast(layoutDst.shape(1) * layoutDst.shape(3)); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + } else { + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], + srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } + } + } +}; + +///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// +/// Partial specialization for CopyL1ToL0B, AtlasA2, zN in and nZ out. +template +struct TileCopyTla< + Arch::AtlasA2, + Tensor, LayoutSrc_, + AscendC::TPosition::A1>, + Tensor, LayoutDst_, + AscendC::TPosition::B2>, + std::enable_if_t::value && + tla::detail::iszN::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::B2>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], + srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0B, AtlasA2, nZ in and nZ out. +/// (Transpose B) +template +struct TileCopyTla< + Arch::AtlasA2, + Tensor, LayoutSrc_, + AscendC::TPosition::A1>, + Tensor, LayoutDst_, + AscendC::TPosition::B2>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::B2>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], + srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } + } +}; + +/// Partial specialization for CopyL1ToL0B, AtlasA2, int8_t, zN in and nZ out. +template +struct TileCopyTla< + Arch::AtlasA2, + Tensor, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::B2>, + std::enable_if_t::value && + tla::detail::iszN::value>> { + using Element = int8_t; + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = + Tensor, LayoutDst, AscendC::TPosition::B2>; + using TensorSrc = + Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + const uint32_t srcOuterShapeCol = get<1, 1>(srcTensor.shape()); + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = srcOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadDataWithTranspose( + dstTensor.data()[i * dstOuterStrideRow], + srcTensor.data()[i * srcOuterStrideRow * 2], loadDataParams); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP diff --git a/act/gemm/tile/copy_ub_to_gm.hpp b/act/gemm/tile/copy_ub_to_gm.hpp new file mode 100644 index 00000000..2ab44a84 --- /dev/null +++ b/act/gemm/tile/copy_ub_to_gm.hpp @@ -0,0 +1,99 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_COPY_UB_TO_GM_HPP +#define ACT_GEMM_TILE_COPY_UB_TO_GM_HPP + +#include "../../../act/act.hpp" +#include "../../../tla/tensor.hpp" + +namespace Act::Gemm::Tile { + +/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. +template +struct TileCopyTla< + Arch::AtlasA2, + Tensor, LayoutSrc_, + AscendC::TPosition::VECCALC>, + Tensor, LayoutDst_, + AscendC::TPosition::GM>, + std::enable_if_t::value && + tla::detail::isRowMajor::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::GM>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::VECCALC>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + AscendC::DataCopyExtParams dataCopyParams( + get<0>(dstTensor.shape()), + get<1>(dstTensor.shape()) * sizeof(ElementSrc), + (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) / + ELE_NUM_PER_C0, + (get<0>(dstTensor.stride()) - get<1>(dstTensor.shape())) * + sizeof(ElementSrc), + 0); + AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams); + }; +}; + +/// Partial specialization for AtlasA2, RowMajor in and PaddingRowMajor out. +template +struct TileCopyTlaExt, LayoutSrc_, + AscendC::TPosition::VECCALC>, + Tensor, LayoutDst_, + AscendC::TPosition::GM>, + layout::RowMajor, layout::PaddingRowMajor> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, + AscendC::TPosition::GM>; + using TensorSrc = Tensor, LayoutSrc, + AscendC::TPosition::VECCALC>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTlaExt() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { + AscendC::DataCopyExtParams dataCopyParams( + get<1, 1>(dstTensor.shape()), + get<1, 0>(dstTensor.shape()) * sizeof(ElementSrc), + (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) / + ELE_NUM_PER_C0, + (get<1, 1>(dstTensor.stride()) - get<1, 0>(dstTensor.shape())) * + sizeof(ElementSrc), + 0); + AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams); + }; +}; + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_COPY_UB_TO_GM_HPP diff --git a/act/gemm/tile/tile_copy.hpp b/act/gemm/tile/tile_copy.hpp new file mode 100644 index 00000000..c9b9b69f --- /dev/null +++ b/act/gemm/tile/tile_copy.hpp @@ -0,0 +1,214 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_TILE_COPY_HPP +#define ACT_GEMM_TILE_TILE_COPY_HPP + +#include "../../../act/act.hpp" +#include "../../../act/detail/tag_to_layout.hpp" + +namespace Act::Gemm::Tile { + +template +struct TileCopyTla { + static_assert(DEPENDENT_FALSE, + "Unsupporteded tileCopyTla, can not find the specialization."); +}; + +template +struct TileCopyTlaExt { + static_assert( + DEPENDENT_FALSE, + "Unsupporteded tileCopyTlaExt, can not find the specialization."); +}; +} // namespace Act::Gemm::Tile + +#include "../../../act/gemm/helper.hpp" +#include "../../../act/gemm/tile/copy_gm_to_l1.hpp" +#include "../../../act/gemm/tile/copy_gm_to_ub.hpp" +#include "../../../act/gemm/tile/copy_l0c_to_gm.hpp" +#include "../../../act/gemm/tile/copy_l1_to_l0a.hpp" +#include "../../../act/gemm/tile/copy_l1_to_l0b.hpp" +#include "../../../act/gemm/tile/copy_ub_to_gm.hpp" + +namespace Act::Gemm::Tile { + +template < + /// Tag indicating architecture + class ArchTag, + /// GemmType for A matrix operand + class AType, + /// GemmType type for B matrix operand + class BType, + /// GemmType type for C matrix operand + class CType, + /// GemmType type for Bias operand + class BiasType = void> +struct TileCopy { + using ElementA = typename AType::Element; + using ElementB = typename BType::Element; + using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< + ElementA, ElementB>::ElementAccumulator; + + using CopyGmToL1A = Gemm::Tile::CopyGmToL1; + using CopyGmToL1B = Gemm::Tile::CopyGmToL1; + using CopyL1ToL0A = + Gemm::Tile::CopyL1ToL0A::L1AType>; + using CopyL1ToL0B = + Gemm::Tile::CopyL1ToL0B::L1BType>; + using CopyL0CToGm = + Gemm::Tile::CopyL0CToGm; +}; + +/// new add +template < + /// Tag indicating architecture + class ArchTag, + /// GemmType for A matrix operand + class AType, + /// GemmType type for B matrix operand + class BType, + /// GemmType type for C matrix operand + class CType, + /// GemmTpe type for Bias operand + class BiasType = void> +struct TileCopyGemm { + using ElementA = typename AType::Element; + using ElementB = typename BType::Element; + using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< + ElementA, ElementB>::ElementAccumulator; + // change structural + using L1AType = typename helper::L1ATypeSelectorGemm::L1AType; + using L1BType = typename helper::L1BTypeSelectorGemm::L1BType; + using L0AType = typename helper::L0ATypeSelector::L0AType; + using L0BType = typename helper::L0BTypeSelectorGemm::L0BType; + + using CopyGmToL1A = Gemm::Tile::CopyGmToL1; + using CopyGmToL1B = Gemm::Tile::CopyGmToL1; + using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A; + using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B; + using CopyL0CToGm = + Gemm::Tile::CopyL0CToGm; +}; + +template < + /// Tag indicating architecture + class ArchTag, class TensorA, class LayoutTagA, class TensorB, + class LayoutTagB, class TensorC, class LayoutTagC, class TensorBias = void, + class LayoutTagBias = void> +struct PackedTileCopyTla { + using ElementA = typename TensorA::Element; + using ElementB = typename TensorB::Element; + using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< + ElementA, ElementB>::ElementAccumulator; + + using LayoutL1A = detail::TagToLayout_t< + ElementA, typename helper::L1ATypeSelector< + Gemm::GemmType>::L1AType::Layout>; + using LayoutL1B = detail::TagToLayout_t< + ElementB, typename helper::L1BTypeSelector< + Gemm::GemmType>::L1BType::Layout>; + using LayoutL0A = detail::TagToLayout_t; + using LayoutL0B = detail::TagToLayout_t; + using LayoutL0C = typename detail::LayoutL0C; + + using TensorL1A = + Tensor, LayoutL1A, AscendC::TPosition::A1>; + using TensorL1B = + Tensor, LayoutL1B, AscendC::TPosition::A1>; + using TensorL0A = + Tensor, LayoutL0A, AscendC::TPosition::A2>; + using TensorL0B = + Tensor, LayoutL0B, AscendC::TPosition::B2>; + using TensorL0C = Tensor, LayoutL0C, + AscendC::TPosition::CO1>; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + using CopyGmToL1A = Gemm::Tile::TileCopyTla; + using CopyGmToL1B = Gemm::Tile::TileCopyTla; + using CopyL1ToL0A = Gemm::Tile::TileCopyTla; + using CopyL1ToL0B = Gemm::Tile::TileCopyTla; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; +}; + +template < + /// Tag indicating architecture + class ArchTag, class TensorA, class LayoutTagA, class TensorB, + class LayoutTagB, class TensorC, class LayoutTagC, class TensorBias = void, + class LayoutTagBias = void, bool IS_PADDING_A = false, + bool IS_PADDING_B = false> +struct PaddingPackedTileCopyTla { + static_assert(std::is_same_v || + std::is_same_v, + "Unsupporteded layout, only can be RowMajor and ColumnMajor"); + static_assert(std::is_same_v || + std::is_same_v, + "Unsupporteded layout, only can be RowMajor and ColumnMajor"); + using ElementA = typename TensorA::Element; + using ElementB = typename TensorB::Element; + using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< + ElementA, ElementB>::ElementAccumulator; + + using LayoutTagL1A = typename helper::L1ATypeSelector< + Gemm::GemmType>::L1AType::Layout; + using LayoutTagL1B = typename helper::L1BTypeSelector< + Gemm::GemmType>::L1BType::Layout; + using LayoutL1A = detail::TagToLayout_t; + using LayoutL1B = detail::TagToLayout_t; + using LayoutL0A = detail::TagToLayout_t; + using LayoutL0B = detail::TagToLayout_t; + using LayoutL0C = typename detail::LayoutL0C; + + using TensorL1A = + Tensor, LayoutL1A, AscendC::TPosition::A1>; + using TensorL1B = + Tensor, LayoutL1B, AscendC::TPosition::A1>; + using TensorL0A = + Tensor, LayoutL0A, AscendC::TPosition::A2>; + using TensorL0B = + Tensor, LayoutL0B, AscendC::TPosition::B2>; + using TensorL0C = Tensor, LayoutL0C, + AscendC::TPosition::CO1>; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + using LayoutPaddingTagA = + std::conditional_t, + layout::PaddingRowMajor, layout::PaddingColumnMajor>; + using LayoutPaddingTagB = + std::conditional_t, + layout::PaddingRowMajor, layout::PaddingColumnMajor>; + + using CopyGmToL1A = std::conditional_t< + IS_PADDING_A, + Gemm::Tile::TileCopyTlaExt, + Gemm::Tile::TileCopyTla>; + using CopyGmToL1B = std::conditional_t< + IS_PADDING_B, + Gemm::Tile::TileCopyTlaExt, + Gemm::Tile::TileCopyTla>; + + using CopyL1ToL0A = Gemm::Tile::TileCopyTla; + using CopyL1ToL0B = Gemm::Tile::TileCopyTla; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; +}; +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_TILE_COPY_HPP diff --git a/act/gemm/tile/tile_mmad.hpp b/act/gemm/tile/tile_mmad.hpp new file mode 100644 index 00000000..44824087 --- /dev/null +++ b/act/gemm/tile/tile_mmad.hpp @@ -0,0 +1,114 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_TILE_TILE_MMAD_HPP +#define ACT_GEMM_TILE_TILE_MMAD_HPP + +#include "../../../act/act.hpp" +#include "../../../act/gemm/helper.hpp" +namespace Act::Gemm::Tile { + +/////////////////////////////////////////////////////////// + +template < + /// Tag indicating architecture + class ArchTag_, + /// GemmType for A matrix operand + class AType_, + /// GemmType type for B matrix operand + class BType_, + /// GemmType type for Bias operand + class BiasType_> +struct TileMmad { + using ElementA = typename AType_::Element; + using ElementB = typename BType_::Element; + using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< + ElementA, ElementB>::ElementAccumulator; + + // Methods + + ACT_DEVICE + TileMmad() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &l0CTensor, + AscendC::LocalTensor const &l0ATensor, + AscendC::LocalTensor const &l0BTensor, uint32_t m, + uint32_t n, uint32_t k, bool initC = true, + uint8_t unitFlag = 0) { + AscendC::MmadParams mmadParams; + mmadParams.m = m; + mmadParams.n = n; + mmadParams.k = k; + mmadParams.unitFlag = unitFlag; + mmadParams.cmatrixInitVal = initC; + + AscendC::Mmad(l0CTensor, l0ATensor, l0BTensor, mmadParams); + + const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; + if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < + PIPE_M_BARRIER_THRESHOLD) { + AscendC::PipeBarrier(); + } + } +}; + +///////////////////////////////////////////TileMmadTla///////////////////////////////////////////////// + +template < + /// Tag indicating architecture + class ArchTag_, + /// Tensor type for A matrix operand + class TensorA, + /// Tensor type for B matrix operand + class TensorB, + /// Tensor type for C matrix operand + class TensorC, + /// Tensor type for Bias operand + class TensorBias = void> +struct TileMmadTla { + // Methods + + ACT_DEVICE + TileMmadTla() {} + + ACT_DEVICE + void operator()(TensorC const &l0CTensor, TensorA const &l0ATensor, + TensorB const &l0BTensor, bool initC = true, + uint8_t unitFlag = 0) { + const uint32_t m = get<0>(l0ATensor.orgShape()); + const uint32_t n = get<1>(l0BTensor.orgShape()); + const uint32_t k = get<1>(l0ATensor.orgShape()); + + AscendC::MmadParams mmadParams; + mmadParams.m = m; + mmadParams.n = n; + mmadParams.k = k; + mmadParams.unitFlag = unitFlag; + mmadParams.cmatrixInitVal = initC; + + AscendC::Mmad(l0CTensor.data(), l0ATensor.data(), l0BTensor.data(), + mmadParams); + + const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; + if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < + PIPE_M_BARRIER_THRESHOLD) { + AscendC::PipeBarrier(); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace Act::Gemm::Tile + +#endif // ACT_GEMM_TILE_TILE_MMAD_HPP diff --git a/act/gemm_coord.hpp b/act/gemm_coord.hpp new file mode 100644 index 00000000..6eb6f83a --- /dev/null +++ b/act/gemm_coord.hpp @@ -0,0 +1,120 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMM_COORD_HPP +#define ACT_GEMM_COORD_HPP + +#include "../act/coord.hpp" + +namespace Act { + +/// Shape of a matrix multiply-add operation +template < + /// Rows of matrix product + uint32_t M_ = 1, + /// Columns of matrix product + uint32_t N_ = 1, + /// Inner dimension of matrix product + uint32_t K_ = 1> +struct GemmShape { + static constexpr uint32_t M = M_; + static constexpr uint32_t N = N_; + static constexpr uint32_t K = K_; + + static constexpr int64_t MN = M * N; + static constexpr int64_t MK = M * K; + static constexpr int64_t KN = N * K; + static constexpr int64_t MNK = M * N * K; + + static constexpr int64_t COUNT = MNK; + + /// Returns a Coord object + ACT_HOST_DEVICE + static Coord<3> ToCoord() { return MakeCoord(M, N, K); } + + ACT_HOST_DEVICE + static Coord<2> ToCoordMN() { return MakeCoord(M, N); } + + ACT_HOST_DEVICE + static Coord<2> ToCoordMK() { return MakeCoord(M, K); } + + ACT_HOST_DEVICE + static Coord<2> ToCoordKN() { return MakeCoord(K, N); } +}; + +/// GemmCoord is a structure derived from Coord<3> that specifies a location +/// within the coordinate space of a Gemm problem. +struct GemmCoord : public Coord<3, uint32_t> { + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=3 + using Base = Coord<3, Index>; + + /// Gemm M dimension - rows of the output C matrix + static constexpr int M_INDEX = 0; + + /// Gemm N dimension - columns of the output C matrix + static constexpr int N_INDEX = 1; + + /// Gemm K dimension - inner dimension of the Gemm problem + static constexpr int K_INDEX = 2; + + /// Default ctor + ACT_HOST_DEVICE + GemmCoord() {} + + /// Constructs from Coord<3> and a batch + ACT_HOST_DEVICE + GemmCoord(Coord<3, Index> const &coord) : Base(coord) {} + + /// Helper to construct from a K, N, M, batch variables + ACT_HOST_DEVICE + GemmCoord(Index m, Index n, Index k) : Base(MakeCoord(m, n, k)) {} + + /// Returns the Gemm M coordinate + ACT_HOST_DEVICE + Index const &m() const { return this->At(M_INDEX); } + + /// Returns reference to the Gemm M coordinate + ACT_HOST_DEVICE + Index &m() { return this->At(M_INDEX); } + + /// Returns the Gemm N coordinate + ACT_HOST_DEVICE + Index const &n() const { return this->At(N_INDEX); } + + /// Returns reference to the Gemm N coordinate + ACT_HOST_DEVICE + Index &n() { return this->At(N_INDEX); } + + /// Returns the Gemm K coordinate + ACT_HOST_DEVICE + Index const &k() const { return this->At(K_INDEX); } + + /// Returns reference to the Gemm K coordinate + ACT_HOST_DEVICE + Index &k() { return this->At(K_INDEX); } + + ACT_HOST_DEVICE + auto GetCoordMN() const { return this->GetCoordByAxis(); } + + ACT_HOST_DEVICE + auto GetCoordMK() const { return this->GetCoordByAxis(); } + + ACT_HOST_DEVICE + auto GetCoordKN() const { return this->GetCoordByAxis(); } +}; + +} // namespace Act + +#endif // ACT_GEMM_COORD_HPP diff --git a/act/gemv_coord.hpp b/act/gemv_coord.hpp new file mode 100644 index 00000000..08af1180 --- /dev/null +++ b/act/gemv_coord.hpp @@ -0,0 +1,89 @@ +/* + * Copyright (c) 2024 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_GEMV_COORD_HPP +#define ACT_GEMV_COORD_HPP + +#include "../act/coord.hpp" + +namespace Act { + +/// Shape of a matrix multiply-add operation +template < + /// Rows of matrix product + uint32_t M_ = 1, + /// Columns of the matrix (number of elements in the input vector) + uint32_t N_ = 1> +struct GemvShape { + static constexpr uint32_t M = M_; + static constexpr uint32_t N = N_; + + static constexpr int64_t MN = M * N; + + static constexpr int64_t COUNT = MN; + + /// Returns a Coord object + ACT_HOST_DEVICE + static Coord<2> ToCoord() { return MakeCoord(M, N); } +}; + +/// GemvCoord is a structure derived from Coord<2> that specifies a location +/// within the coordinate space of a GEMV problem. +struct GemvCoord : public Coord<2, uint32_t> { + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=2 + using Base = Coord<2, Index>; + + /// GEMV M dimension - rows of the output vector (y) + static constexpr int M_INDEX = 0; + + /// GEMV N dimension - columns of the matrix (length of the input vector x) + static constexpr int N_INDEX = 1; + + /// Default ctor + ACT_HOST_DEVICE + GemvCoord() {} + + /// Constructs from Coord<2> and a batch + ACT_HOST_DEVICE + GemvCoord(Coord<2, Index> const &coord) : Base(coord) {} + + /// Helper to construct from M, N coordinates + ACT_HOST_DEVICE + GemvCoord(Index m, Index n) : Base(MakeCoord(m, n)) {} + + /// Returns the GEMV M coordinate (row of the result y) + ACT_HOST_DEVICE + Index const &m() const { return this->At(M_INDEX); } + + /// Returns reference to the GEMV M coordinate + ACT_HOST_DEVICE + Index &m() { return this->At(M_INDEX); } + + /// Returns the GEMV N coordinate (column of the matrix A or the input vector + /// x) + ACT_HOST_DEVICE + Index const &n() const { return this->At(N_INDEX); } + + /// Returns reference to the GEMV N coordinate + ACT_HOST_DEVICE + Index &n() { return this->At(N_INDEX); } + + ACT_HOST_DEVICE + auto GetCoordMN() const { return this->GetCoordByAxis(); } +}; + +} // namespace Act + +#endif // ACT_GEMV_COORD_HPP diff --git a/act/layout/layout.hpp b/act/layout/layout.hpp new file mode 100644 index 00000000..981f0d33 --- /dev/null +++ b/act/layout/layout.hpp @@ -0,0 +1,20 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_LAYOUT_LAYOUT_HPP +#define ACT_LAYOUT_LAYOUT_HPP + +#include "../../act/act.hpp" +#include "../../act/layout/matrix.hpp" +#include "../../act/layout/vector.hpp" + +#endif // ACT_LAYOUT_LAYOUT_HPP diff --git a/act/layout/matrix.hpp b/act/layout/matrix.hpp new file mode 100644 index 00000000..2035ca9a --- /dev/null +++ b/act/layout/matrix.hpp @@ -0,0 +1,982 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_LAYOUT_MATRIX_HPP +#define ACT_LAYOUT_MATRIX_HPP + +#include "../../act/act.hpp" +#include "../../act/coord.hpp" +#include "../../act/detail/alignment.hpp" +#include "../../act/matrix_coord.hpp" + +namespace Act::layout { + +/// Mapping function for row-major matrices +struct RowMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + /// Constructor + ACT_HOST_DEVICE + RowMajor(Index rows = 0, Index cols = 0) + : shape_(MakeCoord(rows, cols)), + stride_(MakeCoord(LongIndex(cols), LongIndex(1))) {} + + /// Constructor + ACT_HOST_DEVICE + RowMajor(Index rows, Index cols, LongIndex ldm) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(ldm, LongIndex(1))) {} + + /// Ctor + ACT_HOST_DEVICE + RowMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + template + ACT_HOST_DEVICE static RowMajor MakeLayoutInUb(MatrixCoord const &shape) { + return RowMajor(shape.row(), shape.column(), + RoundUp(shape.column())); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const { + return LongIndex(coord.row()) * stride_[0] + LongIndex(coord.column()); + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + RowMajor GetTileLayout(MatrixCoord const &tileShape) const { + return RowMajor(tileShape, stride()); + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const { return shape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) { return shape_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const { return stride_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) { return stride_[idx]; } + +private: + // + // Data members + // + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for col-major matrices +struct ColumnMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + ACT_HOST_DEVICE + ColumnMajor(Index rows = 0, Index cols = 0) + : shape_(MakeCoord(rows, cols)), + stride_(MakeCoord(LongIndex(1), LongIndex(rows))) {} + + /// Constructor + ACT_HOST_DEVICE + ColumnMajor(Index rows, Index cols, LongIndex ldm) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(1), ldm)) {} + + /// Ctor + ACT_HOST_DEVICE + ColumnMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const { + return LongIndex(coord.row()) + LongIndex(coord.column()) * stride_[1]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + ColumnMajor GetTileLayout(MatrixCoord const &tileShape) const { + return ColumnMajor(tileShape, stride()); + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const { return shape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) { return shape_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const { return stride_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) { return stride_[idx]; } + +private: + // + // Data members + // + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for nZ matrices which is col-major inside fractal and +/// row-major between fractal +struct nZ { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + ACT_HOST_DEVICE constexpr nZ( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = + 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = + 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = + 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = + 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, + colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, + strideColsInFractal, strideColsByFractal)) {} + + /// Ctor + ACT_HOST_DEVICE constexpr nZ(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE constexpr static nZ MakeLayout(Index orgRows, Index orgCols) { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return nZ(orgRows, orgCols, ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, + C0_NUM_PER_FRACTAL, colsRound / C0_NUM_PER_FRACTAL, 1, + colsRound * ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + + LongIndex(coord.column()) / shape_[2] * stride_[3] + + (LongIndex(coord.row()) % shape_[0]) * stride_[0] + + (LongIndex(coord.column()) % shape_[2]) * stride_[2]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + nZ GetTileLayout(MatrixCoord const &tileOriShape) const { + auto tileShape = + MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), + CeilDiv(tileOriShape.column(), shape(2))); + return nZ(tileOriShape, tileShape, stride()); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const { return shape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) { return shape_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const { return stride_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) { return stride_[idx]; } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for zN matrices which is row-major inside fractal and +/// col-major between fractal +struct zN { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + ACT_HOST_DEVICE constexpr zN( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = + 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = + 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = + 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = + 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, + colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, + strideColsInFractal, strideColsByFractal)) {} + + /// Ctor + ACT_HOST_DEVICE constexpr zN(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE constexpr static zN MakeLayout(Index orgRows, Index orgCols) { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return zN(orgRows, orgCols, C0_NUM_PER_FRACTAL, + rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, + colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL, + 1, rowsRound * ELE_NUM_PER_C0); + } + + ACT_HOST_DEVICE + static zN MakeLayoutInL0C(MatrixCoord const &shape) { + return zN(shape.row(), shape.column(), C0_NUM_PER_FRACTAL, + CeilDiv(shape.row()), C0_NUM_PER_FRACTAL, + CeilDiv(shape.column()), C0_NUM_PER_FRACTAL, + C0_NUM_PER_FRACTAL * C0_NUM_PER_FRACTAL, 1, + RoundUp(shape.row()) * C0_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + + LongIndex(coord.column()) / shape_[2] * stride_[3] + + (LongIndex(coord.row()) % shape_[0]) * stride_[0] + + (LongIndex(coord.column()) % shape_[2]) * stride_[2]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + zN GetTileLayout(MatrixCoord const &tileOriShape) const { + auto tileShape = + MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), + CeilDiv(tileOriShape.column(), shape(2))); + return zN(tileOriShape, tileShape, stride()); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const { return shape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) { return shape_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const { return stride_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) { return stride_[idx]; } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for zN matrices which is row-major inside fractal and +/// row-major between fractal +struct zZ { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + ACT_HOST_DEVICE constexpr zZ( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = + 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = + 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = + 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = + 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, + colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, + strideColsInFractal, strideColsByFractal)) {} + + /// Ctor + ACT_HOST_DEVICE constexpr zZ(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE constexpr static zZ MakeLayout(Index orgRows, Index orgCols) { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return zZ(orgRows, orgCols, C0_NUM_PER_FRACTAL, + rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, + colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, + colsRound * C0_NUM_PER_FRACTAL, 1, ELE_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + + LongIndex(coord.column()) / shape_[2] * stride_[3]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const { return shape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) { return shape_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const { return stride_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) { return stride_[idx]; } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for padding rowmajor matrices +/// A special data layout designed to improve the efficiency of matrix +/// operations in non-512B aligned scenarios. This layout is row-major within +/// blocks and also row-major between blocks. +struct PaddingRowMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + /// Constructor + ACT_HOST_DEVICE + PaddingRowMajor(Index orgRows, Index orgCols, Index blockRows, + Index blockCols) + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, + CeilDiv(orgCols, blockCols))), + stride_(MakeCoord( + (LongIndex)blockCols, + (LongIndex)blockRows * (LongIndex)RoundUp(orgCols, blockCols), + (LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols)) {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const { + LongIndex blockRows = (LongIndex)shape_[0]; + LongIndex blockCols = (LongIndex)shape_[2]; + return (LongIndex)coord.row() / blockRows * stride_[1] + + (LongIndex)coord.column() / blockCols * stride_[3] + + (LongIndex)coord.row() % blockRows * stride_[0] + + (LongIndex)coord.column() % blockCols; + } + + ACT_HOST_DEVICE + PaddingRowMajor GetTileLayout(MatrixCoord const &tileShape) const { + return PaddingRowMajor(tileShape.row(), tileShape.column(), shape_[0], + shape_[2]); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const { return shape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) { return shape_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const { return stride_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) { return stride_[idx]; } + +private: + // + // Data members + // + + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/// Mapping function for padding columnmajor matrices +/// A special data layout designed to improve the efficiency of matrix +/// operations in non-512B aligned scenarios. This layout is column-major within +/// blocks and also column-major between blocks. +struct PaddingColumnMajor { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + /// Constructor + ACT_HOST_DEVICE + PaddingColumnMajor(Index orgRows, Index orgCols, Index blockRows, + Index blockCols) + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, + CeilDiv(orgCols, blockCols))), + stride_(MakeCoord( + (LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols, + (LongIndex)blockRows, + (LongIndex)RoundUp(orgRows, blockRows) * (LongIndex)blockCols)) {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const { + LongIndex blockRows = (LongIndex)shape_[0]; + LongIndex blockCols = (LongIndex)shape_[2]; + return (LongIndex)coord.row() / blockRows * stride_[1] + + (LongIndex)coord.column() / blockCols * stride_[3] + + (LongIndex)coord.row() % blockRows + + (LongIndex)coord.column() % blockCols * stride_[2]; + } + + ACT_HOST_DEVICE + PaddingColumnMajor GetTileLayout(MatrixCoord const &tileShape) const { + return PaddingColumnMajor(tileShape.row(), tileShape.column(), shape_[0], + shape_[2]); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const { return shape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) { return shape_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const { return stride_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) { return stride_[idx]; } + +private: + // + // Data members + // + + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; + +/////////////////////// +// new add layout nN +// nN layout +struct nN { +public: + /// Logical rank of tensor + static constexpr int RANK = 4; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; + + /// Logical coordinate + using OrgShape = Coord; + + /// Logical coordinate + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + +public: + // Methods + + /// Constructor + ACT_HOST_DEVICE + nN(Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + + LongIndex strideRowsInFractal = + 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = + 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = + 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = + 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, + colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, + strideColsInFractal, strideColsByFractal)) {} + + /// Ctor + ACT_HOST_DEVICE + nN(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE static nN MakeLayout(Index orgRows, Index orgCols) { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = + BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return nN(orgRows, orgCols, + + ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, C0_NUM_PER_FRACTAL, + colsRound / C0_NUM_PER_FRACTAL, + + 1, ELE_NUM_PER_FRACTAL, ELE_NUM_PER_C0, + rowsRound * C0_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + + LongIndex(coord.column()) / shape_[2] * stride_[3]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const { return shape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) { return shape_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const { return stride_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) { return stride_[idx]; } + +private: + /// Origin Shape data member + OrgShape orgShape_; + + /// Shape data member + Shape shape_; + + /// Stride data member + Stride stride_; +}; +} // namespace Act::layout + +#endif // ACT_LAYOUT_MATRIX_HPP diff --git a/act/layout/vector.hpp b/act/layout/vector.hpp new file mode 100644 index 00000000..286d0648 --- /dev/null +++ b/act/layout/vector.hpp @@ -0,0 +1,108 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_LAYOUT_VECTOR_HPP +#define ACT_LAYOUT_VECTOR_HPP + +#include "../../act/act.hpp" +#include "../../act/coord.hpp" + +namespace Act::layout { + +struct VectorLayout { +public: + /// Logical rank of tensor + static constexpr int RANK = 1; + + /// Index type used for coordinates + using Index = uint32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Shape vector + using Shape = Coord; + + /// Stride vector + using Stride = Coord; + + /// Logical coordinate + using TensorCoord = Coord; + +public: + // Methods + + ACT_HOST_DEVICE + VectorLayout(Index size = 0) + : shape_(MakeCoord(size)), stride_(MakeCoord(LongIndex(1))) {} + + ACT_HOST_DEVICE + VectorLayout(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + template + ACT_HOST_DEVICE static VectorLayout + MakeLayoutInUb(TensorCoord const &tileShape) { + return VectorLayout{RoundUp(tileShape[0])}; + } + + ACT_HOST_DEVICE + LongIndex GetOffset(TensorCoord const &coord) const { + return stride_[0] * coord[0]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + VectorLayout GetTileLayout(TensorCoord const &tileShape) const { + return VectorLayout(tileShape, stride()); + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() { return shape_; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const { return shape_[idx]; } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) { return shape_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() { return stride_; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const { return stride_[idx]; } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) { return stride_[idx]; } + +private: + /// Stride data member + Shape shape_; + Stride stride_; +}; + +} // namespace Act::layout + +#endif // ACT_LAYOUT_VECTOR_HPP diff --git a/act/matrix_coord.hpp b/act/matrix_coord.hpp new file mode 100644 index 00000000..bad9545b --- /dev/null +++ b/act/matrix_coord.hpp @@ -0,0 +1,98 @@ +/* + * Copyright (c) 2025 Huawei Technologies Co., Ltd. + * This file is a part of the CANN Open Software. + * Licensed under CANN Open Software License Agreement Version 1.0 (the + * "License"). Please refer to the License for details. You may not use this + * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN + * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, + * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS + * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository + * for the full text of the License. + */ + +#ifndef ACT_MATRIX_COORD_HPP +#define ACT_MATRIX_COORD_HPP + +#include "../act/coord.hpp" + +namespace Act { + +template struct MatrixShape { + static constexpr uint32_t ROW = ROW_; + static constexpr uint32_t COLUMN = COLUMN_; + + static constexpr int64_t COUNT = ROW * COLUMN; + + ACT_HOST_DEVICE + static Coord<2> ToCoord() { return MakeCoord(ROW, COLUMN); } +}; + +/// MatrixCoord wraps Coord<2, uint32_t> to provide a helper for accessing named +/// dimensions. Classes expecting a coordinate in the rank=2 index space of a +/// matrix should use MatrixCoord. +struct MatrixCoord : public Coord<2, uint32_t> { + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=2 + using Base = Coord<2, Index>; + + /// LongIndex type + using LongIndex = typename Base::LongIndex; + + /// Rows dimension + static constexpr uint32_t ROW_INDEX = 0; + + /// Columns dimension + static constexpr uint32_t COLUMN_INDEX = 1; + + /// Default ctor + ACT_HOST_DEVICE + MatrixCoord() {} + + /// Constructs from Coord<2> + ACT_HOST_DEVICE + MatrixCoord(Coord<2, Index> const &coord) : Base(coord) {} + + /// Helper to construct from a row and column + ACT_HOST_DEVICE + MatrixCoord(Index row, Index column) : Base(MakeCoord(row, column)) {} + + /// Helper to construct from a row and column, which are LongIndex based + ACT_HOST_DEVICE + MatrixCoord(LongIndex row, LongIndex column) + : Base(MakeCoord(Index(row), Index(column))) {} + + /// Returns the row of the coordinate + ACT_HOST_DEVICE + Index const &row() const { return this->At(ROW_INDEX); } + + /// Returns the row of the coordinate + ACT_HOST_DEVICE + Index &row() { return this->At(ROW_INDEX); } + + /// Returns the column of the coordinate + ACT_HOST_DEVICE + Index const &column() const { return this->At(COLUMN_INDEX); } + + /// Returns the column of the coordinate + ACT_HOST_DEVICE + Index &column() { return this->At(COLUMN_INDEX); } + + /// Element-wise addition + ACT_HOST_DEVICE + MatrixCoord operator+(Base const &b) const { + return MatrixCoord(Base::operator+(b)); + } + + /// In-place addition + ACT_HOST_DEVICE + MatrixCoord &operator+=(Base const &b) { + Base::operator+=(b); + return *this; + } +}; + +} // namespace Act + +#endif From 29557a4e32cd9bde45d6c8720d55932e803499fb Mon Sep 17 00:00:00 2001 From: wangyibo1005 <74347676+wangyibo1005@users.noreply.github.com> Date: Fri, 26 Sep 2025 10:16:06 +0800 Subject: [PATCH 2/6] Add files via upload --- act/act.hpp | 7 +- act/arch/arch.hpp | 30 +- act/arch/cross_core_sync.hpp | 100 +- act/arch/local_tensor_buffer.hpp | 253 ++- act/arch/resource.hpp | 38 +- act/coord.hpp | 452 ++-- act/detail/alignment.hpp | 38 +- act/detail/callback.hpp | 44 +- act/detail/dependent_false.hpp | 2 +- act/detail/macros.hpp | 2 +- act/detail/tag_to_layout.hpp | 80 +- act/epilogue/block/block_epilogue.hpp | 11 +- .../block_epilogue_per_token_dequant.hpp | 1480 +++++++------ act/epilogue/dispatch_policy.hpp | 39 +- act/epilogue/tile/copy_gm_to_ub.hpp | 238 +-- act/epilogue/tile/copy_ub_to_gm.hpp | 168 +- .../tile/tile_broadcast_inplace_by_column.hpp | 59 +- .../tile/tile_broadcast_inplace_by_row.hpp | 51 +- act/epilogue/tile/tile_broadcast_mul.hpp | 161 +- act/epilogue/tile/tile_broadcast_one_blk.hpp | 55 +- act/epilogue/tile/tile_cast.hpp | 28 +- act/epilogue/tile/tile_copy.hpp | 113 +- act/epilogue/tile/tile_elemwise_add.hpp | 33 +- act/epilogue/tile/tile_elemwise_mul.hpp | 31 +- act/epilogue/tile/tile_elemwise_muls.hpp | 26 +- act/epilogue/tile/tile_swizzle.hpp | 118 +- act/gemm/block/block_mmad.hpp | 52 +- ...block_mmad_preload_async_with_callback.hpp | 761 ++++--- act/gemm/block/block_swizzle.hpp | 399 ++-- act/gemm/dispatch_policy.hpp | 61 +- act/gemm/gemm_type.hpp | 13 +- act/gemm/helper.hpp | 222 +- ...per_token_dequant_multistage_workspace.hpp | 638 +++--- act/gemm/tile/copy_gm_to_l1.hpp | 1404 ++++++------- act/gemm/tile/copy_gm_to_ub.hpp | 67 +- act/gemm/tile/copy_l0c_to_gm.hpp | 326 ++- act/gemm/tile/copy_l1_to_l0a.hpp | 668 +++--- act/gemm/tile/copy_l1_to_l0b.hpp | 947 ++++----- act/gemm/tile/copy_ub_to_gm.hpp | 107 +- act/gemm/tile/tile_copy.hpp | 245 +-- act/gemm/tile/tile_mmad.hpp | 116 +- act/gemm_coord.hpp | 207 +- act/gemv_coord.hpp | 122 +- act/layout/layout.hpp | 2 +- act/layout/matrix.hpp | 1832 +++++++++-------- act/layout/vector.hpp | 173 +- act/matrix_coord.hpp | 155 +- 47 files changed, 6018 insertions(+), 6156 deletions(-) diff --git a/act/act.hpp b/act/act.hpp index 0fc19b54..2e5fab8b 100644 --- a/act/act.hpp +++ b/act/act.hpp @@ -27,12 +27,11 @@ constexpr uint32_t BYTE_PER_FRACTAL = BYTE_PER_C0 * C0_NUM_PER_FRACTAL; constexpr uint32_t BYTE_PER_BLK = 32; constexpr uint32_t BLK_NUM_PER_VECTOR_FRACTAL = 8; -constexpr uint32_t BYTE_PER_VECTOR_FRACTAL = - BYTE_PER_BLK * BLK_NUM_PER_VECTOR_FRACTAL; +constexpr uint32_t BYTE_PER_VECTOR_FRACTAL = BYTE_PER_BLK * BLK_NUM_PER_VECTOR_FRACTAL; constexpr uint64_t L2_OFFSET = 0; constexpr uint32_t STRIDE_LIMIT = 65536; -} // namespace Act +} // namespace Act -#endif // ACT_ACT_HPP +#endif // ACT_ACT_HPP diff --git a/act/arch/arch.hpp b/act/arch/arch.hpp index bb0a2b4d..f1bb8727 100644 --- a/act/arch/arch.hpp +++ b/act/arch/arch.hpp @@ -16,39 +16,39 @@ namespace Act::Arch { struct AtlasA2 { - static constexpr uint32_t BIAS_SIZE = 1024; - static constexpr uint32_t FIXBUF_SIZE = 7 * 1024; - static constexpr uint32_t UB_SIZE = 192 * 1024; - static constexpr uint32_t L1_SIZE = 512 * 1024; - static constexpr uint32_t L0A_SIZE = 64 * 1024; - static constexpr uint32_t L0B_SIZE = 64 * 1024; - static constexpr uint32_t L0C_SIZE = 128 * 1024; + static constexpr uint32_t BIAS_SIZE = 1024; + static constexpr uint32_t FIXBUF_SIZE = 7 * 1024; + static constexpr uint32_t UB_SIZE = 192 * 1024; + static constexpr uint32_t L1_SIZE = 512 * 1024; + static constexpr uint32_t L0A_SIZE = 64 * 1024; + static constexpr uint32_t L0B_SIZE = 64 * 1024; + static constexpr uint32_t L0C_SIZE = 128 * 1024; }; struct PositionGM { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::GM; + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::GM; }; struct PositionL1 { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::A1; + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::A1; }; struct PositionL0A { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::A2; + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::A2; }; struct PositionL0B { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::B2; + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::B2; }; struct PositionL0C { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::CO1; + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::CO1; }; struct PositionUB { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::VECCALC; + static constexpr AscendC::TPosition POSITION = AscendC::TPosition::VECCALC; }; -} // namespace Act::Arch +} // namespace Act::Arch -#endif // ACT_ARCH_ARCH_HPP +#endif // ACT_ARCH_ARCH_HPP diff --git a/act/arch/cross_core_sync.hpp b/act/arch/cross_core_sync.hpp index 1617304f..72099c4e 100644 --- a/act/arch/cross_core_sync.hpp +++ b/act/arch/cross_core_sync.hpp @@ -26,82 +26,90 @@ constexpr FlagID AIV_INTER_SUBBLOCK_BARRIER = 10; constexpr FlagID FFTS_MAX_FLAG = 7; struct CrossCoreFlag { - ACT_DEVICE - CrossCoreFlag() : id(0) {} + ACT_DEVICE + CrossCoreFlag() : id(0) {} - ACT_DEVICE - CrossCoreFlag(FlagID id) : id(id) {} + ACT_DEVICE + CrossCoreFlag(FlagID id) : id(id) {} - FlagID id; + FlagID id; }; template struct CrossCoreFlagWithReverse { - ACT_DEVICE - CrossCoreFlagWithReverse() : id(0), reverseId(0) {} + ACT_DEVICE + CrossCoreFlagWithReverse() : id(0), reverseId(0) {} - ACT_DEVICE - CrossCoreFlagWithReverse(FlagID id, FlagID reverseId) - : id(id), reverseId(reverseId) {} + ACT_DEVICE + CrossCoreFlagWithReverse(FlagID id, FlagID reverseId) : id(id), reverseId(reverseId) {} - FlagID id; - FlagID reverseId; - uint32_t count{0}; + FlagID id; + FlagID reverseId; + uint32_t count{0}; }; -template struct BarrierFlag { - static_assert(MODE != MODE, "Unsupporteded cross core barrier flag, can not " - "find the specialization."); +template +struct BarrierFlag { + static_assert(MODE != MODE, + "Unsupporteded cross core barrier flag, can not " + "find the specialization."); }; -template <> struct BarrierFlag<0x0, AscendC::AIV> { - static constexpr FlagID ID = AIV_INTER_BLOCK_BARRIER; +template <> +struct BarrierFlag<0x0, AscendC::AIV> { + static constexpr FlagID ID = AIV_INTER_BLOCK_BARRIER; }; -template <> struct BarrierFlag<0x0, AscendC::AIC> { - static constexpr FlagID ID = AIC_INTER_BLOCK_BARRIER; +template <> +struct BarrierFlag<0x0, AscendC::AIC> { + static constexpr FlagID ID = AIC_INTER_BLOCK_BARRIER; }; -template <> struct BarrierFlag<0x1, AscendC::AIV> { - static constexpr FlagID ID = AIV_INTER_SUBBLOCK_BARRIER; +template <> +struct BarrierFlag<0x1, AscendC::AIV> { + static constexpr FlagID ID = AIV_INTER_SUBBLOCK_BARRIER; }; -template ACT_DEVICE void CrossCoreBarrier() { - constexpr FlagID flagId = BarrierFlag::ID; - AscendC::CrossCoreSetFlag(flagId); - AscendC::CrossCoreWaitFlag(flagId); +template +ACT_DEVICE void CrossCoreBarrier() +{ + constexpr FlagID flagId = BarrierFlag::ID; + AscendC::CrossCoreSetFlag(flagId); + AscendC::CrossCoreWaitFlag(flagId); } template -ACT_DEVICE void CrossCoreSetFlag(CrossCoreFlag &flag) { - AscendC::CrossCoreSetFlag(flag.id); +ACT_DEVICE void CrossCoreSetFlag(CrossCoreFlag &flag) +{ + AscendC::CrossCoreSetFlag(flag.id); } ACT_DEVICE -void CrossCoreWaitFlag(CrossCoreFlag &flag) { - AscendC::CrossCoreWaitFlag(flag.id); +void CrossCoreWaitFlag(CrossCoreFlag &flag) +{ + AscendC::CrossCoreWaitFlag(flag.id); } template -ACT_DEVICE void -CrossCoreSetFlagWithReverse(CrossCoreFlagWithReverse &flag) { - AscendC::CrossCoreSetFlag(flag.id); - if (++flag.count >= REVERSE_DEPTH) { - AscendC::CrossCoreWaitFlag(flag.reverseId); - flag.count = 0; - } +ACT_DEVICE void CrossCoreSetFlagWithReverse(CrossCoreFlagWithReverse &flag) +{ + AscendC::CrossCoreSetFlag(flag.id); + if (++flag.count >= REVERSE_DEPTH) { + AscendC::CrossCoreWaitFlag(flag.reverseId); + flag.count = 0; + } } template -ACT_DEVICE void -CrossCoreWaitFlagWithReverse(CrossCoreFlagWithReverse &flag) { - AscendC::CrossCoreWaitFlag(flag.id); - if (++flag.count >= REVERSE_DEPTH) { - AscendC::CrossCoreSetFlag(flag.reverseId); - flag.count = 0; - } +ACT_DEVICE void CrossCoreWaitFlagWithReverse(CrossCoreFlagWithReverse &flag) +{ + AscendC::CrossCoreWaitFlag(flag.id); + if (++flag.count >= REVERSE_DEPTH) { + AscendC::CrossCoreSetFlag(flag.reverseId); + flag.count = 0; + } } -} // namespace Act::Arch +} // namespace Act::Arch -#endif // ACT_ARCH_CROSS_CORE_SYNC_HPP +#endif // ACT_ARCH_CROSS_CORE_SYNC_HPP diff --git a/act/arch/local_tensor_buffer.hpp b/act/arch/local_tensor_buffer.hpp index c94841a4..5208153f 100644 --- a/act/arch/local_tensor_buffer.hpp +++ b/act/arch/local_tensor_buffer.hpp @@ -20,213 +20,212 @@ namespace Act::Arch { struct LocalTensorBufferBase { public: - template - ACT_DEVICE AscendC::LocalTensor - GetBufferByByte(const uint32_t offset) const { - return tensor[offset].template ReinterpretCast(); - } + template + ACT_DEVICE AscendC::LocalTensor GetBufferByByte(const uint32_t offset) const + { + return tensor[offset].template ReinterpretCast(); + } protected: - ACT_DEVICE - LocalTensorBufferBase() = default; + ACT_DEVICE + LocalTensorBufferBase() = default; - AscendC::LocalTensor tensor; + AscendC::LocalTensor tensor; }; -template struct LocalTensorBuffer { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded local tensor buffer, can not find the specialization."); +template +struct LocalTensorBuffer { + static_assert(DEPENDENT_FALSE, "Unsupporteded local tensor buffer, can not find the specialization."); }; /// Partial specialization for TPosition::A1 template -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::A1; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufA1; - GetTPipePtr()->InitBuffer(tbufA1, ArchTag::L1_SIZE); - tensor = tbufA1.Get(); - } + static constexpr AscendC::TPosition Position = AscendC::TPosition::A1; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufA1; + GetTPipePtr()->InitBuffer(tbufA1, ArchTag::L1_SIZE); + tensor = tbufA1.Get(); + } }; /////////////////////////////////////////////////////////// /// Partial specialization for TPosition::A2 template -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::A2; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufA2; - GetTPipePtr()->InitBuffer(tbufA2, ArchTag::L0A_SIZE); - tensor = tbufA2.Get(); - } + static constexpr AscendC::TPosition Position = AscendC::TPosition::A2; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufA2; + GetTPipePtr()->InitBuffer(tbufA2, ArchTag::L0A_SIZE); + tensor = tbufA2.Get(); + } }; /////////////////////////////////////////////////////////// /// Partial specialization for TPosition::B1 template -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::B1; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufB1; - GetTPipePtr()->InitBuffer(tbufB1, ArchTag::L1_SIZE); - tensor = tbufB1.Get(); - } + static constexpr AscendC::TPosition Position = AscendC::TPosition::B1; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufB1; + GetTPipePtr()->InitBuffer(tbufB1, ArchTag::L1_SIZE); + tensor = tbufB1.Get(); + } }; /////////////////////////////////////////////////////////// /// Partial specialization for AtlasA2, TPosition::B2 template -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::B2; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufB2; - GetTPipePtr()->InitBuffer(tbufB2, ArchTag::L0B_SIZE); - tensor = tbufB2.Get(); - } + static constexpr AscendC::TPosition Position = AscendC::TPosition::B2; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufB2; + GetTPipePtr()->InitBuffer(tbufB2, ArchTag::L0B_SIZE); + tensor = tbufB2.Get(); + } }; /////////////////////////////////////////////////////////// /// Partial specialization for AtlasA2, TPosition::C1 template <> -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - using ArchTag = Arch::AtlasA2; - static constexpr AscendC::TPosition Position = AscendC::TPosition::C1; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufC1; - GetTPipePtr()->InitBuffer(tbufC1, ArchTag::L1_SIZE); - tensor = tbufC1.Get(); - } + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C1; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufC1; + GetTPipePtr()->InitBuffer(tbufC1, ArchTag::L1_SIZE); + tensor = tbufC1.Get(); + } }; /////////////////////////////////////////////////////////// /// Partial specialization for AtlasA2, TPosition::C2 template <> -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - using ArchTag = Arch::AtlasA2; - static constexpr AscendC::TPosition Position = AscendC::TPosition::C2; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufC2; - GetTPipePtr()->InitBuffer(tbufC2, ArchTag::BIAS_SIZE); - tensor = tbufC2.Get(); - } + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C2; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufC2; + GetTPipePtr()->InitBuffer(tbufC2, ArchTag::BIAS_SIZE); + tensor = tbufC2.Get(); + } }; /////////////////////////////////////////////////////////// /// Partial specialization for TPosition::CO1 template -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::CO1; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufCO1; - GetTPipePtr()->InitBuffer(tbufCO1, ArchTag::L0C_SIZE); - tensor = tbufCO1.Get(); - } + static constexpr AscendC::TPosition Position = AscendC::TPosition::CO1; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufCO1; + GetTPipePtr()->InitBuffer(tbufCO1, ArchTag::L0C_SIZE); + tensor = tbufCO1.Get(); + } }; /////////////////////////////////////////////////////////// /// Partial specialization for AtlasA2, TPosition::C2PIPE2GM template <> -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - using ArchTag = Arch::AtlasA2; - static constexpr AscendC::TPosition Position = AscendC::TPosition::C2PIPE2GM; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufC2PIPE2GM; - GetTPipePtr()->InitBuffer(tbufC2PIPE2GM, ArchTag::FIXBUF_SIZE); - tensor = tbufC2PIPE2GM.Get(); - } + using ArchTag = Arch::AtlasA2; + static constexpr AscendC::TPosition Position = AscendC::TPosition::C2PIPE2GM; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufC2PIPE2GM; + GetTPipePtr()->InitBuffer(tbufC2PIPE2GM, ArchTag::FIXBUF_SIZE); + tensor = tbufC2PIPE2GM.Get(); + } }; /////////////////////////////////////////////////////////// /// Partial specialization for TPosition::VECIN template -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::VECIN; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufVECIN; - GetTPipePtr()->InitBuffer(tbufVECIN, ArchTag::UB_SIZE); - tensor = tbufVECIN.Get(); - } + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECIN; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufVECIN; + GetTPipePtr()->InitBuffer(tbufVECIN, ArchTag::UB_SIZE); + tensor = tbufVECIN.Get(); + } }; /////////////////////////////////////////////////////////// /// Partial specialization for TPosition::VECOUT template -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::VECOUT; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufVECOUT; - GetTPipePtr()->InitBuffer(tbufVECOUT, ArchTag::UB_SIZE); - tensor = tbufVECOUT.Get(); - } + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECOUT; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufVECOUT; + GetTPipePtr()->InitBuffer(tbufVECOUT, ArchTag::UB_SIZE); + tensor = tbufVECOUT.Get(); + } }; /////////////////////////////////////////////////////////// /// Partial specialization for TPosition::VECCALC template -struct LocalTensorBuffer - : LocalTensorBufferBase { +struct LocalTensorBuffer : LocalTensorBufferBase { public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::VECCALC; - - ACT_DEVICE - LocalTensorBuffer() { - AscendC::TBuf tbufVECCALC; - GetTPipePtr()->InitBuffer(tbufVECCALC, ArchTag::UB_SIZE); - tensor = tbufVECCALC.Get(); - } + static constexpr AscendC::TPosition Position = AscendC::TPosition::VECCALC; + + ACT_DEVICE + LocalTensorBuffer() + { + AscendC::TBuf tbufVECCALC; + GetTPipePtr()->InitBuffer(tbufVECCALC, ArchTag::UB_SIZE); + tensor = tbufVECCALC.Get(); + } }; -} // namespace Act::Arch +} // namespace Act::Arch -#endif // INCLUDE_ACT_ARCH_MEMORY_H +#endif // INCLUDE_ACT_ARCH_MEMORY_H diff --git a/act/arch/resource.hpp b/act/arch/resource.hpp index d5c8531b..71367981 100644 --- a/act/arch/resource.hpp +++ b/act/arch/resource.hpp @@ -18,25 +18,27 @@ namespace Act::Arch { -template struct Resource { +template +struct Resource { public: - AscendC::TPipe pipe; - - LocalTensorBuffer l1Buf; - LocalTensorBuffer l0ABuf; - LocalTensorBuffer l0BBuf; - LocalTensorBuffer l0CBuf; - LocalTensorBuffer ubBuf; - - ACT_DEVICE - Resource() { - // The initialization of AscendC::Tpipe will insert some synchronization - // interfaces, which may conflict with the usage by users. Therefore, the - // "destroy" interface is used for releasing. - pipe.Destroy(); - } + AscendC::TPipe pipe; + + LocalTensorBuffer l1Buf; + LocalTensorBuffer l0ABuf; + LocalTensorBuffer l0BBuf; + LocalTensorBuffer l0CBuf; + LocalTensorBuffer ubBuf; + + ACT_DEVICE + Resource() + { + // The initialization of AscendC::Tpipe will insert some synchronization + // interfaces, which may conflict with the usage by users. Therefore, the + // "destroy" interface is used for releasing. + pipe.Destroy(); + } }; -} // namespace Act::Arch +} // namespace Act::Arch -#endif // INCLUDE_ACT_ARCH_RESOURCE_HPP +#endif // INCLUDE_ACT_ARCH_RESOURCE_HPP diff --git a/act/coord.hpp b/act/coord.hpp index f2e065e8..5faf5be6 100644 --- a/act/coord.hpp +++ b/act/coord.hpp @@ -18,254 +18,294 @@ namespace Act { /// Statically-sized array specifying Coords within a tensor -template struct Coord { public: - // Number of elements in Coord - static const int RANK = RANK_; + // Number of elements in Coord + static const int RANK = RANK_; - // Index typen used to store elements - using Index = Index_; + // Index typen used to store elements + using Index = Index_; - // Type used to represent linear offsets - using LongIndex = LongIndex_; + // Type used to represent linear offsets + using LongIndex = LongIndex_; - // Default ctor initializes uniformly - ACT_HOST_DEVICE constexpr explicit Coord(Index value = Index(0)) { - for (int i = 0; i < RANK; ++i) { - idx[i] = value; + // Default ctor initializes uniformly + ACT_HOST_DEVICE constexpr explicit Coord(Index value = Index(0)) + { + for (int i = 0; i < RANK; ++i) { + idx[i] = value; + } } - } - // Constructs from an array of integers - ACT_HOST_DEVICE constexpr Coord(Index const (&idx_)[RANK]) { - for (int i = 0; i < RANK; ++i) { - idx[i] = idx_[i]; + // Constructs from an array of integers + ACT_HOST_DEVICE constexpr Coord(Index const (&idx_)[RANK]) + { + for (int i = 0; i < RANK; ++i) { + idx[i] = idx_[i]; + } } - } - - // Constructs from an array of integers - ACT_HOST_DEVICE - int Argmin() const { - int i = 0; - for (int j = 1; j < RANK; ++j) { - if (idx[j] < idx[i]) { - i = j; - } + + // Constructs from an array of integers + ACT_HOST_DEVICE + int Argmin() const + { + int i = 0; + for (int j = 1; j < RANK; ++j) { + if (idx[j] < idx[i]) { + i = j; + } + } + return i; + } + + // Returns the index of the dimension with greatest value + ACT_HOST_DEVICE + int Argmax() const + { + int i = 0; + for (int j = 1; j < RANK; ++j) { + if (idx[j] > idx[i]) { + i = j; + } + } + return i; } - return i; - } - - // Returns the index of the dimension with greatest value - ACT_HOST_DEVICE - int Argmax() const { - int i = 0; - for (int j = 1; j < RANK; ++j) { - if (idx[j] > idx[i]) { - i = j; - } + + // Returns true if Coord is non-zero + ACT_HOST_DEVICE + explicit operator bool() const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i]) { + return true; + } + } + return false; } - return i; - } - - // Returns true if Coord is non-zero - ACT_HOST_DEVICE - explicit operator bool() const { - for (int i = 0; i < RANK; ++i) { - if (idx[i]) { + + // Return true if Coord is uniformly zero. + ACT_HOST_DEVICE + bool operator!() const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i]) { + return false; + } + } return true; - } } - return false; - } - - // Return true if Coord is uniformly zero. - ACT_HOST_DEVICE - bool operator!() const { - for (int i = 0; i < RANK; ++i) { - if (idx[i]) { - return false; - } + + // Element-wise addition + ACT_HOST_DEVICE + Coord operator+(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] + b.idx[i]; + } + return c; } - return true; - } - - // Element-wise addition - ACT_HOST_DEVICE - Coord operator+(Coord const &b) const { - Coord c; - for (int i = 0; i < RANK; ++i) { - c.idx[i] = idx[i] + b.idx[i]; + + // Add a scalar to each element + ACT_HOST_DEVICE + Coord operator+(const Index val) const + { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] + val; + } + return c; } - return c; - } - - // Add a scalar to each element - ACT_HOST_DEVICE - Coord operator+(const Index val) const { - Coord c; - for (int i = 0; i < RANK; ++i) { - c.idx[i] = idx[i] + val; + + // Element-wise subtraction + ACT_HOST_DEVICE + Coord operator-(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] - b.idx[i]; + } + return c; } - return c; - } - - // Element-wise subtraction - ACT_HOST_DEVICE - Coord operator-(Coord const &b) const { - Coord c; - for (int i = 0; i < RANK; i++) { - c.idx[i] = idx[i] - b.idx[i]; + + // Subtract a scalar from each element + ACT_HOST_DEVICE + Coord operator-(Index const val) const + { + Coord c; + for (int i = 0; i < RANK; ++i) { + c.idx[i] = idx[i] - val; + } + return c; } - return c; - } - - // Subtract a scalar from each element - ACT_HOST_DEVICE - Coord operator-(Index const val) const { - Coord c; - for (int i = 0; i < RANK; ++i) { - c.idx[i] = idx[i] - val; + + // Element-wise multiply + ACT_HOST_DEVICE + Coord operator*(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] * b.idx[i]; + } + return c; } - return c; - } - - // Element-wise multiply - ACT_HOST_DEVICE - Coord operator*(Coord const &b) const { - Coord c; - for (int i = 0; i < RANK; i++) { - c.idx[i] = idx[i] * b.idx[i]; + + // Element-wise division + ACT_HOST_DEVICE + Coord operator/(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] / b.idx[i]; + } + return c; } - return c; - } - - // Element-wise division - ACT_HOST_DEVICE - Coord operator/(Coord const &b) const { - Coord c; - for (int i = 0; i < RANK; i++) { - c.idx[i] = idx[i] / b.idx[i]; + + // Element-wise mod + ACT_HOST_DEVICE + Coord operator%(Coord const &b) const + { + Coord c; + for (int i = 0; i < RANK; i++) { + c.idx[i] = idx[i] % b.idx[i]; + } + return c; } - return c; - } - - // Element-wise mod - ACT_HOST_DEVICE - Coord operator%(Coord const &b) const { - Coord c; - for (int i = 0; i < RANK; i++) { - c.idx[i] = idx[i] % b.idx[i]; + + // In-place addition + ACT_HOST_DEVICE + Coord &operator+=(Coord const &b) + { + for (int i = 0; i < RANK; ++i) { + idx[i] += b.idx[i]; + } + return *this; } - return c; - } - - // In-place addition - ACT_HOST_DEVICE - Coord &operator+=(Coord const &b) { - for (int i = 0; i < RANK; ++i) { - idx[i] += b.idx[i]; + + // In-place equal + ACT_HOST_DEVICE + bool operator==(Coord const &b) const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i] != b.idx[i]) { + return false; + } + } + return true; } - return *this; - } - - // In-place equal - ACT_HOST_DEVICE - bool operator==(Coord const &b) const { - for (int i = 0; i < RANK; ++i) { - if (idx[i] != b.idx[i]) { - return false; - } + + // In-place equal + ACT_HOST_DEVICE + bool operator==(Index const val) const + { + for (int i = 0; i < RANK; ++i) { + if (idx[i] != val) { + return false; + } + } + return true; } - return true; - } - - // In-place equal - ACT_HOST_DEVICE - bool operator==(Index const val) const { - for (int i = 0; i < RANK; ++i) { - if (idx[i] != val) { - return false; - } + + // Member access operator + ACT_HOST_DEVICE + Index &operator[](int dim) + { + return idx[dim]; + } + + // Member access operator + ACT_HOST_DEVICE + Index const &operator[](int dim) const + { + return idx[dim]; + } + + // Gets the index of a given Coord element + template + ACT_HOST_DEVICE Index &At() + { + return idx[DIM]; + } + + // Access via index; may limit unrolling potential + ACT_HOST_DEVICE + Index &At(int dim) + { + return idx[dim]; } - return true; - } - - // Member access operator - ACT_HOST_DEVICE - Index &operator[](int dim) { return idx[dim]; } - - // Member access operator - ACT_HOST_DEVICE - Index const &operator[](int dim) const { return idx[dim]; } - - // Gets the index of a given Coord element - template ACT_HOST_DEVICE Index &At() { return idx[DIM]; } - - // Access via index; may limit unrolling potential - ACT_HOST_DEVICE - Index &At(int dim) { return idx[dim]; } - - // Gets the index of a given Coord element - template ACT_HOST_DEVICE Index const &At() const { - return idx[DIM]; - } - - // Access via index; may limit unrolling potential - ACT_HOST_DEVICE - Index const &At(int dim) const { return idx[dim]; } - - template ACT_HOST_DEVICE auto GetCoordByAxis() const { - Index idx_[sizeof...(Is)]{idx[Is]...}; - return Coord{idx_}; - } - - ACT_HOST_DEVICE - static Coord Min(Coord const &a, Coord const &b) { - Coord res; - for (int i = 0; i < RANK; ++i) { - res[i] = a[i] < b[i] ? a[i] : b[i]; + + // Gets the index of a given Coord element + template + ACT_HOST_DEVICE Index const &At() const + { + return idx[DIM]; + } + + // Access via index; may limit unrolling potential + ACT_HOST_DEVICE + Index const &At(int dim) const + { + return idx[dim]; + } + + template + ACT_HOST_DEVICE auto GetCoordByAxis() const + { + Index idx_[sizeof...(Is)]{idx[Is]...}; + return Coord{idx_}; + } + + ACT_HOST_DEVICE + static Coord Min(Coord const &a, Coord const &b) + { + Coord res; + for (int i = 0; i < RANK; ++i) { + res[i] = a[i] < b[i] ? a[i] : b[i]; + } + return res; } - return res; - } private: - // Indices - Index idx[RANK]; + // Indices + Index idx[RANK]; }; // Helper to make a 1-element coordinate -template ACT_HOST_DEVICE constexpr Coord<1, T> MakeCoord(T dim0) { - T values[1] = {dim0}; - return Coord<1, T>(values); +template +ACT_HOST_DEVICE constexpr Coord<1, T> MakeCoord(T dim0) +{ + T values[1] = {dim0}; + return Coord<1, T>(values); } /// Helper to make a 2-element coordinate template -ACT_HOST_DEVICE constexpr Coord<2, T> MakeCoord(T dim0, T dim1) { - T values[2] = {dim0, dim1}; - return Coord<2, T>(values); +ACT_HOST_DEVICE constexpr Coord<2, T> MakeCoord(T dim0, T dim1) +{ + T values[2] = {dim0, dim1}; + return Coord<2, T>(values); } /// Helper to make a 3-element coordinate template -ACT_HOST_DEVICE constexpr Coord<3, T> MakeCoord(T dim0, T dim1, T dim2) { - T values[3] = {dim0, dim1, dim2}; - return Coord<3, T>(values); +ACT_HOST_DEVICE constexpr Coord<3, T> MakeCoord(T dim0, T dim1, T dim2) +{ + T values[3] = {dim0, dim1, dim2}; + return Coord<3, T>(values); } /// Helper to make a 4-element coordinate template -ACT_HOST_DEVICE constexpr Coord<4, T> MakeCoord(T dim0, T dim1, T dim2, - T dim3) { - T values[4] = {dim0, dim1, dim2, dim3}; - return Coord<4, T>(values); +ACT_HOST_DEVICE constexpr Coord<4, T> MakeCoord(T dim0, T dim1, T dim2, T dim3) +{ + T values[4] = {dim0, dim1, dim2, dim3}; + return Coord<4, T>(values); } -} // namespace Act +} // namespace Act -#endif // ACT_COORD_HPP +#endif // ACT_COORD_HPP diff --git a/act/detail/alignment.hpp b/act/detail/alignment.hpp index fe9e3e1e..db40e7ba 100644 --- a/act/detail/alignment.hpp +++ b/act/detail/alignment.hpp @@ -16,36 +16,42 @@ #include "../../act/detail/macros.hpp" template -ACT_HOST_DEVICE constexpr T RoundUp(const T &val) { - static_assert(ALIGN != 0, "ALIGN must not be 0"); - return (val + ALIGN - 1) / ALIGN * ALIGN; +ACT_HOST_DEVICE constexpr T RoundUp(const T &val) +{ + static_assert(ALIGN != 0, "ALIGN must not be 0"); + return (val + ALIGN - 1) / ALIGN * ALIGN; } template -ACT_HOST_DEVICE constexpr T RoundUp(const T &val, const T align) { - return (val + align - 1) / align * align; +ACT_HOST_DEVICE constexpr T RoundUp(const T &val, const T align) +{ + return (val + align - 1) / align * align; } template -ACT_HOST_DEVICE constexpr T RoundDown(const T val) { - static_assert(ALIGN != 0, "ALIGN must not be 0"); - return val / ALIGN * ALIGN; +ACT_HOST_DEVICE constexpr T RoundDown(const T val) +{ + static_assert(ALIGN != 0, "ALIGN must not be 0"); + return val / ALIGN * ALIGN; } template -ACT_HOST_DEVICE constexpr T RoundDown(const T val, const T align) { - return val / align * align; +ACT_HOST_DEVICE constexpr T RoundDown(const T val, const T align) +{ + return val / align * align; } template -ACT_HOST_DEVICE constexpr T CeilDiv(const T dividend) { - static_assert(DIVISOP != 0, "DIVISOP must not be 0"); - return (dividend + DIVISOP - 1) / DIVISOP; +ACT_HOST_DEVICE constexpr T CeilDiv(const T dividend) +{ + static_assert(DIVISOP != 0, "DIVISOP must not be 0"); + return (dividend + DIVISOP - 1) / DIVISOP; } template -ACT_HOST_DEVICE constexpr T CeilDiv(const T dividend, const T divisor) { - return (dividend + divisor - 1) / divisor; +ACT_HOST_DEVICE constexpr T CeilDiv(const T dividend, const T divisor) +{ + return (dividend + divisor - 1) / divisor; } -#endif // ACT_ALIGNMENT_HPP +#endif // ACT_ALIGNMENT_HPP diff --git a/act/detail/callback.hpp b/act/detail/callback.hpp index 5c47c6f8..7475213c 100644 --- a/act/detail/callback.hpp +++ b/act/detail/callback.hpp @@ -24,32 +24,40 @@ /// it is necessary to ensure that it is used within the life cycle of the /// callable structure. struct Callback { - void const *func{nullptr}; - void (*caller)(void const *){nullptr}; + void const *func{nullptr}; + void (*caller)(void const *){nullptr}; - Callback() = default; + Callback() = default; - ACT_DEVICE - void operator()() const { - if (func) { - caller(func); + ACT_DEVICE + void operator()() const + { + if (func) { + caller(func); + } } - } - ACT_DEVICE - operator bool() const { return func != nullptr; } + ACT_DEVICE + operator bool() const + { + return func != nullptr; + } }; -template ACT_DEVICE static void FuncWrapper(void const *func) { - (*static_cast(func))(); +template +ACT_DEVICE static void FuncWrapper(void const *func) +{ + (*static_cast(func))(); } // Use this to make a callback -template ACT_DEVICE Callback MakeCallback(Func *func) { - Callback callback; - callback.func = func; - callback.caller = &FuncWrapper; - return callback; +template +ACT_DEVICE Callback MakeCallback(Func *func) +{ + Callback callback; + callback.func = func; + callback.caller = &FuncWrapper; + return callback; } -#endif // ACT_DETAIL_CALLBACK_HPP +#endif // ACT_DETAIL_CALLBACK_HPP diff --git a/act/detail/dependent_false.hpp b/act/detail/dependent_false.hpp index 9a76dd52..c9985a05 100644 --- a/act/detail/dependent_false.hpp +++ b/act/detail/dependent_false.hpp @@ -19,4 +19,4 @@ constexpr bool DEPENDENT_BOOL_VALUE = VALUE; template constexpr bool DEPENDENT_FALSE = DEPENDENT_BOOL_VALUE; -#endif // ACT_DETAIL_DEPENDENT_FALSE_HPP +#endif // ACT_DETAIL_DEPENDENT_FALSE_HPP diff --git a/act/detail/macros.hpp b/act/detail/macros.hpp index fa31d68a..a2825344 100644 --- a/act/detail/macros.hpp +++ b/act/detail/macros.hpp @@ -17,4 +17,4 @@ #define ACT_HOST_DEVICE __forceinline__[host, aicore] #define ACT_GLOBAL __global__[aicore] -#endif // ACT_DETAIL_MACROS_HPP +#endif // ACT_DETAIL_MACROS_HPP diff --git a/act/detail/tag_to_layout.hpp b/act/detail/tag_to_layout.hpp index ec649e4e..033a4ee4 100644 --- a/act/detail/tag_to_layout.hpp +++ b/act/detail/tag_to_layout.hpp @@ -22,53 +22,46 @@ using namespace tla; namespace Act::detail { //////////////////////////////////////////////////////////////////////////////////////////////////// // For each Act::layout, provides its corresponding tla layout types -template struct TagToLayout { - using type = LayoutTag; +template +struct TagToLayout { + using type = LayoutTag; }; -template struct TagToLayout { - using type = Layout, Stride>, - Shape>; +template +struct TagToLayout { + using type = Layout, Stride>, Shape>; }; -template struct TagToLayout { - using type = Layout, Stride, int64_t>, - Shape>; +template +struct TagToLayout { + using type = Layout, Stride, int64_t>, Shape>; }; -template struct TagToLayout { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - using type = - Layout, uint32_t>, - Shape, uint32_t>>, - Stride, Int>, - Stride, int64_t>>, - Shape>; +template +struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + using type = Layout, uint32_t>, Shape, uint32_t>>, + Stride, Int>, Stride, int64_t>>, + Shape>; }; -template struct TagToLayout { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - using type = Layout, uint32_t>, - Shape, uint32_t>>, - Stride, int64_t>, - Stride, Int>>, - Shape>; +template +struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + using type = Layout, uint32_t>, Shape, uint32_t>>, + Stride, int64_t>, Stride, Int>>, + Shape>; }; -template struct TagToLayout { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - using type = - Layout, uint32_t>, - Shape, uint32_t>>, - Stride, int64_t>, - Stride, Int>>, - Shape>; +template +struct TagToLayout { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + using type = Layout, uint32_t>, Shape, uint32_t>>, + Stride, int64_t>, Stride, Int>>, + Shape>; }; // Convenience aliases @@ -76,15 +69,12 @@ template using TagToLayout_t = typename TagToLayout::type; constexpr uint32_t ELE_NUM_PER_FRACTAL_L0C = 256; -using LayoutL0C = - Layout, uint32_t>, - Shape, uint32_t>>, - Stride, Int>, - Stride, int64_t>>, - Shape>; +using LayoutL0C = Layout, uint32_t>, Shape, uint32_t>>, + Stride, Int>, Stride, int64_t>>, + Shape>; //////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace Act::detail +} // namespace Act::detail -#endif // ACT_DETAIL_TAG_TO_LAYOUT_HPP +#endif // ACT_DETAIL_TAG_TO_LAYOUT_HPP diff --git a/act/epilogue/block/block_epilogue.hpp b/act/epilogue/block/block_epilogue.hpp index f7057680..bb7a6ac6 100644 --- a/act/epilogue/block/block_epilogue.hpp +++ b/act/epilogue/block/block_epilogue.hpp @@ -17,12 +17,13 @@ namespace Act::Epilogue::Block { -template class BlockEpilogue { - static_assert(DEPENDENT_FALSE, - "Could not find an epilogue specialization"); +template +class BlockEpilogue +{ + static_assert(DEPENDENT_FALSE, "Could not find an epilogue specialization"); }; -} // namespace Act::Epilogue::Block +} // namespace Act::Epilogue::Block #include "../../../act/epilogue/block/block_epilogue_per_token_dequant.hpp" -#endif // ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP +#endif // ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP diff --git a/act/epilogue/block/block_epilogue_per_token_dequant.hpp b/act/epilogue/block/block_epilogue_per_token_dequant.hpp index ac21c634..dee41a8c 100644 --- a/act/epilogue/block/block_epilogue_per_token_dequant.hpp +++ b/act/epilogue/block/block_epilogue_per_token_dequant.hpp @@ -26,842 +26,738 @@ namespace Act::Epilogue::Block { -template -class BlockEpilogue< - EpilogueAtlasA2PerTokenDequant, CType_, ScaleType_, - PerTokenScaleType_, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, - TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_> { +template +class BlockEpilogue, CType_, ScaleType_, PerTokenScaleType_, + DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, + EpilogueTileSwizzle_> +{ public: - using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; - using ArchTag = typename DispatchPolicy::ArchTag; - static constexpr uint32_t UB_STAGES = UB_STAGES_; - - // Data infos - using ElementC = typename CType_::Element; - using LayoutC = typename CType_::Layout; - using ElementScale = typename ScaleType_::Element; - using LayoutScale = typename ScaleType_::Layout; - using ElementPerTokenScale = typename PerTokenScaleType_::Element; - using LayoutPerTokenScale = typename PerTokenScaleType_::Layout; - using ElementD = typename DType_::Element; - using LayoutD = typename DType_::Layout; - - // Check data infos - static_assert( - std::is_same_v && - (std::is_same_v || - std::is_same_v) && - std::is_same_v && - std::is_same_v, - "The element type template parameters of BlockEpilogue are wrong"); - static_assert(std::is_same_v && - std::is_same_v && - std::is_same_v && - std::is_same_v, - "The layout template parameters of BlockEpilogue are wrong"); - - // Tile compute ops - using TileRowBroadcastMul = TileRowBroadcastMul_; - using TileBroadcastOneBlk = TileBroadcastOneBlk_; - using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; - - // Tile copy - using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; - using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; - using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; - using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; - - using EpilogueTileSwizzle = EpilogueTileSwizzle_; - - using TileShape = typename TileRowBroadcastMul::TileShape; - - static_assert( - TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && - std::is_same_v, - "TileShape must be consistent for all tile compute ops"); - - static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + - TileShape::COLUMN * sizeof(ElementScale) + - TileShape::ROW * sizeof(ElementPerTokenScale) + - TileShape::COUNT * sizeof(ElementD)) + - (TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + - TileShape::ROW) * - sizeof(float) + - TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, - "TileShape is too large to fit in UB"); - - struct Params { - __gm__ ElementScale *ptrScale{nullptr}; - LayoutScale layoutScale{}; - __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; - LayoutPerTokenScale layoutPerTokenScale{}; - __gm__ ElementD *ptrD{nullptr}; - LayoutD layoutD{}; + using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = typename ScaleType_::Element; + using LayoutScale = typename ScaleType_::Layout; + using ElementPerTokenScale = typename PerTokenScaleType_::Element; + using LayoutPerTokenScale = typename PerTokenScaleType_::Layout; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && + (std::is_same_v || std::is_same_v) && + std::is_same_v && std::is_same_v, + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) + + TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + ACT_DEVICE + Params() {}; + + ACT_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; ACT_DEVICE - Params() {}; + BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) + { + size_t ubOffset = 0; + int32_t eventVMTE2 = 0; + int32_t eventMTE2V = 0; + int32_t eventMTE3V = 0; + int32_t eventVMTE3 = 0; + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(float); + ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(float); + ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubPerTokenMul = ubMul; + } ACT_DEVICE - Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, - __gm__ ElementPerTokenScale *ptrPerTokenScale_, - LayoutPerTokenScale const &layoutPerTokenScale_, - __gm__ ElementD *ptrD_, LayoutD const &layoutD_) - : ptrScale(ptrScale_), layoutScale(layoutScale_), - ptrPerTokenScale(ptrPerTokenScale_), - layoutPerTokenScale(layoutPerTokenScale_), ptrD(ptrD_), - layoutD(layoutD_) {} - }; - - ACT_DEVICE - BlockEpilogue(Arch::Resource const &resource, - Params const ¶ms = Params{}) - : params(params) { - size_t ubOffset = 0; - int32_t eventVMTE2 = 0; - int32_t eventMTE2V = 0; - int32_t eventMTE3V = 0; - int32_t eventVMTE3 = 0; - for (uint32_t i = 0; i < UB_STAGES; ++i) { - ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementC); - ubScaleList[i] = - resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(ElementScale); - ubPerTokenScaleList[i] = - resource.ubBuf.template GetBufferByByte( - ubOffset); - ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); - ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementD); - - eventUbCVMTE2List[i] = eventVMTE2++; - eventUbCMTE2VList[i] = eventMTE2V++; - eventUbScaleVMTE2List[i] = eventVMTE2++; - eventUbScaleMTE2VList[i] = eventMTE2V++; - eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; - eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; - eventUbDMTE3VList[i] = eventMTE3V++; - eventUbDVMTE3List[i] = eventVMTE3++; - - AscendC::SetFlag(eventUbCVMTE2List[i]); - AscendC::SetFlag(eventUbScaleVMTE2List[i]); - AscendC::SetFlag( - eventUbPerTokenScaleVMTE2List[i]); - AscendC::SetFlag(eventUbDMTE3VList[i]); - } - ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(float); - ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubPerTokenScaleFp32 = - resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * sizeof(float); - ubPerTokenScaleFp32Brcb = - resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * BYTE_PER_BLK; - ubPerTokenMul = ubMul; - } - - ACT_DEVICE - ~BlockEpilogue() { - for (uint32_t i = 0; i < UB_STAGES; ++i) { - AscendC::WaitFlag(eventUbCVMTE2List[i]); - AscendC::WaitFlag(eventUbScaleVMTE2List[i]); - AscendC::WaitFlag( - eventUbPerTokenScaleVMTE2List[i]); - AscendC::WaitFlag(eventUbDMTE3VList[i]); + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } } - } - - ACT_DEVICE - void UpdateParams(Params const ¶ms_) { params = params_; } - - ACT_DEVICE - void operator()(GemmCoord const &blockShapeMNK, - GemmCoord const &blockCoordMNK, - GemmCoord const &actualBlockShapeMNK, - AscendC::GlobalTensor const &gmBlockC, - LayoutC const &layoutBlockC, - Callback &&callback = Callback{}) { - if (actualBlockShapeMNK.k() == 0) { - return; + + ACT_DEVICE + void UpdateParams(Params const ¶ms_) + { + params = params_; } - callback(); - - // Calculate the offset of the current block - MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); - MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); - MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); - MatrixCoord blockOffset = blockCoord * blockShape; - - AscendC::GlobalTensor gmScale; - gmScale.SetGlobalBuffer(params.ptrScale); - AscendC::GlobalTensor gmPerTokenScale; - gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); - AscendC::GlobalTensor gmD; - gmD.SetGlobalBuffer(params.ptrD); - - auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); - auto tileShape = TileShape::ToCoord(); - EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); - uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); - uint32_t subblockIdx = AscendC::GetSubBlockIdx(); - uint32_t subblockNum = AscendC::GetSubBlockNum(); - for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; - loopIdx += subblockNum) { - auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); - auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); - auto tileOffsetInBlock = tileCoord * tileShape; - auto tileOffset = blockOffset + tileOffsetInBlock; - - auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; - auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); - - auto &ubC = ubCList[ubListId]; - LayoutC layoutUbC{actualTileShape, ubTileStride}; - - AscendC::WaitFlag( - eventUbCVMTE2List[ubListId]); - copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); - AscendC::SetFlag(eventUbCMTE2VList[ubListId]); - - auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); - auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); - - auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; - auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - - auto &ubScale = ubScaleList[ubListId]; - auto layoutUbScale = - LayoutScale::template MakeLayoutInUb(scaleTileShape); - - AscendC::WaitFlag( - eventUbScaleVMTE2List[ubListId]); - copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); - AscendC::SetFlag( - eventUbScaleMTE2VList[ubListId]); - - auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); - auto perTokenScaleTileShape = - actualTileShape.template GetCoordByAxis<0>(); - - auto gmTilePerTokenScale = - gmPerTokenScale[params.layoutPerTokenScale.GetOffset( - perTokenScaleTileOffset)]; - auto layoutGmTilePerTokenScale = - params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); - - auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; - auto layoutUbPerTokenScale = - LayoutScale::template MakeLayoutInUb( - perTokenScaleTileShape); - - AscendC::WaitFlag( - eventUbPerTokenScaleVMTE2List[ubListId]); - copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, - layoutUbPerTokenScale, layoutGmTilePerTokenScale); - AscendC::SetFlag( - eventUbPerTokenScaleMTE2VList[ubListId]); - - AscendC::WaitFlag( - eventUbCMTE2VList[ubListId]); - AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, - TileShape::COUNT); - AscendC::SetFlag(eventUbCVMTE2List[ubListId]); - - AscendC::WaitFlag( - eventUbScaleMTE2VList[ubListId]); - AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, - TileShape::COLUMN); - AscendC::SetFlag( - eventUbScaleVMTE2List[ubListId]); - - AscendC::WaitFlag( - eventUbPerTokenScaleMTE2VList[ubListId]); - AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, - AscendC::RoundMode::CAST_NONE, TileShape::ROW); - AscendC::SetFlag( - eventUbPerTokenScaleVMTE2List[ubListId]); - - tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32); - tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32); - AscendC::PipeBarrier(); - tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, - ubPerTokenScaleFp32Brcb); - AscendC::PipeBarrier(); - - auto &ubD = ubDList[ubListId]; - LayoutD layoutUbD{actualTileShape, ubTileStride}; - - AscendC::WaitFlag( - eventUbDMTE3VList[ubListId]); - AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, - TileShape::COUNT); - AscendC::SetFlag(eventUbDVMTE3List[ubListId]); - - auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; - auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); - - AscendC::WaitFlag( - eventUbDVMTE3List[ubListId]); - copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); - AscendC::SetFlag(eventUbDMTE3VList[ubListId]); - - ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + + ACT_DEVICE + void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, + GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, + LayoutC const &layoutBlockC, Callback &&callback = Callback{}) + { + if (actualBlockShapeMNK.k() == 0) { + return; + } + callback(); + + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); + + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32); + tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32); + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb); + AscendC::PipeBarrier(); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } } - } private: - Params params; - - AscendC::LocalTensor ubCList[UB_STAGES]; - AscendC::LocalTensor ubScaleList[UB_STAGES]; - AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; - AscendC::LocalTensor ubDList[UB_STAGES]; - - int32_t eventUbCVMTE2List[UB_STAGES]; - int32_t eventUbCMTE2VList[UB_STAGES]; - int32_t eventUbScaleVMTE2List[UB_STAGES]; - int32_t eventUbScaleMTE2VList[UB_STAGES]; - int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; - int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; - int32_t eventUbDMTE3VList[UB_STAGES]; - int32_t eventUbDVMTE3List[UB_STAGES]; - - uint32_t ubListId{0}; - - AscendC::LocalTensor ubCFp32; - AscendC::LocalTensor ubScaleFp32; - AscendC::LocalTensor ubMul; - AscendC::LocalTensor ubPerTokenScaleFp32; - AscendC::LocalTensor ubPerTokenScaleFp32Brcb; - AscendC::LocalTensor ubPerTokenMul; - - TileRowBroadcastMul tileRowBroadcastMul; - TileBroadcastOneBlk tileBroadcastOneBlk; - TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; - - CopyGmToUbC copyGmToUbC; - CopyGmToUbScale copyGmToUbScale; - CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; - CopyUbToGmD copyUbToGmD; + Params params; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + uint32_t ubListId{0}; + + AscendC::LocalTensor ubCFp32; + AscendC::LocalTensor ubScaleFp32; + AscendC::LocalTensor ubMul; + AscendC::LocalTensor ubPerTokenScaleFp32; + AscendC::LocalTensor ubPerTokenScaleFp32Brcb; + AscendC::LocalTensor ubPerTokenMul; + + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; }; -template -class BlockEpilogue, - CType_, Gemm::GemmType, - Gemm::GemmType, DType_, - TileRowBroadcastMul_, TileBroadcastOneBlk_, - TileOneBlkColumnBroadcastMul_, TileCopy_, - EpilogueTileSwizzle_> { +template +class BlockEpilogue, CType_, Gemm::GemmType, + Gemm::GemmType, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, + TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_> +{ public: - using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; - using ArchTag = typename DispatchPolicy::ArchTag; - static constexpr uint32_t UB_STAGES = UB_STAGES_; - static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; - - // Data infos - using ElementC = typename CType_::Element; - using LayoutC = typename CType_::Layout; - using ElementScale = float; - using LayoutScale = LayoutScale_; - using ElementPerTokenScale = float; - using LayoutPerTokenScale = LayoutPerTokenScale_; - using ElementD = typename DType_::Element; - using LayoutD = typename DType_::Layout; - - // Check data infos - static_assert( - std::is_same_v && - (std::is_same_v || - std::is_same_v), - "The element type template parameters of BlockEpilogue are wrong"); - static_assert(std::is_same_v && - std::is_same_v && - std::is_same_v && - std::is_same_v, - "The layout template parameters of BlockEpilogue are wrong"); - - // Tile compute ops - using TileRowBroadcastMul = TileRowBroadcastMul_; - using TileBroadcastOneBlk = TileBroadcastOneBlk_; - using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; - - // Tile copy - using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; - using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; - using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; - using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; - - using EpilogueTileSwizzle = EpilogueTileSwizzle_; - - using TileShape = typename TileRowBroadcastMul::TileShape; - - static_assert( - TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && - std::is_same_v, - "TileShape must be consistent for all tile compute ops"); - - static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + - TileShape::COLUMN * sizeof(ElementScale) + - TileShape::ROW * sizeof(ElementPerTokenScale) + - TileShape::COUNT * sizeof(ElementD)) + - (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + - TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, - "TileShape is too large to fit in UB"); - - struct Params { - __gm__ ElementScale *ptrScale{nullptr}; - LayoutScale layoutScale{}; - __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; - LayoutPerTokenScale layoutPerTokenScale{}; - __gm__ ElementD *ptrD{nullptr}; - LayoutD layoutD{}; + using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; + + // Data infos + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using ElementScale = float; + using LayoutScale = LayoutScale_; + using ElementPerTokenScale = float; + using LayoutPerTokenScale = LayoutPerTokenScale_; + using ElementD = typename DType_::Element; + using LayoutD = typename DType_::Layout; + + // Check data infos + static_assert(std::is_same_v && + (std::is_same_v || std::is_same_v), + "The element type template parameters of BlockEpilogue are wrong"); + static_assert(std::is_same_v && std::is_same_v && + std::is_same_v && + std::is_same_v, + "The layout template parameters of BlockEpilogue are wrong"); + + // Tile compute ops + using TileRowBroadcastMul = TileRowBroadcastMul_; + using TileBroadcastOneBlk = TileBroadcastOneBlk_; + using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; + + // Tile copy + using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; + using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; + using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; + using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; + + using EpilogueTileSwizzle = EpilogueTileSwizzle_; + + using TileShape = typename TileRowBroadcastMul::TileShape; + + static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && + std::is_same_v, + "TileShape must be consistent for all tile compute ops"); + + static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + + TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + + (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= + ArchTag::UB_SIZE, + "TileShape is too large to fit in UB"); + + struct Params { + __gm__ ElementScale *ptrScale{nullptr}; + LayoutScale layoutScale{}; + __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; + LayoutPerTokenScale layoutPerTokenScale{}; + __gm__ ElementD *ptrD{nullptr}; + LayoutD layoutD{}; + + ACT_DEVICE + Params() {}; + + ACT_DEVICE + Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, + __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, + __gm__ ElementD *ptrD_, LayoutD const &layoutD_) + : ptrScale(ptrScale_), + layoutScale(layoutScale_), + ptrPerTokenScale(ptrPerTokenScale_), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(ptrD_), + layoutD(layoutD_) + {} + }; + + ACT_DEVICE void AlignUbOffset() + { + size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1); + if (ubMask != 0) { + ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask; + } + } + + ACT_DEVICE + BlockEpilogue(Arch::Resource &resource, MoeDistributeCombineImpl::CombineCalcInfo &calcInfo, + Params const ¶ms = Params{}) + : resource(resource), calcInfo(calcInfo), params(params) + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementC); + ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COLUMN * sizeof(ElementScale); + ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); + ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(ElementD); + + eventUbCVMTE2List[i] = eventVMTE2++; + eventUbCMTE2VList[i] = eventMTE2V++; + eventUbScaleVMTE2List[i] = eventVMTE2++; + eventUbScaleMTE2VList[i] = eventMTE2V++; + eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; + eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; + eventUbDMTE3VList[i] = eventMTE3V++; + eventUbDVMTE3List[i] = eventVMTE3++; + + AscendC::SetFlag(eventUbCVMTE2List[i]); + AscendC::SetFlag(eventUbScaleVMTE2List[i]); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::SetFlag(eventUbDMTE3VList[i]); + } + ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::COUNT * sizeof(float); + ubPerTokenScaleBrcb = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += TileShape::ROW * BYTE_PER_BLK; + ubPerTokenMul = ubCFp32; + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + AlignUbOffset(); + epSendCountLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); + ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t); + AlignUbOffset(); + AscendC::GlobalTensor epSendCountGM; + epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_); + uint32_t epSendCountSize = calcInfo.isSharedExpert_ ? calcInfo.epWorldSize_ : calcInfo.moeSendNum_; + AscendC::DataCopyExtParams epSendCntParams = {1U, static_cast(epSendCountSize * sizeof(uint32_t)), + 0U, 0U, 0U}; + AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; + AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams); + AscendC::SetFlag(eventMTE2S); + AscendC::WaitFlag(eventMTE2S); +#if ENABLE_EP_SEND_COUNT_HASH + tokenToEpRankHashLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); + uint32_t maxGroupSendCount = 0; + uint32_t groupSendCount = 0; + for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; ++expertIdx) { + uint32_t prevGroupSendCount = groupSendCount; + groupSendCount = epSendCountLocal_.GetValue((expertIdx + 1) * calcInfo.epWorldSize_ - 1); + if (maxGroupSendCount < groupSendCount - prevGroupSendCount) { + maxGroupSendCount = groupSendCount - prevGroupSendCount; + } + } + ubOffset += maxGroupSendCount * sizeof(int32_t); + AlignUbOffset(); + // assert: ubOffset <= AscendC::TOTAL_UB_SIZE or + // AscendC::TOTAL_VEC_LOCAL_SIZE +#endif + } + } ACT_DEVICE - Params() {}; + ~BlockEpilogue() + { + for (uint32_t i = 0; i < UB_STAGES; ++i) { + AscendC::WaitFlag(eventUbCVMTE2List[i]); + AscendC::WaitFlag(eventUbScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); + AscendC::WaitFlag(eventUbDMTE3VList[i]); + } + } ACT_DEVICE - Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, - __gm__ ElementPerTokenScale *ptrPerTokenScale_, - LayoutPerTokenScale const &layoutPerTokenScale_, - __gm__ ElementD *ptrD_, LayoutD const &layoutD_) - : ptrScale(ptrScale_), layoutScale(layoutScale_), - ptrPerTokenScale(ptrPerTokenScale_), - layoutPerTokenScale(layoutPerTokenScale_), ptrD(ptrD_), - layoutD(layoutD_) {} - }; - - ACT_DEVICE void AlignUbOffset() { - size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1); - if (ubMask != 0) { - ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask; + void UpdateParams(Params const ¶ms_) + { + params = params_; } - } - - ACT_DEVICE - BlockEpilogue(Arch::Resource &resource, - MoeDistributeCombineImpl::CombineCalcInfo &calcInfo, - Params const ¶ms = Params{}) - : resource(resource), calcInfo(calcInfo), params(params) { - for (uint32_t i = 0; i < UB_STAGES; ++i) { - ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementC); - ubScaleList[i] = - resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(ElementScale); - ubPerTokenScaleList[i] = - resource.ubBuf.template GetBufferByByte( - ubOffset); - ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); - ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementD); - - eventUbCVMTE2List[i] = eventVMTE2++; - eventUbCMTE2VList[i] = eventMTE2V++; - eventUbScaleVMTE2List[i] = eventVMTE2++; - eventUbScaleMTE2VList[i] = eventMTE2V++; - eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; - eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; - eventUbDMTE3VList[i] = eventMTE3V++; - eventUbDVMTE3List[i] = eventVMTE3++; - - AscendC::SetFlag(eventUbCVMTE2List[i]); - AscendC::SetFlag(eventUbScaleVMTE2List[i]); - AscendC::SetFlag( - eventUbPerTokenScaleVMTE2List[i]); - AscendC::SetFlag(eventUbDMTE3VList[i]); + + ACT_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t expertLocalId = 0U) + { + return (GM_ADDR)((calcInfo.epRankId_ == rankId) + ? calcInfo.epWinContext_->localWindowsIn + : ((HcclRankRelationResV2 *)(calcInfo.epWinContext_->remoteRes[rankId].nextDevicePtr)) + ->windowsIn) + + calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; } - ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubPerTokenScaleBrcb = - resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * BYTE_PER_BLK; - ubPerTokenMul = ubCFp32; - - if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - AlignUbOffset(); - epSendCountLocal_ = - resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t); - AlignUbOffset(); - AscendC::GlobalTensor epSendCountGM; - epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_); - uint32_t epSendCountSize = calcInfo.isSharedExpert_ - ? calcInfo.epWorldSize_ - : calcInfo.moeSendNum_; - AscendC::DataCopyExtParams epSendCntParams = { - 1U, static_cast(epSendCountSize * sizeof(uint32_t)), 0U, 0U, - 0U}; - AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; - AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, - copyPadParams); - AscendC::SetFlag(eventMTE2S); - AscendC::WaitFlag(eventMTE2S); #if ENABLE_EP_SEND_COUNT_HASH - tokenToEpRankHashLocal_ = - resource.ubBuf.template GetBufferByByte(ubOffset); - uint32_t maxGroupSendCount = 0; - uint32_t groupSendCount = 0; - for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; - ++expertIdx) { - uint32_t prevGroupSendCount = groupSendCount; - groupSendCount = epSendCountLocal_.GetValue( - (expertIdx + 1) * calcInfo.epWorldSize_ - 1); - if (maxGroupSendCount < groupSendCount - prevGroupSendCount) { - maxGroupSendCount = groupSendCount - prevGroupSendCount; + ACT_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, uint32_t epRank, uint32_t copyLen) + { + constexpr uint32_t DUPLICATE_MASK_COUNT = 8; + uint32_t hashOffsetMask = (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1)); + if (hashOffsetMask != 0) { + uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask; + if (copyLen < remainMaskCount) { + remainMaskCount = copyLen; + } + uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask; + AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, ©Mask, 1, 1, + DUPLICATE_MASK_COUNT); + hashOffset += remainMaskCount; + copyLen -= remainMaskCount; + } + if (copyLen > 0) { + AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset], epRank, copyLen); + hashOffset += copyLen; } - } - ubOffset += maxGroupSendCount * sizeof(int32_t); - AlignUbOffset(); - // assert: ubOffset <= AscendC::TOTAL_UB_SIZE or - // AscendC::TOTAL_VEC_LOCAL_SIZE -#endif - } - } - - ACT_DEVICE - ~BlockEpilogue() { - for (uint32_t i = 0; i < UB_STAGES; ++i) { - AscendC::WaitFlag(eventUbCVMTE2List[i]); - AscendC::WaitFlag(eventUbScaleVMTE2List[i]); - AscendC::WaitFlag( - eventUbPerTokenScaleVMTE2List[i]); - AscendC::WaitFlag(eventUbDMTE3VList[i]); - } - } - - ACT_DEVICE - void UpdateParams(Params const ¶ms_) { params = params_; } - - ACT_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, - const uint8_t expertLocalId = 0U) { - return (GM_ADDR)((calcInfo.epRankId_ == rankId) - ? calcInfo.epWinContext_->localWindowsIn - : ((HcclRankRelationResV2 *)(calcInfo.epWinContext_ - ->remoteRes[rankId] - .nextDevicePtr)) - ->windowsIn) + - calcInfo.winDataSizeOffset_ + - expertLocalId * calcInfo.expertPerSizeOnWin_ + - rankId * OPT_RANK_OFFSET; - } -#if ENABLE_EP_SEND_COUNT_HASH - ACT_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, - uint32_t epRank, - uint32_t copyLen) { - constexpr uint32_t DUPLICATE_MASK_COUNT = 8; - uint32_t hashOffsetMask = - (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1)); - if (hashOffsetMask != 0) { - uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask; - if (copyLen < remainMaskCount) { - remainMaskCount = copyLen; - } - uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask; - AscendC::Duplicate( - tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, - ©Mask, 1, 1, DUPLICATE_MASK_COUNT); - hashOffset += remainMaskCount; - copyLen -= remainMaskCount; - } - if (copyLen > 0) { - AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset], epRank, - copyLen); - hashOffset += copyLen; } - } #endif - ACT_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, - uint32_t &localEpRank) { - if ((calcInfo.isSharedExpert_) && - (epRank < calcInfo.sharedExpertRankNum_)) { - remoteEpRank = calcInfo.epRankId_; - localEpRank = epRank; - } else { - remoteEpRank = epRank; - localEpRank = calcInfo.epRankId_; + ACT_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank) + { + if ((calcInfo.isSharedExpert_) && (epRank < calcInfo.sharedExpertRankNum_)) { + remoteEpRank = calcInfo.epRankId_; + localEpRank = epRank; + } else { + remoteEpRank = epRank; + localEpRank = calcInfo.epRankId_; + } } - } - - ACT_DEVICE void DoCombineSend(AscendC::LocalTensor &ubD, - layout::RowMajor &layoutGmTileD, - LayoutD &layoutUbD, int64_t groupOffsetD, - uint32_t expertIdx, uint32_t tileOffsetD) { - const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD); - const uint32_t copyTokenSrcStride = - (layoutUbD.stride(0) - layoutUbD.shape(1)) / - (BYTE_PER_C0 / sizeof(ElementD)); - const uint32_t copyTokenDstStride = - (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD); - - int64_t offsetD = groupOffsetD + tileOffsetD; - uint32_t startToken = offsetD / calcInfo.axisH_; - uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_; - uint32_t itToken = startToken; - uint32_t endToken = startToken + layoutGmTileD.shape(0); + + ACT_DEVICE void DoCombineSend(AscendC::LocalTensor &ubD, layout::RowMajor &layoutGmTileD, + LayoutD &layoutUbD, int64_t groupOffsetD, uint32_t expertIdx, uint32_t tileOffsetD) + { + const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD); + const uint32_t copyTokenSrcStride = + (layoutUbD.stride(0) - layoutUbD.shape(1)) / (BYTE_PER_C0 / sizeof(ElementD)); + const uint32_t copyTokenDstStride = (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD); + + int64_t offsetD = groupOffsetD + tileOffsetD; + uint32_t startToken = offsetD / calcInfo.axisH_; + uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_; + uint32_t itToken = startToken; + uint32_t endToken = startToken + layoutGmTileD.shape(0); #if ENABLE_EP_SEND_COUNT_HASH - uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken); + uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken); #else - constexpr uint32_t epRankStart = 0; + constexpr uint32_t epRankStart = 0; #endif - uint32_t sendCount = - expertIdx == 0 && epRankStart == 0 - ? 0 - : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1); - for (uint32_t epRank = epRankStart; - epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) { - uint32_t prevSendCount = sendCount; - sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); - if (prevSendCount <= itToken && itToken < sendCount) { - uint32_t copyTokenCount = - (sendCount < endToken ? sendCount : endToken) - itToken; - AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, - copyTokenSrcStride, - copyTokenDstStride, 0); - uint32_t remoteEpRank; - uint32_t localEpRank; - SetCombineSendEpRank(epRank, remoteEpRank, localEpRank); - GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) + - localEpRank * calcInfo.moeExpertPerRankNum_ * - calcInfo.expertPerSizeOnWin_; - AscendC::GlobalTensor rankWindow; - rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM); - AscendC::DataCopyPad( - rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + - tokenOffset], - ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams); - itToken += copyTokenCount; - } - } - } - - ACT_DEVICE - void operator()(int64_t groupOffsetD, uint32_t expertIdx, - GemmCoord const &blockShapeMNK, - GemmCoord const &blockCoordMNK, - GemmCoord const &actualBlockShapeMNK, - AscendC::GlobalTensor const &gmBlockC, - LayoutC const &layoutBlockC, - Callback &&callback = Callback{}) { - if (actualBlockShapeMNK.k() == 0) { - return; + uint32_t sendCount = + expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1); + for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) { + uint32_t prevSendCount = sendCount; + sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); + if (prevSendCount <= itToken && itToken < sendCount) { + uint32_t copyTokenCount = (sendCount < endToken ? sendCount : endToken) - itToken; + AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, copyTokenSrcStride, + copyTokenDstStride, 0); + uint32_t remoteEpRank; + uint32_t localEpRank; + SetCombineSendEpRank(epRank, remoteEpRank, localEpRank); + GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) + + localEpRank * calcInfo.moeExpertPerRankNum_ * calcInfo.expertPerSizeOnWin_; + AscendC::GlobalTensor rankWindow; + rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM); + AscendC::DataCopyPad(rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + tokenOffset], + ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams); + itToken += copyTokenCount; + } + } } - if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - expertOffset = expertIdx * calcInfo.epWorldSize_; -#if ENABLE_EP_SEND_COUNT_HASH - if (currentExpertIdx_ != expertIdx) { - uint32_t hashOffset = 0; - uint32_t sendCount = - expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1); - for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) { - uint32_t prevSendCount = sendCount; - sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); - InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, - sendCount - prevSendCount); + ACT_DEVICE + void operator()(int64_t groupOffsetD, uint32_t expertIdx, GemmCoord const &blockShapeMNK, + GemmCoord const &blockCoordMNK, GemmCoord const &actualBlockShapeMNK, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutBlockC, + Callback &&callback = Callback{}) + { + if (actualBlockShapeMNK.k() == 0) { + return; } - AscendC::SetFlag(eventVS); - AscendC::WaitFlag(eventVS); - currentExpertIdx_ = expertIdx; - } + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + expertOffset = expertIdx * calcInfo.epWorldSize_; +#if ENABLE_EP_SEND_COUNT_HASH + if (currentExpertIdx_ != expertIdx) { + uint32_t hashOffset = 0; + uint32_t sendCount = expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1); + for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) { + uint32_t prevSendCount = sendCount; + sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); + InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, sendCount - prevSendCount); + } + AscendC::SetFlag(eventVS); + AscendC::WaitFlag(eventVS); + currentExpertIdx_ = expertIdx; + } #endif - } + } + + callback(); + // Calculate the offset of the current block + MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); + MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); + MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); + MatrixCoord blockOffset = blockCoord * blockShape; + + AscendC::GlobalTensor gmScale; + gmScale.SetGlobalBuffer(params.ptrScale); + AscendC::GlobalTensor gmPerTokenScale; + gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); + AscendC::GlobalTensor gmD; + gmD.SetGlobalBuffer(params.ptrD); + + auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); + auto tileShape = TileShape::ToCoord(); + EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); + uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); + uint32_t subblockIdx = AscendC::GetSubBlockIdx(); + uint32_t subblockNum = AscendC::GetSubBlockNum(); + for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { + auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); + auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); + auto tileOffsetInBlock = tileCoord * tileShape; + auto tileOffset = blockOffset + tileOffsetInBlock; + + auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; + auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); + + auto &ubC = ubCList[ubListId]; + LayoutC layoutUbC{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); + copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); + AscendC::SetFlag(eventUbCMTE2VList[ubListId]); + + auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); + auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); + + auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; + auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - callback(); - // Calculate the offset of the current block - MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); - MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); - MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); - MatrixCoord blockOffset = blockCoord * blockShape; - - AscendC::GlobalTensor gmScale; - gmScale.SetGlobalBuffer(params.ptrScale); - AscendC::GlobalTensor gmPerTokenScale; - gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); - AscendC::GlobalTensor gmD; - gmD.SetGlobalBuffer(params.ptrD); - - auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); - auto tileShape = TileShape::ToCoord(); - EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); - uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); - uint32_t subblockIdx = AscendC::GetSubBlockIdx(); - uint32_t subblockNum = AscendC::GetSubBlockNum(); - for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; - loopIdx += subblockNum) { - auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); - auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); - auto tileOffsetInBlock = tileCoord * tileShape; - auto tileOffset = blockOffset + tileOffsetInBlock; - - auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; - auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); - - auto &ubC = ubCList[ubListId]; - LayoutC layoutUbC{actualTileShape, ubTileStride}; - - AscendC::WaitFlag( - eventUbCVMTE2List[ubListId]); - copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); - AscendC::SetFlag(eventUbCMTE2VList[ubListId]); - - auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); - auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); - - auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; - auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - - auto &ubScale = ubScaleList[ubListId]; - auto layoutUbScale = - LayoutScale::template MakeLayoutInUb(scaleTileShape); - - AscendC::WaitFlag( - eventUbScaleVMTE2List[ubListId]); - copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); - AscendC::SetFlag( - eventUbScaleMTE2VList[ubListId]); - - auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); - auto perTokenScaleTileShape = - actualTileShape.template GetCoordByAxis<0>(); - - auto gmTilePerTokenScale = - gmPerTokenScale[params.layoutPerTokenScale.GetOffset( - perTokenScaleTileOffset)]; - auto layoutGmTilePerTokenScale = - params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); - - auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; - auto layoutUbPerTokenScale = - LayoutScale::template MakeLayoutInUb( - perTokenScaleTileShape); - - AscendC::WaitFlag( - eventUbPerTokenScaleVMTE2List[ubListId]); - copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, - layoutUbPerTokenScale, layoutGmTilePerTokenScale); - AscendC::SetFlag( - eventUbPerTokenScaleMTE2VList[ubListId]); - - AscendC::WaitFlag( - eventUbCMTE2VList[ubListId]); - AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, - TileShape::COUNT); - AscendC::SetFlag(eventUbCVMTE2List[ubListId]); - - AscendC::WaitFlag( - eventUbScaleMTE2VList[ubListId]); - tileRowBroadcastMul(ubMul, ubCFp32, ubScale); - AscendC::SetFlag( - eventUbScaleVMTE2List[ubListId]); - - AscendC::WaitFlag( - eventUbPerTokenScaleMTE2VList[ubListId]); - tileBroadcastOneBlk(ubPerTokenScaleBrcb, ubPerTokenScale); - AscendC::SetFlag( - eventUbPerTokenScaleVMTE2List[ubListId]); - - AscendC::PipeBarrier(); - tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleBrcb); - AscendC::PipeBarrier(); - - auto &ubD = ubDList[ubListId]; - LayoutD layoutUbD{actualTileShape, ubTileStride}; - - AscendC::WaitFlag( - eventUbDMTE3VList[ubListId]); - AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, - TileShape::COUNT); - AscendC::SetFlag(eventUbDVMTE3List[ubListId]); - - auto tileOffsetD = params.layoutD.GetOffset(tileOffset); - auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); - - AscendC::WaitFlag( - eventUbDVMTE3List[ubListId]); - - if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, - tileOffsetD); - } else { - auto gmTileD = gmD[tileOffsetD]; - copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); - } - - AscendC::SetFlag(eventUbDMTE3VList[ubListId]); - - ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + auto &ubScale = ubScaleList[ubListId]; + auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); + + AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); + copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); + AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); + + auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); + auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); + + auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; + auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); + + auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; + auto layoutUbPerTokenScale = + LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); + + AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, + layoutGmTilePerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + + AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); + AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbCVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); + tileRowBroadcastMul(ubMul, ubCFp32, ubScale); + AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); + + AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); + tileBroadcastOneBlk(ubPerTokenScaleBrcb, ubPerTokenScale); + AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); + + AscendC::PipeBarrier(); + tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleBrcb); + AscendC::PipeBarrier(); + + auto &ubD = ubDList[ubListId]; + LayoutD layoutUbD{actualTileShape, ubTileStride}; + + AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); + AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + AscendC::SetFlag(eventUbDVMTE3List[ubListId]); + + auto tileOffsetD = params.layoutD.GetOffset(tileOffset); + auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); + + AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); + + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, tileOffsetD); + } else { + auto gmTileD = gmD[tileOffsetD]; + copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); + } + + AscendC::SetFlag(eventUbDMTE3VList[ubListId]); + + ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; + } } - } private: - Params params; - Arch::Resource &resource; - MoeDistributeCombineImpl::CombineCalcInfo calcInfo; - - AscendC::LocalTensor ubCList[UB_STAGES]; - AscendC::LocalTensor ubScaleList[UB_STAGES]; - AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; - AscendC::LocalTensor ubDList[UB_STAGES]; - - int32_t eventUbCVMTE2List[UB_STAGES]; - int32_t eventUbCMTE2VList[UB_STAGES]; - int32_t eventUbScaleVMTE2List[UB_STAGES]; - int32_t eventUbScaleMTE2VList[UB_STAGES]; - int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; - int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; - int32_t eventUbDMTE3VList[UB_STAGES]; - int32_t eventUbDVMTE3List[UB_STAGES]; - - AscendC::LocalTensor epSendCountLocal_; + Params params; + Arch::Resource &resource; + MoeDistributeCombineImpl::CombineCalcInfo calcInfo; + + AscendC::LocalTensor ubCList[UB_STAGES]; + AscendC::LocalTensor ubScaleList[UB_STAGES]; + AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; + AscendC::LocalTensor ubDList[UB_STAGES]; + + int32_t eventUbCVMTE2List[UB_STAGES]; + int32_t eventUbCMTE2VList[UB_STAGES]; + int32_t eventUbScaleVMTE2List[UB_STAGES]; + int32_t eventUbScaleMTE2VList[UB_STAGES]; + int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; + int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; + int32_t eventUbDMTE3VList[UB_STAGES]; + int32_t eventUbDVMTE3List[UB_STAGES]; + + AscendC::LocalTensor epSendCountLocal_; #if ENABLE_EP_SEND_COUNT_HASH - AscendC::LocalTensor tokenToEpRankHashLocal_; - uint32_t currentExpertIdx_{static_cast(-1)}; + AscendC::LocalTensor tokenToEpRankHashLocal_; + uint32_t currentExpertIdx_{static_cast(-1)}; #endif - size_t ubOffset{0}; - int32_t eventVMTE2{0}; - int32_t eventMTE2V{0}; - int32_t eventMTE3V{0}; - int32_t eventVMTE3{0}; - int32_t eventVS{0}; - int32_t eventMTE2S{0}; + size_t ubOffset{0}; + int32_t eventVMTE2{0}; + int32_t eventMTE2V{0}; + int32_t eventMTE3V{0}; + int32_t eventVMTE3{0}; + int32_t eventVS{0}; + int32_t eventMTE2S{0}; - uint32_t expertOffset; + uint32_t expertOffset; - uint32_t ubListId{0}; + uint32_t ubListId{0}; - AscendC::LocalTensor ubCFp32; - AscendC::LocalTensor ubMul; - AscendC::LocalTensor ubPerTokenScaleBrcb; - AscendC::LocalTensor ubPerTokenMul; + AscendC::LocalTensor ubCFp32; + AscendC::LocalTensor ubMul; + AscendC::LocalTensor ubPerTokenScaleBrcb; + AscendC::LocalTensor ubPerTokenMul; - TileRowBroadcastMul tileRowBroadcastMul; - TileBroadcastOneBlk tileBroadcastOneBlk; - TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; + TileRowBroadcastMul tileRowBroadcastMul; + TileBroadcastOneBlk tileBroadcastOneBlk; + TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; - CopyGmToUbC copyGmToUbC; - CopyGmToUbScale copyGmToUbScale; - CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; - CopyUbToGmD copyUbToGmD; + CopyGmToUbC copyGmToUbC; + CopyGmToUbScale copyGmToUbScale; + CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; + CopyUbToGmD copyUbToGmD; }; -} // namespace Act::Epilogue::Block +} // namespace Act::Epilogue::Block -#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP +#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP diff --git a/act/epilogue/dispatch_policy.hpp b/act/epilogue/dispatch_policy.hpp index 0323c274..8d93192d 100644 --- a/act/epilogue/dispatch_policy.hpp +++ b/act/epilogue/dispatch_policy.hpp @@ -20,56 +20,57 @@ namespace Act::Epilogue { // For AtlasA2, an element wise epilogue of the form D = C + X, where X is an // additional source struct EpilogueAtlasA2ElemWiseOneSource { - using ArchTag = Arch::AtlasA2; - // Number of operands. Including C, X, and D 3 operands - static constexpr uint32_t OPERANDS_NUM = 3; + using ArchTag = Arch::AtlasA2; + // Number of operands. Including C, X, and D 3 operands + static constexpr uint32_t OPERANDS_NUM = 3; }; // For AtlasA2, FA Softmax struct EpilogueAtlasA2FASoftmax { - using ArchTag = Arch::AtlasA2; + using ArchTag = Arch::AtlasA2; }; // For AtlasA2, FA RescaleO struct EpilogueAtlasA2FARescaleO { - using ArchTag = Arch::AtlasA2; + using ArchTag = Arch::AtlasA2; }; // For AtlasA2, MLA Softmax struct EpilogueAtlasA2MLASoftmax { - using ArchTag = Arch::AtlasA2; + using ArchTag = Arch::AtlasA2; }; // For AtlasA2, MLA RescaleO struct EpilogueAtlasA2MLARescaleO { - using ArchTag = Arch::AtlasA2; + using ArchTag = Arch::AtlasA2; }; // For AtlasA2, MLA FD RescaleO -template struct EpilogueAtlasA2MLAFDRescaleO { - using ArchTag = Arch::AtlasA2; - static constexpr uint32_t KV_SPLIT_MAX = 64; - static constexpr uint32_t HEADS_PROCESS_MAX = 16; - static constexpr uint32_t COMPUTE_ELE_NUM = COMPUTE_ELE_NUM_; +template +struct EpilogueAtlasA2MLAFDRescaleO { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t KV_SPLIT_MAX = 64; + static constexpr uint32_t HEADS_PROCESS_MAX = 16; + static constexpr uint32_t COMPUTE_ELE_NUM = COMPUTE_ELE_NUM_; }; // For AtlasA2, MLA TP1 Softmax struct EpilogueAtlasA2MLATP1Softmax { - using ArchTag = Arch::AtlasA2; + using ArchTag = Arch::AtlasA2; }; // For AtlasA2, MLA TP1 RescaleO struct EpilogueAtlasA2MLATP1RescaleO { - using ArchTag = Arch::AtlasA2; + using ArchTag = Arch::AtlasA2; }; // For AtlasA2, per token dequant template struct EpilogueAtlasA2PerTokenDequant { - using ArchTag = Arch::AtlasA2; - static constexpr uint32_t UB_STAGES = UB_STAGES_; - static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t UB_STAGES = UB_STAGES_; + static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; }; -} // namespace Act::Epilogue +} // namespace Act::Epilogue -#endif // ACT_EPILOGUE_DISPATCH_POLICY_HPP +#endif // ACT_EPILOGUE_DISPATCH_POLICY_HPP diff --git a/act/epilogue/tile/copy_gm_to_ub.hpp b/act/epilogue/tile/copy_gm_to_ub.hpp index ede41844..1a9d3b40 100644 --- a/act/epilogue/tile/copy_gm_to_ub.hpp +++ b/act/epilogue/tile/copy_gm_to_ub.hpp @@ -19,56 +19,51 @@ namespace Act::Epilogue::Tile { -template struct CopyGm2Ub { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded copy gm to ub, can not find the specialization."); +template +struct CopyGm2Ub { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy gm to ub, can not find the specialization."); }; template struct CopyGm2Ub> { - using LayoutSrc = layout::RowMajor; - using LayoutDst = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - - ACT_DEVICE - CopyGm2Ub() = default; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - layout::RowMajor const &layoutDst, - layout::RowMajor const &layoutSrc) { - AscendC::DataCopyExtParams dataCopyParams( - layoutSrc.shape(0), layoutSrc.shape(1) * sizeof(Element), - (layoutSrc.stride(0) - layoutSrc.shape(1)) * sizeof(Element), - (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK, 0); - AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); - AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); - }; + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyGm2Ub() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(layoutSrc.shape(0), layoutSrc.shape(1) * sizeof(Element), + (layoutSrc.stride(0) - layoutSrc.shape(1)) * sizeof(Element), + (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + }; }; template struct CopyGm2Ub> { - using LayoutSrc = layout::VectorLayout; - using LayoutDst = layout::VectorLayout; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - - ACT_DEVICE - CopyGm2Ub() = default; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - layout::VectorLayout const &layoutDst, - layout::VectorLayout const &layoutSrc) { - AscendC::DataCopyExtParams dataCopyParams( - 1, layoutSrc.shape(0) * sizeof(Element), 0, 0, 0); - AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); - AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); - }; + using LayoutSrc = layout::VectorLayout; + using LayoutDst = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyGm2Ub() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + layout::VectorLayout const &layoutDst, layout::VectorLayout const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(1, layoutSrc.shape(0) * sizeof(Element), 0, 0, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + }; }; /// @brief This copy instruction used to copy per token scale from GM to UB. @@ -77,98 +72,85 @@ struct CopyGm2Ub> { /// element type is float). /// @tparam ArchTag: Architecture tag. /// @tparam GmType: Type of data on GM. -template struct CopyPerTokenScale2Ub { - static_assert(std::is_same_v, - "Unsupporteded layout for CopyPerTokenScale2Ub."); - - using Element = typename GmType::Element; - using LayoutSrc = typename GmType::Layout; - using LayoutDst = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - - ACT_DEVICE - CopyPerTokenScale2Ub() = default; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::DataCopyExtParams dataCopyParams; - AscendC::DataCopyPadExtParams padParams; - - dataCopyParams.blockCount = layoutSrc.shape(0); - dataCopyParams.blockLen = - layoutSrc.shape(1) * - sizeof(Element); // per token scale has only one column - dataCopyParams.srcStride = 0; - dataCopyParams.dstStride = - (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; - // Pad the data to the complete block - padParams.isPad = true; - padParams.leftPadding = 0; - padParams.rightPadding = 0; - - AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); - } +template +struct CopyPerTokenScale2Ub { + static_assert(std::is_same_v, + "Unsupporteded layout for CopyPerTokenScale2Ub."); + + using Element = typename GmType::Element; + using LayoutSrc = typename GmType::Layout; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyPerTokenScale2Ub() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams; + AscendC::DataCopyPadExtParams padParams; + + dataCopyParams.blockCount = layoutSrc.shape(0); + dataCopyParams.blockLen = layoutSrc.shape(1) * sizeof(Element); // per token scale has only one column + dataCopyParams.srcStride = 0; + dataCopyParams.dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + // Pad the data to the complete block + padParams.isPad = true; + padParams.leftPadding = 0; + padParams.rightPadding = 0; + + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); + } }; -template struct CopyGm2UbAligned { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded copy gm to ub aligned, can not find the specialization."); +template +struct CopyGm2UbAligned { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy gm to ub aligned, can not find the specialization."); }; template -struct CopyGm2UbAligned> { - using LayoutSrc = layout::RowMajor; - using LayoutDst = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; - static constexpr uint32_t MAX_REPEAT = 4095; - static constexpr uint32_t STRIDE_LIMIT = 65536; - - ACT_DEVICE - CopyGm2UbAligned() = default; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - layout::RowMajor const &layoutDst, - layout::RowMajor const &layoutSrc) { - uint32_t rows = layoutSrc.shape(0); - uint32_t cols = layoutSrc.shape(1); - uint32_t srcStride = - (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; - uint32_t dstStride = - (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; - - if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && - (layoutDst.shape(1) == layoutDst.stride(0))) { - DataCopy(dstTensor, srcTensor, rows * cols); - } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && - (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { - uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); - for (uint32_t i = 0; i < rLoops; ++i) { - uint32_t rActual = - (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; - AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, - srcStride, dstStride); - DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], - srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], - dataCopyParams); - } - } else { - for (uint32_t i = 0; i < rows; ++i) { - DataCopy(dstTensor[i * layoutDst.stride(0)], - srcTensor[i * layoutSrc.stride(0)], cols); - } - } - }; +struct CopyGm2UbAligned> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + static constexpr uint32_t STRIDE_LIMIT = 65536; + + ACT_DEVICE + CopyGm2UbAligned() = default; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + uint32_t rows = layoutSrc.shape(0); + uint32_t cols = layoutSrc.shape(1); + uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); + } + } + }; }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile #endif diff --git a/act/epilogue/tile/copy_ub_to_gm.hpp b/act/epilogue/tile/copy_ub_to_gm.hpp index 2c584048..651f4342 100644 --- a/act/epilogue/tile/copy_ub_to_gm.hpp +++ b/act/epilogue/tile/copy_ub_to_gm.hpp @@ -19,113 +19,97 @@ namespace Act::Epilogue::Tile { -template struct CopyUb2Gm { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded copy ub to gm, can not find the specialization."); +template +struct CopyUb2Gm { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy ub to gm, can not find the specialization."); }; template struct CopyUb2Gm> { - using LayoutDst = layout::RowMajor; - using LayoutSrc = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyUb2Gm() = default; - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - layout::RowMajor const &layoutDst, - layout::RowMajor const &layoutSrc) { - AscendC::DataCopyExtParams dataCopyParams( - layoutDst.shape(0), layoutDst.shape(1) * sizeof(Element), - (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_C0, - (layoutDst.stride(0) - layoutDst.shape(1)) * sizeof(Element), 0); - AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); - } + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyUb2Gm() = default; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(layoutDst.shape(0), layoutDst.shape(1) * sizeof(Element), + (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_C0, + (layoutDst.stride(0) - layoutDst.shape(1)) * sizeof(Element), 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); + } }; // new add vectorlayout version template struct CopyUb2Gm> { - using LayoutSrc = layout::VectorLayout; - using LayoutDst = layout::VectorLayout; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - - ACT_DEVICE - CopyUb2Gm() = default; - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - layout::VectorLayout const &layoutDst, - layout::VectorLayout const &layoutSrc) { - AscendC::DataCopyExtParams dataCopyParams( - 1, layoutDst.shape(0) * sizeof(Element), 0, 0, 0); - AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); - }; + using LayoutSrc = layout::VectorLayout; + using LayoutDst = layout::VectorLayout; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + + ACT_DEVICE + CopyUb2Gm() = default; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + layout::VectorLayout const &layoutDst, layout::VectorLayout const &layoutSrc) + { + AscendC::DataCopyExtParams dataCopyParams(1, layoutDst.shape(0) * sizeof(Element), 0, 0, 0); + AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); + }; }; -template struct CopyUb2GmAligned { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded copy ub to gm aligned, can not find the specialization."); +template +struct CopyUb2GmAligned { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy ub to gm aligned, can not find the specialization."); }; template -struct CopyUb2GmAligned> { - using LayoutSrc = layout::RowMajor; - using LayoutDst = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; - static constexpr uint32_t MAX_REPEAT = 4095; - static constexpr uint32_t STRIDE_LIMIT = 65536; - - ACT_DEVICE - CopyUb2GmAligned() = default; - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - layout::RowMajor const &layoutDst, - layout::RowMajor const &layoutSrc) { - uint32_t rows = layoutDst.shape(0); - uint32_t cols = layoutDst.shape(1); - uint32_t srcStride = - (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; - uint32_t dstStride = - (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; - - if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && - (layoutDst.shape(1) == layoutDst.stride(0))) { - DataCopy(dstTensor, srcTensor, rows * cols); - } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && - (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { - uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); - for (uint32_t i = 0; i < rLoops; ++i) { - uint32_t rActual = - (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; - AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, - srcStride, dstStride); - DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], - srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], - dataCopyParams); - } - } else { - for (uint32_t i = 0; i < rows; ++i) { - DataCopy(dstTensor[i * layoutDst.stride(0)], - srcTensor[i * layoutSrc.stride(0)], cols); - } - } - }; +struct CopyUb2GmAligned> { + using LayoutSrc = layout::RowMajor; + using LayoutDst = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + static constexpr uint32_t STRIDE_LIMIT = 65536; + + ACT_DEVICE + CopyUb2GmAligned() = default; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) + { + uint32_t rows = layoutDst.shape(0); + uint32_t cols = layoutDst.shape(1); + uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); + } + } + }; }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile #endif diff --git a/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp b/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp index da5eeaca..a4a9d8d6 100644 --- a/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp +++ b/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp @@ -25,47 +25,40 @@ template < /// Length of the compute buffer class TileShape_> struct TileBroadcastInplaceByColumn { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - using TileShape = TileShape_; + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; - ACT_DEVICE - TileBroadcastInplaceByColumn() {} + ACT_DEVICE + TileBroadcastInplaceByColumn() {} - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubInOut) { - constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); - constexpr uint32_t blkNumPerRow = TileShape::COLUMN / eleNumPerBlk; + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubInOut) + { + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + constexpr uint32_t blkNumPerRow = TileShape::COLUMN / eleNumPerBlk; - constexpr uint64_t defaultMask = - BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); - constexpr uint64_t tailMask = - (TileShape::ROW % BLK_NUM_PER_VECTOR_FRACTAL) * eleNumPerBlk; + constexpr uint64_t defaultMask = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + constexpr uint64_t tailMask = (TileShape::ROW % BLK_NUM_PER_VECTOR_FRACTAL) * eleNumPerBlk; - constexpr uint8_t repeatTimes = 1; + constexpr uint8_t repeatTimes = 1; - AscendC::CopyRepeatParams repeatParams; - repeatParams.dstStride = blkNumPerRow; - repeatParams.srcStride = blkNumPerRow; - repeatParams.dstRepeatSize = 1; - repeatParams.srcRepeatSize = 1; + AscendC::CopyRepeatParams repeatParams; + repeatParams.dstStride = blkNumPerRow; + repeatParams.srcStride = blkNumPerRow; + repeatParams.dstRepeatSize = 1; + repeatParams.srcRepeatSize = 1; - for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; - rowOffset += BLK_NUM_PER_VECTOR_FRACTAL) { - uint64_t mask = - ((TileShape::ROW - rowOffset) >= BLK_NUM_PER_VECTOR_FRACTAL) - ? defaultMask - : tailMask; - for (uint32_t colOffset = eleNumPerBlk; colOffset < TileShape::COLUMN; - colOffset += eleNumPerBlk) { - AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN + colOffset], - ubInOut[rowOffset * TileShape::COLUMN], mask, 1, - repeatParams); - } + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += BLK_NUM_PER_VECTOR_FRACTAL) { + uint64_t mask = ((TileShape::ROW - rowOffset) >= BLK_NUM_PER_VECTOR_FRACTAL) ? defaultMask : tailMask; + for (uint32_t colOffset = eleNumPerBlk; colOffset < TileShape::COLUMN; colOffset += eleNumPerBlk) { + AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN + colOffset], + ubInOut[rowOffset * TileShape::COLUMN], mask, 1, repeatParams); + } + } } - } }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile #endif diff --git a/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp b/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp index f507f94c..7ea15659 100644 --- a/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp +++ b/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp @@ -25,34 +25,33 @@ template < /// Length of the compute buffer class TileShape_> struct TileBroadcastInplaceByRow { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileBroadcastInplaceByRow() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubInOut) { - constexpr uint32_t eleNumPerVectorFractal = - BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); - - constexpr uint64_t mask = eleNumPerVectorFractal; - constexpr uint8_t repeatTimes = TileShape::COLUMN / eleNumPerVectorFractal; - - AscendC::CopyRepeatParams repeatParams; - repeatParams.dstStride = 1; - repeatParams.srcStride = 1; - repeatParams.dstRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; - repeatParams.srcRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; - - for (uint32_t rowOffset = 1; rowOffset < TileShape::ROW; ++rowOffset) { - AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN], ubInOut, mask, - repeatTimes, repeatParams); + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileBroadcastInplaceByRow() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubInOut) + { + constexpr uint32_t eleNumPerVectorFractal = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + + constexpr uint64_t mask = eleNumPerVectorFractal; + constexpr uint8_t repeatTimes = TileShape::COLUMN / eleNumPerVectorFractal; + + AscendC::CopyRepeatParams repeatParams; + repeatParams.dstStride = 1; + repeatParams.srcStride = 1; + repeatParams.dstRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; + repeatParams.srcRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; + + for (uint32_t rowOffset = 1; rowOffset < TileShape::ROW; ++rowOffset) { + AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN], ubInOut, mask, repeatTimes, repeatParams); + } } - } }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile #endif diff --git a/act/epilogue/tile/tile_broadcast_mul.hpp b/act/epilogue/tile/tile_broadcast_mul.hpp index 9e31d69a..93b6125f 100644 --- a/act/epilogue/tile/tile_broadcast_mul.hpp +++ b/act/epilogue/tile/tile_broadcast_mul.hpp @@ -28,48 +28,44 @@ namespace Act::Epilogue::Tile { /// @tparam TileShape_ is the shape (m, n). template struct TileRowBroadcastMul { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileRowBroadcastMul() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, - AscendC::LocalTensor const &ubIn0, - AscendC::LocalTensor const &ubIn1) { - constexpr uint32_t maxRepeatTimes = 255; - constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); - - constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; - AscendC::BinaryRepeatParams repeatParams; - repeatParams.dstBlkStride = 1; - repeatParams.src0BlkStride = 1; - repeatParams.src1BlkStride = 1; - repeatParams.dstRepStride = blkNumPerColumn; - repeatParams.src0RepStride = blkNumPerColumn; - repeatParams.src1RepStride = 0; - - constexpr uint32_t rowNumPerCompute = maxRepeatTimes; - constexpr uint32_t colNumPerCompute = - BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); - for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; - rowOffset += rowNumPerCompute) { - uint32_t residueM = TileShape::ROW - rowOffset; - uint8_t repeatTimes = static_cast( - (residueM > rowNumPerCompute) ? rowNumPerCompute : residueM); - for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; - colOffset += colNumPerCompute) { - uint32_t residueN = TileShape::COLUMN - colOffset; - uint64_t mask = - (residueN > colNumPerCompute) ? colNumPerCompute : residueN; - AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], - ubIn0[rowOffset * TileShape::COLUMN + colOffset], - ubIn1[colOffset], mask, repeatTimes, repeatParams); - } + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileRowBroadcastMul() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + constexpr uint32_t maxRepeatTimes = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.src0BlkStride = 1; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = blkNumPerColumn; + repeatParams.src0RepStride = blkNumPerColumn; + repeatParams.src1RepStride = 0; + + constexpr uint32_t rowNumPerCompute = maxRepeatTimes; + constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint8_t repeatTimes = static_cast((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM); + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint64_t mask = (residueN > colNumPerCompute) ? colNumPerCompute : residueN; + AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], + ubIn0[rowOffset * TileShape::COLUMN + colOffset], ubIn1[colOffset], mask, repeatTimes, + repeatParams); + } + } } - } }; /// @brief Compute the elementwise multiplication of a tensor of shape (m, n) @@ -80,52 +76,47 @@ struct TileRowBroadcastMul { /// @tparam TileShape_ is the shape (m, n). template struct TileOneBlkColumnBroadcastMul { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileOneBlkColumnBroadcastMul() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, - AscendC::LocalTensor const &ubIn0, - AscendC::LocalTensor const &ubIn1) { - constexpr uint32_t maxRepeatNum = 255; - constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); - - constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; - AscendC::BinaryRepeatParams repeatParams; - repeatParams.dstBlkStride = blkNumPerColumn; - repeatParams.src0BlkStride = blkNumPerColumn; - repeatParams.src1BlkStride = 1; - repeatParams.dstRepStride = 1; - repeatParams.src0RepStride = 1; - repeatParams.src1RepStride = 0; - - constexpr uint32_t rowNumPerCompute = BLK_NUM_PER_VECTOR_FRACTAL; - constexpr uint32_t colNumPerCompute = eleNumPerBlk * maxRepeatNum; - for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; - rowOffset += rowNumPerCompute) { - uint32_t residueM = TileShape::ROW - rowOffset; - uint64_t mask = - ((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM) * - eleNumPerBlk; - for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; - colOffset += colNumPerCompute) { - uint32_t residueN = TileShape::COLUMN - colOffset; - uint8_t repeatTimes = static_cast( - ((residueN > colNumPerCompute) ? colNumPerCompute : residueN) / - eleNumPerBlk); - AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], - ubIn0[rowOffset * TileShape::COLUMN + colOffset], - ubIn1[rowOffset * eleNumPerBlk], mask, repeatTimes, - repeatParams); - } + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileOneBlkColumnBroadcastMul() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + constexpr uint32_t maxRepeatNum = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; + AscendC::BinaryRepeatParams repeatParams; + repeatParams.dstBlkStride = blkNumPerColumn; + repeatParams.src0BlkStride = blkNumPerColumn; + repeatParams.src1BlkStride = 1; + repeatParams.dstRepStride = 1; + repeatParams.src0RepStride = 1; + repeatParams.src1RepStride = 0; + + constexpr uint32_t rowNumPerCompute = BLK_NUM_PER_VECTOR_FRACTAL; + constexpr uint32_t colNumPerCompute = eleNumPerBlk * maxRepeatNum; + for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { + uint32_t residueM = TileShape::ROW - rowOffset; + uint64_t mask = ((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM) * eleNumPerBlk; + for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { + uint32_t residueN = TileShape::COLUMN - colOffset; + uint8_t repeatTimes = + static_cast(((residueN > colNumPerCompute) ? colNumPerCompute : residueN) / eleNumPerBlk); + AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], + ubIn0[rowOffset * TileShape::COLUMN + colOffset], ubIn1[rowOffset * eleNumPerBlk], mask, + repeatTimes, repeatParams); + } + } } - } }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile #endif diff --git a/act/epilogue/tile/tile_broadcast_one_blk.hpp b/act/epilogue/tile/tile_broadcast_one_blk.hpp index 799a1bd1..d8f7d79d 100644 --- a/act/epilogue/tile/tile_broadcast_one_blk.hpp +++ b/act/epilogue/tile/tile_broadcast_one_blk.hpp @@ -19,38 +19,33 @@ namespace Act::Epilogue::Tile { template struct TileBroadcastOneBlk { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; - - ACT_DEVICE - TileBroadcastOneBlk() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, - AscendC::LocalTensor const &ubIn) { - constexpr uint32_t maxRepeatNum = 255; - constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); - - AscendC::BrcbRepeatParams repeatParams; - repeatParams.dstBlkStride = 1; - repeatParams.dstRepStride = BLK_NUM_PER_VECTOR_FRACTAL; - - constexpr uint32_t eleNumPerCompute = - RoundDown(maxRepeatNum * BLK_NUM_PER_VECTOR_FRACTAL); - for (uint32_t offset = 0; offset < COMPUTE_LENGTH; - offset += eleNumPerCompute) { - uint32_t residueM = COMPUTE_LENGTH - offset; - uint32_t computeM = - (residueM > eleNumPerCompute) ? eleNumPerCompute : residueM; - uint8_t repeatTimes = - static_cast(CeilDiv(computeM)); - AscendC::Brcb(ubOut[offset * eleNumPerBlk], ubIn[offset], repeatTimes, - repeatParams); + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + ACT_DEVICE + TileBroadcastOneBlk() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, AscendC::LocalTensor const &ubIn) + { + constexpr uint32_t maxRepeatNum = 255; + constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); + + AscendC::BrcbRepeatParams repeatParams; + repeatParams.dstBlkStride = 1; + repeatParams.dstRepStride = BLK_NUM_PER_VECTOR_FRACTAL; + + constexpr uint32_t eleNumPerCompute = RoundDown(maxRepeatNum * BLK_NUM_PER_VECTOR_FRACTAL); + for (uint32_t offset = 0; offset < COMPUTE_LENGTH; offset += eleNumPerCompute) { + uint32_t residueM = COMPUTE_LENGTH - offset; + uint32_t computeM = (residueM > eleNumPerCompute) ? eleNumPerCompute : residueM; + uint8_t repeatTimes = static_cast(CeilDiv(computeM)); + AscendC::Brcb(ubOut[offset * eleNumPerBlk], ubIn[offset], repeatTimes, repeatParams); + } } - } }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile #endif diff --git a/act/epilogue/tile/tile_cast.hpp b/act/epilogue/tile/tile_cast.hpp index c0fa588d..50162516 100644 --- a/act/epilogue/tile/tile_cast.hpp +++ b/act/epilogue/tile/tile_cast.hpp @@ -25,21 +25,21 @@ template < /// Length of the compute buffer class TileShape_> struct TileCast { - using ArchTag = ArchTag_; - using ElementDst = typename DstType_::Element; - using ElementSrc = typename SrcType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileCast() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, - AscendC::LocalTensor const &ubIn) { - AscendC::Cast(ubOut, ubIn, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); - } + using ArchTag = ArchTag_; + using ElementDst = typename DstType_::Element; + using ElementSrc = typename SrcType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileCast() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, AscendC::LocalTensor const &ubIn) + { + AscendC::Cast(ubOut, ubIn, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); + } }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile #endif diff --git a/act/epilogue/tile/tile_copy.hpp b/act/epilogue/tile/tile_copy.hpp index abc7c96f..2ed7c9c7 100644 --- a/act/epilogue/tile/tile_copy.hpp +++ b/act/epilogue/tile/tile_copy.hpp @@ -22,8 +22,7 @@ template < /// Tag indicating architecture class ArchTag, class... Args> struct TileCopy { - static_assert(DEPENDENT_FALSE, - "Unsupporteded tile copy, can not find the specialization."); + static_assert(DEPENDENT_FALSE, "Unsupporteded tile copy, can not find the specialization."); }; template struct TileCopy { - using ElementC = typename CType::Element; - using ElementX = typename XType::Element; - using ElementD = typename DType::Element; - - using CopyGmToUbC = CopyGm2Ub; - using CopyGmToUbX = CopyGm2Ub; - using CopyUbToGmD = CopyUb2Gm; - using CopyGmToUbY = CopyGm2Ub; - using CopyGmToUbTemp = CopyGm2Ub; - using CopyUbToGmZ = CopyUb2Gm; + using ElementC = typename CType::Element; + using ElementX = typename XType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub; + using CopyUbToGmD = CopyUb2Gm; + using CopyGmToUbY = CopyGm2Ub; + using CopyGmToUbTemp = CopyGm2Ub; + using CopyUbToGmZ = CopyUb2Gm; }; template struct TileCopy { - using ElementC = typename CType::Element; - using ElementX = typename XType::Element; - using ElementY = typename YType::Element; - using ElementD = typename DType::Element; - - using CopyGmToUbC = CopyGm2Ub; - using CopyGmToUbX = CopyGm2Ub; - using CopyGmToUbY = CopyGm2Ub; - using CopyUbToGmD = CopyUb2Gm; + using ElementC = typename CType::Element; + using ElementX = typename XType::Element; + using ElementY = typename YType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub; + using CopyGmToUbY = CopyGm2Ub; + using CopyUbToGmD = CopyUb2Gm; }; template struct TileCopyBf16 { - using ElementC = typename CType::Element; - using ElementX = bfloat16_t; - using ElementY = bfloat16_t; - using ElementD = bfloat16_t; - - using CopyGmToUbC = CopyGm2Ub; - using CopyGmToUbX = - CopyGm2Ub>; - using CopyGmToUbY = - CopyGm2Ub>; - using CopyUbToGmD = - CopyUb2Gm>; + using ElementC = typename CType::Element; + using ElementX = bfloat16_t; + using ElementY = bfloat16_t; + using ElementD = bfloat16_t; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbX = CopyGm2Ub>; + using CopyGmToUbY = CopyGm2Ub>; + using CopyUbToGmD = CopyUb2Gm>; }; -template +template struct TileCopyPerTokenDequant { - using ElementC = typename CType::Element; - using ElementScale = typename ScaleType::Element; - using ElementPerTokenScale = typename PerTokenScaleType::Element; - using ElementD = typename DType::Element; - - using CopyGmToUbC = CopyGm2Ub; - using CopyGmToUbScale = CopyGm2Ub; - using CopyGmToUbPerTokenScale = - CopyPerTokenScale2Ub; - using CopyUbToGmD = CopyUb2Gm; + using ElementC = typename CType::Element; + using ElementScale = typename ScaleType::Element; + using ElementPerTokenScale = typename PerTokenScaleType::Element; + using ElementD = typename DType::Element; + + using CopyGmToUbC = CopyGm2Ub; + using CopyGmToUbScale = CopyGm2Ub; + using CopyGmToUbPerTokenScale = CopyPerTokenScale2Ub; + using CopyUbToGmD = CopyUb2Gm; }; -template +template struct TileCopyPerTokenDequantGemm { - using ElementX = typename XType::Element; - using ElementScale = typename ScaleType::Element; - using ElementPerTokenScale = typename PerTokenScaleType::Element; - using ElementBias = typename BiasType::Element; - using ElementC = typename CType::Element; - - using CopyGmToUbX = CopyGm2Ub; - using CopyGmToUbScale = CopyGm2Ub; - using CopyGmToUbPerTokenScale = CopyGm2Ub; - using CopyGmToUbBias = CopyGm2Ub; - using CopyUbToGmC = CopyUb2Gm; + using ElementX = typename XType::Element; + using ElementScale = typename ScaleType::Element; + using ElementPerTokenScale = typename PerTokenScaleType::Element; + using ElementBias = typename BiasType::Element; + using ElementC = typename CType::Element; + + using CopyGmToUbX = CopyGm2Ub; + using CopyGmToUbScale = CopyGm2Ub; + using CopyGmToUbPerTokenScale = CopyGm2Ub; + using CopyGmToUbBias = CopyGm2Ub; + using CopyUbToGmC = CopyUb2Gm; }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile -#endif // ACT_EPILOGUE_TILE_TILE_COPY_HPP +#endif // ACT_EPILOGUE_TILE_TILE_COPY_HPP diff --git a/act/epilogue/tile/tile_elemwise_add.hpp b/act/epilogue/tile/tile_elemwise_add.hpp index 047fefc6..8edcc1f9 100644 --- a/act/epilogue/tile/tile_elemwise_add.hpp +++ b/act/epilogue/tile/tile_elemwise_add.hpp @@ -25,23 +25,24 @@ template < /// Length of the compute buffer uint32_t COMPUTE_LENGTH_> struct TileElemWiseAdd { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - - static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; - - ACT_DEVICE - TileElemWiseAdd() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, - AscendC::LocalTensor const &ubIn0, - AscendC::LocalTensor const &ubIn1) { - // Do the calculation - AscendC::Add(ubOut, ubIn0, ubIn1, COMPUTE_LENGTH); - } + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + + ACT_DEVICE + TileElemWiseAdd() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + // Do the calculation + AscendC::Add(ubOut, ubIn0, ubIn1, COMPUTE_LENGTH); + } }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile #endif diff --git a/act/epilogue/tile/tile_elemwise_mul.hpp b/act/epilogue/tile/tile_elemwise_mul.hpp index f79ea98e..cfc45739 100644 --- a/act/epilogue/tile/tile_elemwise_mul.hpp +++ b/act/epilogue/tile/tile_elemwise_mul.hpp @@ -25,22 +25,23 @@ template < /// Length of the compute buffer class TileShape_> struct TileElemwiseMul { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileElemwiseMul() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, - AscendC::LocalTensor const &ubIn0, - AscendC::LocalTensor const &ubIn1) { - // Do the calculation - AscendC::Mul(ubOut, ubIn0, ubIn1, TileShape::COUNT); - } + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; + using TileShape = TileShape_; + + ACT_DEVICE + TileElemwiseMul() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &ubOut, + AscendC::LocalTensor const &ubIn0, + AscendC::LocalTensor const &ubIn1) + { + // Do the calculation + AscendC::Mul(ubOut, ubIn0, ubIn1, TileShape::COUNT); + } }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile #endif diff --git a/act/epilogue/tile/tile_elemwise_muls.hpp b/act/epilogue/tile/tile_elemwise_muls.hpp index 8af5d5c7..9bf10fa9 100644 --- a/act/epilogue/tile/tile_elemwise_muls.hpp +++ b/act/epilogue/tile/tile_elemwise_muls.hpp @@ -18,21 +18,21 @@ namespace Act::Epilogue::Tile { template struct TileElemWiseMuls { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; + using ArchTag = ArchTag_; + using ElementCompute = typename ComputeType_::Element; - static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; + static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; - ACT_DEVICE - TileElemWiseMuls() {} + ACT_DEVICE + TileElemWiseMuls() {} - ACT_DEVICE - void operator()(AscendC::LocalTensor dstLocal, - AscendC::LocalTensor srcTensor, - ElementCompute scalar) { - AscendC::Muls(dstLocal, srcTensor, scalar, COMPUTE_LENGTH); - } + ACT_DEVICE + void operator()(AscendC::LocalTensor dstLocal, AscendC::LocalTensor srcTensor, + ElementCompute scalar) + { + AscendC::Muls(dstLocal, srcTensor, scalar, COMPUTE_LENGTH); + } }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile -#endif // ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP +#endif // ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP diff --git a/act/epilogue/tile/tile_swizzle.hpp b/act/epilogue/tile/tile_swizzle.hpp index 13c05298..490a2a5a 100644 --- a/act/epilogue/tile/tile_swizzle.hpp +++ b/act/epilogue/tile/tile_swizzle.hpp @@ -20,63 +20,73 @@ namespace Act::Epilogue::Tile { struct EpilogueIdentityTileSwizzle { - MatrixCoord blockShape; - MatrixCoord tileShape; - MatrixCoord loopsMN; - - ACT_DEVICE - EpilogueIdentityTileSwizzle() = default; - - ACT_DEVICE - EpilogueIdentityTileSwizzle(MatrixCoord const &blockShape, - MatrixCoord const &tileShape) - : blockShape(blockShape), tileShape(tileShape) { - loopsMN = CeilDiv(blockShape, tileShape); - } - - ACT_DEVICE - uint32_t GetLoops() const { return loopsMN.row() * loopsMN.column(); } - - ACT_DEVICE - MatrixCoord GetTileCoord(uint32_t loopIdx) const { - return MatrixCoord{loopIdx / loopsMN.column(), loopIdx % loopsMN.column()}; - } - - ACT_DEVICE - MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const { - return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); - } + MatrixCoord blockShape; + MatrixCoord tileShape; + MatrixCoord loopsMN; + + ACT_DEVICE + EpilogueIdentityTileSwizzle() = default; + + ACT_DEVICE + EpilogueIdentityTileSwizzle(MatrixCoord const &blockShape, MatrixCoord const &tileShape) + : blockShape(blockShape), tileShape(tileShape) + { + loopsMN = CeilDiv(blockShape, tileShape); + } + + ACT_DEVICE + uint32_t GetLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + ACT_DEVICE + MatrixCoord GetTileCoord(uint32_t loopIdx) const + { + return MatrixCoord{loopIdx / loopsMN.column(), loopIdx % loopsMN.column()}; + } + + ACT_DEVICE + MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const + { + return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); + } }; struct EpilogueHorizontalTileSwizzle { - MatrixCoord blockShape; - MatrixCoord tileShape; - MatrixCoord loopsMN; - - ACT_DEVICE - EpilogueHorizontalTileSwizzle() = default; - - ACT_DEVICE - EpilogueHorizontalTileSwizzle(MatrixCoord const &blockShape, - MatrixCoord const &tileShape) - : blockShape(blockShape), tileShape(tileShape) { - loopsMN = CeilDiv(blockShape, tileShape); - } - - ACT_DEVICE - uint32_t GetLoops() const { return loopsMN.row() * loopsMN.column(); } - - ACT_DEVICE - MatrixCoord GetTileCoord(uint32_t loopIdx) const { - return MatrixCoord{loopIdx % loopsMN.row(), loopIdx / loopsMN.row()}; - } - - ACT_DEVICE - MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const { - return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); - } + MatrixCoord blockShape; + MatrixCoord tileShape; + MatrixCoord loopsMN; + + ACT_DEVICE + EpilogueHorizontalTileSwizzle() = default; + + ACT_DEVICE + EpilogueHorizontalTileSwizzle(MatrixCoord const &blockShape, MatrixCoord const &tileShape) + : blockShape(blockShape), tileShape(tileShape) + { + loopsMN = CeilDiv(blockShape, tileShape); + } + + ACT_DEVICE + uint32_t GetLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + ACT_DEVICE + MatrixCoord GetTileCoord(uint32_t loopIdx) const + { + return MatrixCoord{loopIdx % loopsMN.row(), loopIdx / loopsMN.row()}; + } + + ACT_DEVICE + MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const + { + return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); + } }; -} // namespace Act::Epilogue::Tile +} // namespace Act::Epilogue::Tile -#endif // ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP +#endif // ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP diff --git a/act/gemm/block/block_mmad.hpp b/act/gemm/block/block_mmad.hpp index 0a8d8d8f..8da81c80 100644 --- a/act/gemm/block/block_mmad.hpp +++ b/act/gemm/block/block_mmad.hpp @@ -19,49 +19,39 @@ namespace Act::Gemm::Block { -template , - class TileMmad = Gemm::Tile::TileMmad< - typename DispatchPolicy::ArchTag, AType, BType, BiasType>> +template , + class TileMmad = Gemm::Tile::TileMmad> struct BlockMmad { - static_assert(DEPENDENT_FALSE, - "BlockMmad is not implemented for this DispatchPolicy"); + static_assert(DEPENDENT_FALSE, "BlockMmad is not implemented for this DispatchPolicy"); }; -template , - class TileMmad = Gemm::Tile::TileMmadTla< - typename DispatchPolicy::ArchTag, typename TileCopy::TensorL0A, - typename TileCopy::TensorL0B, typename TileCopy::TensorL0C>> +template , + class TileMmad = Gemm::Tile::TileMmadTla> struct BlockMmadTla { - static_assert(DEPENDENT_FALSE, - "BlockMmadTla is not implemented for this DispatchPolicy"); + static_assert(DEPENDENT_FALSE, "BlockMmadTla is not implemented for this DispatchPolicy"); }; /// new add for the reason that i am using the dispatchpolicy which is same as /// the policy of the optimized_matmul // so i add a new one class to avoid the conflict -template < - class DispatchPolicy, class L1TileShape, class L0TileShape, class AType, - class BType, class CType, class BiasType = void, - class TileCopy = - Gemm::Tile::TileCopyGemm, // change the name - class TileMmad = Gemm::Tile::TileMmad> +template , // change the name + class TileMmad = Gemm::Tile::TileMmad> struct BlockGemm { - static_assert(DEPENDENT_FALSE, - "BlockMmad is not implemented for this DispatchPolicy"); + static_assert(DEPENDENT_FALSE, "BlockMmad is not implemented for this DispatchPolicy"); }; -} // namespace Act::Gemm::Block +} // namespace Act::Gemm::Block #include "../../../act/gemm/block/block_mmad_preload_async_with_callback.hpp" -#endif // ACT_GEMM_BLOCK_BLOCK_MMAD_HPP +#endif // ACT_GEMM_BLOCK_BLOCK_MMAD_HPP diff --git a/act/gemm/block/block_mmad_preload_async_with_callback.hpp b/act/gemm/block/block_mmad_preload_async_with_callback.hpp index 7b28f80b..324f9799 100644 --- a/act/gemm/block/block_mmad_preload_async_with_callback.hpp +++ b/act/gemm/block/block_mmad_preload_async_with_callback.hpp @@ -23,433 +23,388 @@ namespace Act::Gemm::Block { -template -struct BlockMmad, - L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, - TileCopy_, TileMmad_> { +template +struct BlockMmad, + L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> { public: - // Type Aliases - using DispatchPolicy = - MmadAtlasA2PreloadAsyncWithCallback; - using ArchTag = typename DispatchPolicy::ArchTag; - using L1TileShape = L1TileShape_; - using L0TileShape = L0TileShape_; - using ElementA = typename AType_::Element; - using LayoutA = typename AType_::Layout; - using ElementB = typename BType_::Element; - using LayoutB = typename BType_::Layout; - using ElementC = typename CType_::Element; - using LayoutC = typename CType_::Layout; - using TileMmad = TileMmad_; - using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; - using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; - using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; - using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; - using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; - using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< - ElementA, ElementB>::ElementAccumulator; - using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; - using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; - using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; - using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; - using LayoutCInL0 = layout::zN; - - using L1AAlignHelper = Gemm::helper::L1AlignHelper; - using L1BAlignHelper = Gemm::helper::L1AlignHelper; - - static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; - static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES; - static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; - static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; - static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; - - static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; - static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; - - // L1 tile size - static constexpr uint32_t L1A_TILE_SIZE = - L1TileShape::M * L1TileShape::K * sizeof(ElementA); - static constexpr uint32_t L1B_TILE_SIZE = - L1TileShape::N * L1TileShape::K * sizeof(ElementB); - // L0 tile size - static constexpr uint32_t L0A_TILE_SIZE = - L0TileShape::M * L0TileShape::K * sizeof(ElementA); - static constexpr uint32_t L0B_TILE_SIZE = - L0TileShape::K * L0TileShape::N * sizeof(ElementB); - static constexpr uint32_t L0C_TILE_SIZE = - L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); - - // Check LayoutC - static_assert(std::is_same_v, - "LayoutC only support RowMajor yet!"); - - // Check L1TileShape - static_assert((L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE, - "L1TileShape exceeding the L1 space!"); - - // Check L0TileShape - static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, - "L0TileShape exceeding the L0A space!"); - static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, - "L0TileShape exceeding the L0B space!"); - static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, - "L0TileShape exceeding the L0C space!"); - - static_assert(L1TileShape::M == L0TileShape::M && - L1TileShape::N == L0TileShape::N, - "The situation where the basic blocks of L1 and L0 differ on " - "the m and n axes is not supported yet"); - - static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout( - L1TileShape::M, L1TileShape::K); - static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout( - L1TileShape::K, L1TileShape::N); - - ACT_DEVICE - BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) { - InitL1(resource, l1BufAddrStart); - InitL0A(resource); - InitL0B(resource); - InitL0C(resource); - } - - ACT_DEVICE - ~BlockMmad() { - SynchronizeBlock(); - for (uint32_t i = 0; i < L1_STAGES; ++i) { - AscendC::WaitFlag(l1AEventList[i]); - AscendC::WaitFlag(l1BEventList[i]); - } - for (uint32_t i = 0; i < L0A_STAGES; ++i) { - AscendC::WaitFlag(l0AEventList[i]); - } - for (uint32_t i = 0; i < L0B_STAGES; ++i) { - AscendC::WaitFlag(l0BEventList[i]); - } - for (uint32_t i = 0; i < L0C_STAGES; ++i) { - AscendC::WaitFlag(l0CEventList[i]); + // Type Aliases + using DispatchPolicy = MmadAtlasA2PreloadAsyncWithCallback; + using ArchTag = typename DispatchPolicy::ArchTag; + using L1TileShape = L1TileShape_; + using L0TileShape = L0TileShape_; + using ElementA = typename AType_::Element; + using LayoutA = typename AType_::Layout; + using ElementB = typename BType_::Element; + using LayoutB = typename BType_::Layout; + using ElementC = typename CType_::Element; + using LayoutC = typename CType_::Layout; + using TileMmad = TileMmad_; + using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; + using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; + using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; + using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; + using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; + using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; + using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; + using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; + using LayoutCInL0 = layout::zN; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; + static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES; + static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; + static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; + static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; + + static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; + static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; + + // L1 tile size + static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); + static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); + // L0 tile size + static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); + static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); + static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); + + // Check LayoutC + static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); + + // Check L1TileShape + static_assert((L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE, + "L1TileShape exceeding the L1 space!"); + + // Check L0TileShape + static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); + static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); + static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); + + static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, + "The situation where the basic blocks of L1 and L0 differ on " + "the m and n axes is not supported yet"); + + static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); + static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); + + ACT_DEVICE + BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) + { + InitL1(resource, l1BufAddrStart); + InitL0A(resource); + InitL0B(resource); + InitL0C(resource); } - } - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &gmBlockA, - LayoutA const &layoutA, - AscendC::GlobalTensor const &gmBlockB, - LayoutB const &layoutB, - AscendC::GlobalTensor const &gmBlockC, - LayoutC const &layoutC, GemmCoord const &actualShape, - Callback const &callbackBeforeFixpipe, - Callback const &callbackAfterFixpipe) { - uint32_t kTileCount = CeilDiv(actualShape.k()); - - uint32_t mRound = RoundUp(actualShape.m()); - uint32_t nRound = RoundUp(actualShape.n()); - - uint32_t startTileIdx = 0; - if constexpr (ENABLE_SHUFFLE_K) { - startTileIdx = AscendC::GetBlockIdx() % kTileCount; + + ACT_DEVICE + ~BlockMmad() + { + SynchronizeBlock(); + for (uint32_t i = 0; i < L1_STAGES; ++i) { + AscendC::WaitFlag(l1AEventList[i]); + AscendC::WaitFlag(l1BEventList[i]); + } + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + AscendC::WaitFlag(l0AEventList[i]); + } + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + AscendC::WaitFlag(l0BEventList[i]); + } + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + AscendC::WaitFlag(l0CEventList[i]); + } } - for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { - uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) - ? (startTileIdx + kLoopIdx) - : (startTileIdx + kLoopIdx - kTileCount); - - uint32_t kActual = (kTileIdx < kTileCount - 1) - ? L1TileShape::K - : (actualShape.k() - kTileIdx * L1TileShape::K); - - // Emission load instruction from GM to L1 - MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; - MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; - auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; - auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; - // Load first matrix A tile from GM to L1 - AscendC::WaitFlag(l1AEventList[l1ListId]); - auto layoutTileA = - layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); - copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA); - AscendC::SetFlag(l1AEventList[l1ListId]); - // Load first matrix B tile from GM to L1 - AscendC::WaitFlag(l1BEventList[l1ListId]); - auto layoutTileB = - layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); - copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB); - AscendC::SetFlag(l1BEventList[l1ListId]); - - // If the number of preload instructions reaches the upper limit, perform - // an mmad calculation on L1 tile - if (preloadCount == PRELOAD_STAGES) { - L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); - } - - // Store the current load status - uint32_t preloadL1TileMmadParamsId = - (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) - ? (l1TileMmadParamsId + preloadCount) - : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); - auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; - l1TileMmadParams.l1ListId = l1ListId; - l1TileMmadParams.mRound = mRound; - l1TileMmadParams.nRound = nRound; - l1TileMmadParams.kActual = kActual; - l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); - l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); - if (kLoopIdx == kTileCount - 1) { - l1TileMmadParams.gmBlockC = gmBlockC; - l1TileMmadParams.layoutCInGm = - layoutC.GetTileLayout(actualShape.GetCoordMN()); - l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe; - l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe; - } - - if (preloadCount < PRELOAD_STAGES) { - ++preloadCount; - } else { - l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) - ? (l1TileMmadParamsId + 1) - : 0; - } - l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0; + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, + AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, + AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, + GemmCoord const &actualShape, Callback const &callbackBeforeFixpipe, + Callback const &callbackAfterFixpipe) + { + uint32_t kTileCount = CeilDiv(actualShape.k()); + + uint32_t mRound = RoundUp(actualShape.m()); + uint32_t nRound = RoundUp(actualShape.n()); + + uint32_t startTileIdx = 0; + if constexpr (ENABLE_SHUFFLE_K) { + startTileIdx = AscendC::GetBlockIdx() % kTileCount; + } + + for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { + uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? (startTileIdx + kLoopIdx) + : (startTileIdx + kLoopIdx - kTileCount); + + uint32_t kActual = + (kTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K); + + // Emission load instruction from GM to L1 + MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; + MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; + auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; + auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; + // Load first matrix A tile from GM to L1 + AscendC::WaitFlag(l1AEventList[l1ListId]); + auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); + copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA); + AscendC::SetFlag(l1AEventList[l1ListId]); + // Load first matrix B tile from GM to L1 + AscendC::WaitFlag(l1BEventList[l1ListId]); + auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); + copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB); + AscendC::SetFlag(l1BEventList[l1ListId]); + + // If the number of preload instructions reaches the upper limit, perform + // an mmad calculation on L1 tile + if (preloadCount == PRELOAD_STAGES) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + } + + // Store the current load status + uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) + ? (l1TileMmadParamsId + preloadCount) + : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); + auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; + l1TileMmadParams.l1ListId = l1ListId; + l1TileMmadParams.mRound = mRound; + l1TileMmadParams.nRound = nRound; + l1TileMmadParams.kActual = kActual; + l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); + l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); + if (kLoopIdx == kTileCount - 1) { + l1TileMmadParams.gmBlockC = gmBlockC; + l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN()); + l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe; + l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe; + } + + if (preloadCount < PRELOAD_STAGES) { + ++preloadCount; + } else { + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + } + l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0; + } } - } - - ACT_DEVICE - void SynchronizeBlock() { - while (preloadCount > 0) { - L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); - l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) - ? (l1TileMmadParamsId + 1) - : 0; - --preloadCount; + + ACT_DEVICE + void SynchronizeBlock() + { + while (preloadCount > 0) { + L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); + l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; + --preloadCount; + } } - } private: - struct L1TileMmadParams { - uint32_t l1ListId; - uint32_t mRound; - uint32_t nRound; - uint32_t kActual; - bool isKLoopFirst; - bool isKLoopLast; - AscendC::GlobalTensor gmBlockC; - LayoutC layoutCInGm; - Callback callbackBeforeFixpipe; - Callback callbackAfterFixpipe; + struct L1TileMmadParams { + uint32_t l1ListId; + uint32_t mRound; + uint32_t nRound; + uint32_t kActual; + bool isKLoopFirst; + bool isKLoopLast; + AscendC::GlobalTensor gmBlockC; + LayoutC layoutCInGm; + Callback callbackBeforeFixpipe; + Callback callbackAfterFixpipe; + + ACT_DEVICE + L1TileMmadParams() = default; + }; ACT_DEVICE - L1TileMmadParams() = default; - }; - - ACT_DEVICE - void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) { - uint32_t l1AOffset = l1BufAddrStart; - uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES; - for (uint32_t i = 0; i < L1_STAGES; ++i) { - l1ATensorList[i] = resource.l1Buf.template GetBufferByByte( - l1AOffset + L1A_TILE_SIZE * i); - l1BTensorList[i] = resource.l1Buf.template GetBufferByByte( - l1BOffset + L1B_TILE_SIZE * i); - l1AEventList[i] = i; - l1BEventList[i] = i + L1_STAGES; - AscendC::SetFlag(l1AEventList[i]); - AscendC::SetFlag(l1BEventList[i]); - } - } - - ACT_DEVICE - void InitL0A(Arch::Resource &resource) { - for (uint32_t i = 0; i < L0A_STAGES; ++i) { - l0ATensorList[i] = - resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); - l0AEventList[i] = i; - AscendC::SetFlag(l0AEventList[i]); - } - } - - ACT_DEVICE - void InitL0B(Arch::Resource &resource) { - for (uint32_t i = 0; i < L0B_STAGES; ++i) { - l0BTensorList[i] = - resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); - l0BEventList[i] = i + L0A_STAGES; - AscendC::SetFlag(l0BEventList[i]); - } - } - - ACT_DEVICE - void InitL0C(Arch::Resource &resource) { - for (uint32_t i = 0; i < L0C_STAGES; ++i) { - l0CTensorList[i] = - resource.l0CBuf.template GetBufferByByte( - L0C_TILE_SIZE * i); - l0CEventList[i] = i; - AscendC::SetFlag(l0CEventList[i]); + void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) + { + uint32_t l1AOffset = l1BufAddrStart; + uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES; + for (uint32_t i = 0; i < L1_STAGES; ++i) { + l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); + l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); + l1AEventList[i] = i; + l1BEventList[i] = i + L1_STAGES; + AscendC::SetFlag(l1AEventList[i]); + AscendC::SetFlag(l1BEventList[i]); + } } - } - - ACT_DEVICE - void L1TileMmad(L1TileMmadParams const ¶ms) { - uint32_t mPartLoop = CeilDiv(params.mRound); - uint32_t nPartLoop = CeilDiv(params.nRound); - uint32_t kPartLoop = CeilDiv(params.kActual); - auto &l1ATensor = l1ATensorList[params.l1ListId]; - auto &l1BTensor = l1BTensorList[params.l1ListId]; - - auto &l0CTensor = l0CTensorList[l0CListId]; - LayoutCInL0 layoutCInL0 = - LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); - - if constexpr (!ENABLE_UNIT_FLAG) { - if (params.isKLoopFirst) { - AscendC::WaitFlag(l0CEventList[l0CListId]); - } + + ACT_DEVICE + void InitL0A(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0A_STAGES; ++i) { + l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); + l0AEventList[i] = i; + AscendC::SetFlag(l0AEventList[i]); + } } - for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { - uint32_t mPartActual = (mPartIdx < mPartLoop - 1) - ? L0TileShape::M - : (params.mRound - mPartIdx * L0TileShape::M); - - for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { - uint32_t kPartActual = - (kPartIdx < kPartLoop - 1) - ? L0TileShape::K - : (params.kActual - kPartIdx * L0TileShape::K); - - auto &l0ATile = l0ATensorList[l0AListId]; - auto layoutAInL0 = LayoutAInL0::template MakeLayout( - mPartActual, kPartActual); - auto l1AOffset = - MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); - auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; - - AscendC::WaitFlag(l0AEventList[l0AListId]); - if ((mPartIdx == 0) && (kPartIdx == 0)) { - AscendC::WaitFlag( - l1AEventList[params.l1ListId]); + ACT_DEVICE + void InitL0B(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0B_STAGES; ++i) { + l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); + l0BEventList[i] = i + L0A_STAGES; + AscendC::SetFlag(l0BEventList[i]); } - copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); - if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { - AscendC::SetFlag( - l1AEventList[params.l1ListId]); + } + + ACT_DEVICE + void InitL0C(Arch::Resource &resource) + { + for (uint32_t i = 0; i < L0C_STAGES; ++i) { + l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_TILE_SIZE * i); + l0CEventList[i] = i; + AscendC::SetFlag(l0CEventList[i]); } + } - for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { - uint32_t nPartActual = - (nPartIdx < nPartLoop - 1) - ? L0TileShape::N - : (params.nRound - nPartIdx * L0TileShape::N); - - auto &l0BTile = l0BTensorList[l0BListId]; - auto layoutBInL0 = LayoutBInL0::template MakeLayout( - kPartActual, nPartActual); - auto l1BOffset = - MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); - auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; - - AscendC::WaitFlag( - l0BEventList[l0BListId]); - if ((kPartIdx == 0) && (nPartIdx == 0)) { - AscendC::WaitFlag( - l1BEventList[params.l1ListId]); - } - copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); - if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { - AscendC::SetFlag( - l1BEventList[params.l1ListId]); - } - - AscendC::SetFlag(EVENT_ID0); - - auto l0COffset = - MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); - auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; - - AscendC::WaitFlag(EVENT_ID0); - // If the current tile is the first tile on the k axis, the - // accumulator needs to be reset to 0 - bool initC = (params.isKLoopFirst && (kPartIdx == 0)); - // If the unit flag is enabled, the unit flag is set according to the - // calculation progress - uint8_t unitFlag = 0b00; - if constexpr (ENABLE_UNIT_FLAG) { - if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && - (kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { - unitFlag = 0b11; - } else { - unitFlag = 0b10; + ACT_DEVICE + void L1TileMmad(L1TileMmadParams const ¶ms) + { + uint32_t mPartLoop = CeilDiv(params.mRound); + uint32_t nPartLoop = CeilDiv(params.nRound); + uint32_t kPartLoop = CeilDiv(params.kActual); + auto &l1ATensor = l1ATensorList[params.l1ListId]; + auto &l1BTensor = l1BTensorList[params.l1ListId]; + + auto &l0CTensor = l0CTensorList[l0CListId]; + LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); + + if constexpr (!ENABLE_UNIT_FLAG) { + if (params.isKLoopFirst) { + AscendC::WaitFlag(l0CEventList[l0CListId]); } - } - tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, - kPartActual, initC, unitFlag); + } - AscendC::SetFlag(l0BEventList[l0BListId]); - l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { + uint32_t mPartActual = + (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M); + + for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { + uint32_t kPartActual = + (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K); + + auto &l0ATile = l0ATensorList[l0AListId]; + auto layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); + auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); + auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; + + AscendC::WaitFlag(l0AEventList[l0AListId]); + if ((mPartIdx == 0) && (kPartIdx == 0)) { + AscendC::WaitFlag(l1AEventList[params.l1ListId]); + } + copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); + if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { + AscendC::SetFlag(l1AEventList[params.l1ListId]); + } + + for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { + uint32_t nPartActual = + (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N); + + auto &l0BTile = l0BTensorList[l0BListId]; + auto layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); + auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); + auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; + + AscendC::WaitFlag(l0BEventList[l0BListId]); + if ((kPartIdx == 0) && (nPartIdx == 0)) { + AscendC::WaitFlag(l1BEventList[params.l1ListId]); + } + copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); + if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { + AscendC::SetFlag(l1BEventList[params.l1ListId]); + } + + AscendC::SetFlag(EVENT_ID0); + + auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); + auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; + + AscendC::WaitFlag(EVENT_ID0); + // If the current tile is the first tile on the k axis, the + // accumulator needs to be reset to 0 + bool initC = (params.isKLoopFirst && (kPartIdx == 0)); + // If the unit flag is enabled, the unit flag is set according to the + // calculation progress + uint8_t unitFlag = 0b00; + if constexpr (ENABLE_UNIT_FLAG) { + if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && + (nPartIdx == nPartLoop - 1)) { + unitFlag = 0b11; + } else { + unitFlag = 0b10; + } + } + tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); + + AscendC::SetFlag(l0BEventList[l0BListId]); + l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; + } + AscendC::SetFlag(l0AEventList[l0AListId]); + l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; + } } - AscendC::SetFlag(l0AEventList[l0AListId]); - l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; - } - } - if (params.isKLoopLast) { - auto layoutCInGm = params.layoutCInGm; + if (params.isKLoopLast) { + auto layoutCInGm = params.layoutCInGm; - params.callbackBeforeFixpipe(); + params.callbackBeforeFixpipe(); - if constexpr (!ENABLE_UNIT_FLAG) { - AscendC::SetFlag(l0CEventList[l0CListId]); - AscendC::WaitFlag(l0CEventList[l0CListId]); - copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); - AscendC::SetFlag(l0CEventList[l0CListId]); - } else { - copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); - } - l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; + if constexpr (!ENABLE_UNIT_FLAG) { + AscendC::SetFlag(l0CEventList[l0CListId]); + AscendC::WaitFlag(l0CEventList[l0CListId]); + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); + AscendC::SetFlag(l0CEventList[l0CListId]); + } else { + copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); + } + l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; - params.callbackAfterFixpipe(); + params.callbackAfterFixpipe(); + } } - } - - AscendC::LocalTensor l1ATensorList[L1_STAGES]; - AscendC::LocalTensor l1BTensorList[L1_STAGES]; - int32_t l1AEventList[L1_STAGES]; - int32_t l1BEventList[L1_STAGES]; - uint32_t l1ListId{0}; - - AscendC::LocalTensor l0ATensorList[L0A_STAGES]; - int32_t l0AEventList[L0A_STAGES]; - uint32_t l0AListId{0}; - - AscendC::LocalTensor l0BTensorList[L0B_STAGES]; - int32_t l0BEventList[L0B_STAGES]; - uint32_t l0BListId{0}; - - AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; - int32_t l0CEventList[L0C_STAGES_]; - uint32_t l0CListId{0}; - - L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; - uint32_t l1TileMmadParamsId{0}; - uint32_t preloadCount{0}; - - TileMmad tileMmad; - CopyGmToL1A copyGmToL1A; - CopyGmToL1B copyGmToL1B; - CopyL1ToL0A copyL1ToL0A; - CopyL1ToL0B copyL1ToL0B; - CopyL0CToGm copyL0CToGm; + + AscendC::LocalTensor l1ATensorList[L1_STAGES]; + AscendC::LocalTensor l1BTensorList[L1_STAGES]; + int32_t l1AEventList[L1_STAGES]; + int32_t l1BEventList[L1_STAGES]; + uint32_t l1ListId{0}; + + AscendC::LocalTensor l0ATensorList[L0A_STAGES]; + int32_t l0AEventList[L0A_STAGES]; + uint32_t l0AListId{0}; + + AscendC::LocalTensor l0BTensorList[L0B_STAGES]; + int32_t l0BEventList[L0B_STAGES]; + uint32_t l0BListId{0}; + + AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; + int32_t l0CEventList[L0C_STAGES_]; + uint32_t l0CListId{0}; + + L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; + uint32_t l1TileMmadParamsId{0}; + uint32_t preloadCount{0}; + + TileMmad tileMmad; + CopyGmToL1A copyGmToL1A; + CopyGmToL1B copyGmToL1B; + CopyL1ToL0A copyL1ToL0A; + CopyL1ToL0B copyL1ToL0B; + CopyL0CToGm copyL0CToGm; }; -} // namespace Act::Gemm::Block +} // namespace Act::Gemm::Block -#endif // ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP +#endif // ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP diff --git a/act/gemm/block/block_swizzle.hpp b/act/gemm/block/block_swizzle.hpp index 81b3df23..36662d2a 100644 --- a/act/gemm/block/block_swizzle.hpp +++ b/act/gemm/block/block_swizzle.hpp @@ -25,210 +25,219 @@ namespace Act::Gemm::Block { /// Block swizzling function for Gemms template struct GemmIdentityBlockSwizzle { - /// Data members - - GemmCoord problemShape; - MatrixCoord tileMN; - MatrixCoord loopsMN; - - /// Methods - - ACT_DEVICE - GemmIdentityBlockSwizzle() {} - - ACT_DEVICE - GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, - MatrixCoord const &tileMN_) - : problemShape(problemShape_), tileMN(tileMN_) { - loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); - } - - ACT_DEVICE - GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, - MatrixCoord const &tileMN_, - MatrixCoord const &loopsMN_) - : problemShape(problemShape_), tileMN(tileMN_), loopsMN(loopsMN_) {} - - ACT_DEVICE - void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) { - problemShape = problemShape_; - tileMN = tileMN_; - - loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); - } - - ACT_DEVICE - void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, - MatrixCoord const &loopsMN_) { - problemShape = problemShape_; - tileMN = tileMN_; - loopsMN = loopsMN_; - } - - ACT_DEVICE - uint32_t GetCoreLoops() const { return loopsMN.row() * loopsMN.column(); } - - ACT_DEVICE - uint32_t GetBatchIdx(uint32_t taskIdx) { return taskIdx / (GetCoreLoops()); } - - ACT_DEVICE - GemmCoord GetBlockCoord(uint32_t taskIdx) { - uint32_t innerIdx = taskIdx % GetCoreLoops(); - if constexpr (SwizzleDirection == 0) { // Zn - uint32_t tileBlockLoop = CeilDiv(loopsMN.row(), SwizzleOffset); - uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.column()); - uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.column()); - - uint32_t nRow = SwizzleOffset; - if (tileBlockIdx == tileBlockLoop - 1) { - nRow = loopsMN.row() - SwizzleOffset * tileBlockIdx; - } - uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; - uint32_t nIdx = inTileBlockIdx / nRow; - if (tileBlockIdx % 2 == 1) { - nIdx = loopsMN.column() - nIdx - 1; - } - return GemmCoord{mIdx, nIdx, 0}; - } else if constexpr (SwizzleDirection == 1) { // Nz - uint32_t tileBlockLoop = CeilDiv(loopsMN.column(), SwizzleOffset); - uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.row()); - uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.row()); - - uint32_t nCol = SwizzleOffset; - if (tileBlockIdx == tileBlockLoop - 1) { - nCol = loopsMN.column() - SwizzleOffset * tileBlockIdx; - } - uint32_t mIdx = inTileBlockIdx / nCol; - uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; - if (tileBlockIdx % 2 == 1) { - mIdx = loopsMN.row() - mIdx - 1; - } - return GemmCoord{mIdx, nIdx, 0}; + /// Data members + + GemmCoord problemShape; + MatrixCoord tileMN; + MatrixCoord loopsMN; + + /// Methods + + ACT_DEVICE + GemmIdentityBlockSwizzle() {} + + ACT_DEVICE + GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) + : problemShape(problemShape_), tileMN(tileMN_) + { + loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); + } + + ACT_DEVICE + GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, MatrixCoord const &loopsMN_) + : problemShape(problemShape_), tileMN(tileMN_), loopsMN(loopsMN_) + {} + + ACT_DEVICE + void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) + { + problemShape = problemShape_; + tileMN = tileMN_; + + loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); + } + + ACT_DEVICE + void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, MatrixCoord const &loopsMN_) + { + problemShape = problemShape_; + tileMN = tileMN_; + loopsMN = loopsMN_; + } + + ACT_DEVICE + uint32_t GetCoreLoops() const + { + return loopsMN.row() * loopsMN.column(); + } + + ACT_DEVICE + uint32_t GetBatchIdx(uint32_t taskIdx) + { + return taskIdx / (GetCoreLoops()); + } + + ACT_DEVICE + GemmCoord GetBlockCoord(uint32_t taskIdx) + { + uint32_t innerIdx = taskIdx % GetCoreLoops(); + if constexpr (SwizzleDirection == 0) { // Zn + uint32_t tileBlockLoop = CeilDiv(loopsMN.row(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.column()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.column()); + + uint32_t nRow = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nRow = loopsMN.row() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; + uint32_t nIdx = inTileBlockIdx / nRow; + if (tileBlockIdx % 2 == 1) { + nIdx = loopsMN.column() - nIdx - 1; + } + return GemmCoord{mIdx, nIdx, 0}; + } else if constexpr (SwizzleDirection == 1) { // Nz + uint32_t tileBlockLoop = CeilDiv(loopsMN.column(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.row()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.row()); + + uint32_t nCol = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nCol = loopsMN.column() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = inTileBlockIdx / nCol; + uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; + if (tileBlockIdx % 2 == 1) { + mIdx = loopsMN.row() - mIdx - 1; + } + return GemmCoord{mIdx, nIdx, 0}; + } + } + + ACT_DEVICE + GemmCoord GetActualBlockShape(GemmCoord blockCoord) + { + uint32_t mActual = + (blockCoord.m() == (loopsMN.row() - 1)) ? (problemShape.m() - blockCoord.m() * tileMN.row()) : tileMN.row(); + uint32_t nActual = (blockCoord.n() == (loopsMN.column() - 1)) + ? (problemShape.n() - blockCoord.n() * tileMN.column()) + : tileMN.column(); + uint32_t kActual = problemShape.k(); + return GemmCoord{mActual, nActual, kActual}; } - } - - ACT_DEVICE - GemmCoord GetActualBlockShape(GemmCoord blockCoord) { - uint32_t mActual = (blockCoord.m() == (loopsMN.row() - 1)) - ? (problemShape.m() - blockCoord.m() * tileMN.row()) - : tileMN.row(); - uint32_t nActual = - (blockCoord.n() == (loopsMN.column() - 1)) - ? (problemShape.n() - blockCoord.n() * tileMN.column()) - : tileMN.column(); - uint32_t kActual = problemShape.k(); - return GemmCoord{mActual, nActual, kActual}; - } }; /// Block swizzling function for Splitk Gemms template struct SplitkGemmIdentityBlockSwizzle { - /// Data members - - GemmCoord problemShape; - GemmCoord tileShape; - GemmCoord loopsMNK; - uint32_t splitkFactor = 1; // split k dim into virtual cores - - /// Methods - - ACT_DEVICE - SplitkGemmIdentityBlockSwizzle() {} - - ACT_DEVICE - SplitkGemmIdentityBlockSwizzle(GemmCoord const &problemShape_, - GemmCoord const &tileShape_, - uint32_t splitkFactor_ = 1) - : problemShape(problemShape_), tileShape(tileShape_), - splitkFactor(splitkFactor_) { - loopsMNK = CeilDiv(problemShape, tileShape); - } - - ACT_DEVICE - uint32_t GetKIdxBySplitkSliceIdx(uint32_t splitkSliceIdx) const { - if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { - return (loopsMNK.k() / splitkFactor + 1) * splitkSliceIdx; - } else { - return splitkSliceIdx * (loopsMNK.k() / splitkFactor) + - loopsMNK.k() % splitkFactor; + /// Data members + + GemmCoord problemShape; + GemmCoord tileShape; + GemmCoord loopsMNK; + uint32_t splitkFactor = 1; // split k dim into virtual cores + + /// Methods + + ACT_DEVICE + SplitkGemmIdentityBlockSwizzle() {} + + ACT_DEVICE + SplitkGemmIdentityBlockSwizzle(GemmCoord const &problemShape_, GemmCoord const &tileShape_, + uint32_t splitkFactor_ = 1) + : problemShape(problemShape_), tileShape(tileShape_), splitkFactor(splitkFactor_) + { + loopsMNK = CeilDiv(problemShape, tileShape); + } + + ACT_DEVICE + uint32_t GetKIdxBySplitkSliceIdx(uint32_t splitkSliceIdx) const + { + if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { + return (loopsMNK.k() / splitkFactor + 1) * splitkSliceIdx; + } else { + return splitkSliceIdx * (loopsMNK.k() / splitkFactor) + loopsMNK.k() % splitkFactor; + } } - } - - ACT_DEVICE - uint32_t GetSplitkSliceIdx(uint32_t taskIdx) const { - uint32_t mnLoops = loopsMNK.m() * loopsMNK.n(); - return taskIdx % GetCoreLoops() / mnLoops; - } - - ACT_DEVICE - uint32_t GetCoreLoops() const { - return loopsMNK.m() * loopsMNK.n() * splitkFactor; - } - - ACT_DEVICE - uint32_t GetBatchIdx(uint32_t taskIdx) { return taskIdx / GetCoreLoops(); } - - ACT_DEVICE - GemmCoord GetBlockCoord(uint32_t taskIdx) { - uint32_t splitkSliceIdx = GetSplitkSliceIdx(taskIdx); - uint32_t kIdx = GetKIdxBySplitkSliceIdx(splitkSliceIdx); - - uint32_t innerIdx = taskIdx % (loopsMNK.m() * loopsMNK.n()); - if constexpr (SwizzleDirection == 0) { // Zn - uint32_t tileBlockLoop = CeilDiv(loopsMNK.m(), SwizzleOffset); - uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.n()); - uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.n()); - - uint32_t nRow = SwizzleOffset; - if (tileBlockIdx == tileBlockLoop - 1) { - nRow = loopsMNK.m() - SwizzleOffset * tileBlockIdx; - } - uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; - uint32_t nIdx = inTileBlockIdx / nRow; - if (tileBlockIdx % 2 == 1) { - nIdx = loopsMNK.n() - nIdx - 1; - } - return GemmCoord{mIdx, nIdx, kIdx}; - } else if constexpr (SwizzleDirection == 1) { // Nz - uint32_t tileBlockLoop = CeilDiv(loopsMNK.n(), SwizzleOffset); - uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.m()); - uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.m()); - - uint32_t nCol = SwizzleOffset; - if (tileBlockIdx == tileBlockLoop - 1) { - nCol = loopsMNK.n() - SwizzleOffset * tileBlockIdx; - } - uint32_t mIdx = inTileBlockIdx / nCol; - uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; - if (tileBlockIdx % 2 == 1) { - mIdx = loopsMNK.m() - mIdx - 1; - } - return GemmCoord{mIdx, nIdx, kIdx}; + + ACT_DEVICE + uint32_t GetSplitkSliceIdx(uint32_t taskIdx) const + { + uint32_t mnLoops = loopsMNK.m() * loopsMNK.n(); + return taskIdx % GetCoreLoops() / mnLoops; + } + + ACT_DEVICE + uint32_t GetCoreLoops() const + { + return loopsMNK.m() * loopsMNK.n() * splitkFactor; + } + + ACT_DEVICE + uint32_t GetBatchIdx(uint32_t taskIdx) + { + return taskIdx / GetCoreLoops(); + } + + ACT_DEVICE + GemmCoord GetBlockCoord(uint32_t taskIdx) + { + uint32_t splitkSliceIdx = GetSplitkSliceIdx(taskIdx); + uint32_t kIdx = GetKIdxBySplitkSliceIdx(splitkSliceIdx); + + uint32_t innerIdx = taskIdx % (loopsMNK.m() * loopsMNK.n()); + if constexpr (SwizzleDirection == 0) { // Zn + uint32_t tileBlockLoop = CeilDiv(loopsMNK.m(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.n()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.n()); + + uint32_t nRow = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nRow = loopsMNK.m() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; + uint32_t nIdx = inTileBlockIdx / nRow; + if (tileBlockIdx % 2 == 1) { + nIdx = loopsMNK.n() - nIdx - 1; + } + return GemmCoord{mIdx, nIdx, kIdx}; + } else if constexpr (SwizzleDirection == 1) { // Nz + uint32_t tileBlockLoop = CeilDiv(loopsMNK.n(), SwizzleOffset); + uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.m()); + uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.m()); + + uint32_t nCol = SwizzleOffset; + if (tileBlockIdx == tileBlockLoop - 1) { + nCol = loopsMNK.n() - SwizzleOffset * tileBlockIdx; + } + uint32_t mIdx = inTileBlockIdx / nCol; + uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; + if (tileBlockIdx % 2 == 1) { + mIdx = loopsMNK.m() - mIdx - 1; + } + return GemmCoord{mIdx, nIdx, kIdx}; + } } - } - - ACT_DEVICE - GemmCoord GetActualBlockShape(GemmCoord blockCoord, uint32_t splitkSliceIdx) { - uint32_t splitkSliceLen; - if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { - splitkSliceLen = (loopsMNK.k() / splitkFactor + 1) * tileShape.k(); - } else { - splitkSliceLen = (loopsMNK.k() / splitkFactor) * tileShape.k(); + + ACT_DEVICE + GemmCoord GetActualBlockShape(GemmCoord blockCoord, uint32_t splitkSliceIdx) + { + uint32_t splitkSliceLen; + if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { + splitkSliceLen = (loopsMNK.k() / splitkFactor + 1) * tileShape.k(); + } else { + splitkSliceLen = (loopsMNK.k() / splitkFactor) * tileShape.k(); + } + uint32_t mActual = (blockCoord.m() == (loopsMNK.m() - 1)) ? (problemShape.m() - blockCoord.m() * tileShape.m()) + : tileShape.m(); + uint32_t nActual = (blockCoord.n() == (loopsMNK.n() - 1)) ? (problemShape.n() - blockCoord.n() * tileShape.n()) + : tileShape.n(); + uint32_t kActual = (splitkSliceIdx == (splitkFactor - 1)) ? (problemShape.k() - blockCoord.k() * tileShape.k()) + : splitkSliceLen; + return GemmCoord{mActual, nActual, kActual}; } - uint32_t mActual = (blockCoord.m() == (loopsMNK.m() - 1)) - ? (problemShape.m() - blockCoord.m() * tileShape.m()) - : tileShape.m(); - uint32_t nActual = (blockCoord.n() == (loopsMNK.n() - 1)) - ? (problemShape.n() - blockCoord.n() * tileShape.n()) - : tileShape.n(); - uint32_t kActual = (splitkSliceIdx == (splitkFactor - 1)) - ? (problemShape.k() - blockCoord.k() * tileShape.k()) - : splitkSliceLen; - return GemmCoord{mActual, nActual, kActual}; - } }; -} // namespace Act::Gemm::Block +} // namespace Act::Gemm::Block -#endif // ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP +#endif // ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP diff --git a/act/gemm/dispatch_policy.hpp b/act/gemm/dispatch_policy.hpp index df0abfe2..4ec7433f 100644 --- a/act/gemm/dispatch_policy.hpp +++ b/act/gemm/dispatch_policy.hpp @@ -19,9 +19,10 @@ namespace Act::Gemm { // Block Mmad Policies -template struct MmadAtlasA2Base { - using ArchTag = Arch::AtlasA2; - static constexpr uint32_t ASYNC = ASYNC_; +template +struct MmadAtlasA2Base { + using ArchTag = Arch::AtlasA2; + static constexpr uint32_t ASYNC = ASYNC_; }; using MmadAtlasA2 = MmadAtlasA2Base; @@ -30,62 +31,58 @@ using MmadAtlasA2Async = MmadAtlasA2Base; // Now ENABLE_UNIT_FLAG_ must be false when input element is int8 template struct MmadAtlasA2Pingpong : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; - static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; }; template struct MmadAtlasA2Preload : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; - static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; - static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; + static constexpr uint32_t STAGES = 2; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; }; struct MmadAtlasA2FAQK : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; + static constexpr uint32_t STAGES = 2; }; struct MmadAtlasA2FAPV : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; + static constexpr uint32_t STAGES = 2; }; struct MmadAtlasA2MLAQK : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; + static constexpr uint32_t STAGES = 2; }; struct MmadAtlasA2MLAPV : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; + static constexpr uint32_t STAGES = 2; }; struct MmadAtlasA2MLAQKTp1Spec : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; + static constexpr uint32_t STAGES = 2; }; struct MmadAtlasA2MLAPVTp1Spec : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; + static constexpr uint32_t STAGES = 2; }; -template +template struct MmadAtlasA2PreloadAsync : public MmadAtlasA2Async { - static constexpr uint32_t PRELOAD_STAGES = - PRELOAD_STAGES_; // Stages of emitting load instruction in advance - static constexpr uint32_t L1_STAGES = L1_STAGES_; - static constexpr uint32_t L0A_STAGES = L0A_STAGES_; - static constexpr uint32_t L0B_STAGES = L0B_STAGES_; - static constexpr uint32_t L0C_STAGES = L0C_STAGES_; - static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; - static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; + static constexpr uint32_t PRELOAD_STAGES = PRELOAD_STAGES_; // Stages of emitting load instruction in advance + static constexpr uint32_t L1_STAGES = L1_STAGES_; + static constexpr uint32_t L0A_STAGES = L0A_STAGES_; + static constexpr uint32_t L0B_STAGES = L0B_STAGES_; + static constexpr uint32_t L0C_STAGES = L0C_STAGES_; + static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; + static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; }; -template +template struct MmadAtlasA2PreloadAsyncWithCallback - : public MmadAtlasA2PreloadAsync {}; -} // namespace Act::Gemm +} // namespace Act::Gemm -#endif // ACT_GEMM_DISPATCH_POLICY_HPP +#endif // ACT_GEMM_DISPATCH_POLICY_HPP diff --git a/act/gemm/gemm_type.hpp b/act/gemm/gemm_type.hpp index 6b71040f..145c3964 100644 --- a/act/gemm/gemm_type.hpp +++ b/act/gemm/gemm_type.hpp @@ -17,14 +17,13 @@ namespace Act::Gemm { //////////////////////////////////////////////////////////////////// -template +template struct GemmType { - using Element = Element_; - using Layout = Layout_; - static constexpr AscendC::TPosition POSITION = POSITION_; + using Element = Element_; + using Layout = Layout_; + static constexpr AscendC::TPosition POSITION = POSITION_; }; -} // namespace Act::Gemm +} // namespace Act::Gemm -#endif // ACT_GEMM_GEMM_TYPE_HPP +#endif // ACT_GEMM_GEMM_TYPE_HPP diff --git a/act/gemm/helper.hpp b/act/gemm/helper.hpp index bb448a8e..bb634f9b 100644 --- a/act/gemm/helper.hpp +++ b/act/gemm/helper.hpp @@ -19,256 +19,262 @@ namespace Act::Gemm::helper { -template struct L1AlignHelper { - static_assert(DEPENDENT_FALSE, - "Unsupporteded align helper, can not find the specialization."); +template +struct L1AlignHelper { + static_assert(DEPENDENT_FALSE, "Unsupporteded align helper, can not find the specialization."); }; -template struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; }; -template struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; }; template struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; }; template struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; }; -template struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; }; -template struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +template +struct L1AlignHelper { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; }; -template struct ElementAccumulatorSelector { - static_assert(DEPENDENT_FALSE, - "Unsupporteded element accumulator selector, can not find the " - "specialization."); +template +struct ElementAccumulatorSelector { + static_assert(DEPENDENT_FALSE, + "Unsupporteded element accumulator selector, can not find the " + "specialization."); }; -template <> struct ElementAccumulatorSelector { - using ElementAccumulator = float; +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = float; }; -template <> struct ElementAccumulatorSelector { - using ElementAccumulator = float; +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = float; }; -template <> struct ElementAccumulatorSelector { - using ElementAccumulator = int32_t; +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = int32_t; }; -template <> struct ElementAccumulatorSelector { - using ElementAccumulator = float; +template <> +struct ElementAccumulatorSelector { + using ElementAccumulator = float; }; -template struct L1ATypeSelector { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded layout selector, can not find the specialization."); +template +struct L1ATypeSelector { + static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); }; template struct L1ATypeSelector> { - using L1AType = Gemm::GemmType; + using L1AType = Gemm::GemmType; }; template struct L1ATypeSelector> { - using L1AType = Gemm::GemmType; + using L1AType = Gemm::GemmType; }; template struct L1ATypeSelector> { - using L1AType = Gemm::GemmType; + using L1AType = Gemm::GemmType; }; template struct L1ATypeSelector> { - using L1AType = Gemm::GemmType; + using L1AType = Gemm::GemmType; }; -template struct L1BTypeSelector { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded layout selector, can not find the specialization."); +template +struct L1BTypeSelector { + static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); }; template struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; + using L1BType = Gemm::GemmType; }; template struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; + using L1BType = Gemm::GemmType; }; template struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; + using L1BType = Gemm::GemmType; }; template struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; + using L1BType = Gemm::GemmType; }; template struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; + using L1BType = Gemm::GemmType; }; template struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; + using L1BType = Gemm::GemmType; }; template struct L1AlignHelperTla { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded align helper tla, can not find the specialization."); + static_assert(DEPENDENT_FALSE, "Unsupporteded align helper tla, can not find the specialization."); }; template -struct L1AlignHelperTla< - Element, Layout, std::enable_if_t::value>> { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; +struct L1AlignHelperTla::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; }; template -struct L1AlignHelperTla< - Element, Layout, - std::enable_if_t::value>> { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; +struct L1AlignHelperTla::value>> { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; + static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; }; /////////////////////////////////////// // new add -template struct L1ATypeSelectorGemm { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded layout selector, can not find the specialization."); +template +struct L1ATypeSelectorGemm { + static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); }; template struct L1ATypeSelectorGemm> { - using L1AType = Gemm::GemmType; + using L1AType = Gemm::GemmType; }; template <> struct L1ATypeSelectorGemm> { - using L1AType = Gemm::GemmType; + using L1AType = Gemm::GemmType; }; template struct L1ATypeSelectorGemm> { - using L1AType = Gemm::GemmType; + using L1AType = Gemm::GemmType; }; -template struct L1BTypeSelectorGemm { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded layout selector, can not find the specialization."); +template +struct L1BTypeSelectorGemm { + static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); }; template struct L1BTypeSelectorGemm> { - using L1BType = Gemm::GemmType; + using L1BType = Gemm::GemmType; }; template <> struct L1BTypeSelectorGemm> { - using L1BType = Gemm::GemmType; + using L1BType = Gemm::GemmType; }; template struct L1BTypeSelectorGemm> { - using L1BType = Gemm::GemmType; + using L1BType = Gemm::GemmType; }; -template struct L0ATypeSelector {}; +template +struct L0ATypeSelector {}; template struct L0ATypeSelector> { - using L0AType = Gemm::GemmType; + using L0AType = Gemm::GemmType; }; template struct L0ATypeSelector> { - using L0AType = Gemm::GemmType; + using L0AType = Gemm::GemmType; }; -template <> struct L0ATypeSelector> { - using L0AType = Gemm::GemmType; +template <> +struct L0ATypeSelector> { + using L0AType = Gemm::GemmType; }; -template struct L0BTypeSelectorGemm {}; +template +struct L0BTypeSelectorGemm {}; template struct L0BTypeSelectorGemm> { - using L0BType = Gemm::GemmType; + using L0BType = Gemm::GemmType; }; -template <> struct L0BTypeSelectorGemm> { - using L0BType = Gemm::GemmType; +template <> +struct L0BTypeSelectorGemm> { + using L0BType = Gemm::GemmType; }; template struct L0BTypeSelectorGemm> { - using L0BType = Gemm::GemmType; + using L0BType = Gemm::GemmType; }; -template struct L0BTypeSelectorGemv {}; +template +struct L0BTypeSelectorGemv {}; template struct L0BTypeSelectorGemv> { - using L0BType = Gemm::GemmType; + using L0BType = Gemm::GemmType; }; template struct L0BTypeSelectorGemv> { - using L0BType = Gemm::GemmType; + using L0BType = Gemm::GemmType; }; -template <> struct L0BTypeSelectorGemv> { - using L0BType = Gemm::GemmType; +template <> +struct L0BTypeSelectorGemv> { + using L0BType = Gemm::GemmType; }; -} // namespace Act::Gemm::helper +} // namespace Act::Gemm::helper -#endif // ACT_GEMM_HELPER_HPP +#endif // ACT_GEMM_HELPER_HPP diff --git a/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp b/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp index baf7c00d..4a59ac9b 100644 --- a/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp +++ b/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp @@ -24,365 +24,339 @@ namespace Act::Gemm::Kernel { -template -class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace { +template +class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace +{ public: - using BlockMmad = BlockMmad_; - using ArchTag = typename BlockMmad::ArchTag; - using L1TileShape = typename BlockMmad::L1TileShape; - using ElementA = typename BlockMmad::ElementA; - using LayoutA = typename BlockMmad::LayoutA; - using ElementB = typename BlockMmad::ElementB; - using LayoutB = typename BlockMmad::LayoutB; - using ElementC = typename BlockMmad::ElementC; - using LayoutC = typename BlockMmad::LayoutC; - using ElementAccumulator = typename BlockMmad::ElementAccumulator; - - using BlockEpilogue = BlockEpilogue_; - using ElementScale = typename BlockEpilogue::ElementScale; - using LayoutScale = typename BlockEpilogue::LayoutScale; - using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; - using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; - using ElementD = typename BlockEpilogue::ElementD; - using LayoutD = typename BlockEpilogue::LayoutD; - using EpilogueParams = typename BlockEpilogue::Params; - - using BlockScheduler = BlockScheduler_; - static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; - using ElementGroupList = ElementGroupList_; - - /// Parameters structure - struct Params { - // Data members - GemmCoord problemShape; - uint32_t problemCount; - __gm__ ElementGroupList_ *ptrGroupList; - __gm__ ElementA *ptrA; - LayoutA layoutA; - __gm__ ElementB *ptrB; - LayoutB layoutB; - __gm__ ElementScale *ptrScale; - LayoutScale layoutScale; - __gm__ ElementPerTokenScale *ptrPerTokenScale; - LayoutPerTokenScale layoutPerTokenScale; - __gm__ ElementD *ptrD; - LayoutD layoutD; - GM_ADDR ptrWorkspace; - void *combiner; + using BlockMmad = BlockMmad_; + using ArchTag = typename BlockMmad::ArchTag; + using L1TileShape = typename BlockMmad::L1TileShape; + using ElementA = typename BlockMmad::ElementA; + using LayoutA = typename BlockMmad::LayoutA; + using ElementB = typename BlockMmad::ElementB; + using LayoutB = typename BlockMmad::LayoutB; + using ElementC = typename BlockMmad::ElementC; + using LayoutC = typename BlockMmad::LayoutC; + using ElementAccumulator = typename BlockMmad::ElementAccumulator; + + using BlockEpilogue = BlockEpilogue_; + using ElementScale = typename BlockEpilogue::ElementScale; + using LayoutScale = typename BlockEpilogue::LayoutScale; + using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; + using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; + using ElementD = typename BlockEpilogue::ElementD; + using LayoutD = typename BlockEpilogue::LayoutD; + using EpilogueParams = typename BlockEpilogue::Params; + + using BlockScheduler = BlockScheduler_; + static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; + using ElementGroupList = ElementGroupList_; + + /// Parameters structure + struct Params { + // Data members + GemmCoord problemShape; + uint32_t problemCount; + __gm__ ElementGroupList_ *ptrGroupList; + __gm__ ElementA *ptrA; + LayoutA layoutA; + __gm__ ElementB *ptrB; + LayoutB layoutB; + __gm__ ElementScale *ptrScale; + LayoutScale layoutScale; + __gm__ ElementPerTokenScale *ptrPerTokenScale; + LayoutPerTokenScale layoutPerTokenScale; + __gm__ ElementD *ptrD; + LayoutD layoutD; + GM_ADDR ptrWorkspace; + void *combiner; + + // Methods + ACT_DEVICE + Params() {} + + ACT_DEVICE + Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, + GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, + LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_, + void *combiner_) + : problemShape(problemShape_), + problemCount(problemCount_), + ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), + ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), + layoutA(layoutA_), + ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), + layoutB(layoutB_), + ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), + layoutScale(layoutScale_), + ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), + layoutPerTokenScale(layoutPerTokenScale_), + ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), + layoutD(layoutD_), + ptrWorkspace(ptrWorkspace_), + combiner(combiner_) + {} + }; // Methods ACT_DEVICE - Params() {} - - ACT_DEVICE - Params(GemmCoord problemShape_, uint32_t problemCount_, - GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, - GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, - LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, - LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, - LayoutD layoutD_, GM_ADDR ptrWorkspace_, void *combiner_) - : problemShape(problemShape_), problemCount(problemCount_), - ptrGroupList( - reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), - ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), layoutA(layoutA_), - ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), layoutB(layoutB_), - ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), - layoutScale(layoutScale_), - ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>( - ptrPerTokenScale_)), - layoutPerTokenScale(layoutPerTokenScale_), - ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), layoutD(layoutD_), - ptrWorkspace(ptrWorkspace_), combiner(combiner_) {} - }; - - // Methods - ACT_DEVICE - GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace() { - Arch::FlagID flagId = 0; - for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { - flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); - flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); - aicWaitFuncList[stageId] = {this, stageId}; - aicSetFuncList[stageId] = {this, stageId}; - } - } - - template - ACT_DEVICE void operator()(Params const ¶ms); - - template <> ACT_DEVICE void operator()(Params const ¶ms) { - BlockScheduler blockScheduler; - BlockMmad blockMmad(resource); - - // Represent the full gm - AscendC::GlobalTensor gmA; - gmA.SetGlobalBuffer(params.ptrA); - AscendC::GlobalTensor gmB; - gmB.SetGlobalBuffer(params.ptrB); - AscendC::GlobalTensor groupList; - groupList.SetGlobalBuffer(params.ptrGroupList); - - uint32_t coreIdx = AscendC::GetBlockIdx(); - uint32_t coreNum = AscendC::GetBlockNum(); - int64_t gmGroupOffsetA = 0; - int64_t gmGroupOffsetB = 0; - - AscendC::GlobalTensor gmC; - gmC.SetGlobalBuffer( - reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); - auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, - L1TileShape::N}; - - uint32_t stageId = 0; - uint32_t stageUsed = 0; - uint32_t startCoreIdx = 0; - for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { - uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) - : (groupList.GetValue(groupIdx) - - groupList.GetValue(groupIdx - 1)); - GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), - params.problemShape.k()}; - - LayoutA layoutA = - params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); - LayoutB layoutB = params.layoutB; - - blockScheduler.Update(inGroupProblemShape, - MakeCoord(L1TileShape::M, L1TileShape::N)); - uint32_t coreLoops = blockScheduler.GetCoreLoops(); - - // Determine the starting loopIdx of the current core under the current - // groupIdx - uint32_t startLoopIdx = - ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - - startCoreIdx; - // Loop through the matmul of each groupIdx - for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; - loopIdx += coreNum) { - // Compute block location - GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); - GemmCoord actualBlockShape = - blockScheduler.GetActualBlockShape(blockCoord); - - Callback callbackBeforeFixpipe{}; - if (stageUsed == WORKSPACE_STAGES) { - callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); - } else { - ++stageUsed; - } - Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); - - // Compute initial location in logical coordinates - MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, - blockCoord.k() * L1TileShape::K}; - MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, - blockCoord.n() * L1TileShape::N}; - MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; - int64_t gmOffsetA = layoutA.GetOffset(offsetA); - int64_t gmOffsetB = layoutB.GetOffset(offsetB); - int64_t gmOffsetC = layoutC.GetOffset(offsetC); - - // Compute block-scoped matrix multiply-add - if constexpr (BlockMmad::DispatchPolicy::ASYNC) { - blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, - gmB[gmGroupOffsetB + gmOffsetB], layoutB, gmC[gmOffsetC], - layoutC, actualBlockShape, callbackBeforeFixpipe, - callbackAfterFixpipe); - } else { - callbackBeforeFixpipe(); - blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, - gmB[gmGroupOffsetB + gmOffsetB], layoutB, gmC[gmOffsetC], - layoutC, actualBlockShape); - callbackAfterFixpipe(); + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace() + { + Arch::FlagID flagId = 0; + for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { + flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); + flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); + aicWaitFuncList[stageId] = {this, stageId}; + aicSetFuncList[stageId] = {this, stageId}; } - - stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; - } - - gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); - gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); - - startCoreIdx = (startCoreIdx + coreLoops) % coreNum; - } - - if constexpr (BlockMmad::DispatchPolicy::ASYNC) { - blockMmad.SynchronizeBlock(); } - while (stageUsed > 0) { - uint32_t aivComputeStageId = - (stageId >= stageUsed) ? (stageId - stageUsed) - : (stageId + WORKSPACE_STAGES - stageUsed); - Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); - --stageUsed; - } - } + template + ACT_DEVICE void operator()(Params const ¶ms); - template <> ACT_DEVICE void operator()(Params const ¶ms) { - auto *combiner = - (MoeDistributeCombineImpl::CamMoeDistributeCombine - *)params.combiner; + template <> + ACT_DEVICE void operator()(Params const ¶ms) { - if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - if (get_subblockid() == 0) { - AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>( - MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID); - } - } - BlockScheduler blockScheduler; - BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo()); - - uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); - uint32_t coreNum = AscendC::GetBlockNum(); - int64_t gmGroupOffsetScale = 0; - int64_t gmGroupOffsetPerTokenScale = 0; - int64_t gmGroupOffsetD = 0; - - AscendC::GlobalTensor groupList; - groupList.SetGlobalBuffer(params.ptrGroupList); - - AscendC::GlobalTensor gmC; - gmC.SetGlobalBuffer( - reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); - auto layoutC = layout::RowMajor{ - L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; - - uint32_t stageId = 0; - uint32_t startCoreIdx = 0; - for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { - uint32_t currentM = (groupIdx == 0) - ? groupList.GetValue(groupIdx) - : (groupList.GetValue(groupIdx) - - groupList.GetValue(groupIdx - 1)); - GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), - params.problemShape.k()}; - - LayoutScale layoutScale = params.layoutScale; - LayoutPerTokenScale layoutPerTokenScale = - params.layoutPerTokenScale.GetTileLayout( - inGroupProblemShape.template GetCoordByAxis<0>()); - LayoutD layoutD = - params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); - - EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, - layoutScale, - params.ptrPerTokenScale + - gmGroupOffsetPerTokenScale, - layoutPerTokenScale, - params.ptrD + gmGroupOffsetD, - layoutD}; - - blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); - blockEpilogue.UpdateParams(epilogueParams); - uint32_t coreLoops = blockScheduler.GetCoreLoops(); - - GemmCoord blockShapeMNK = L1TileShape::ToCoord(); - uint32_t startLoopIdx = - ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - - startCoreIdx; - for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; - loopIdx += coreNum) { - GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); - GemmCoord actualBlockShapeMNK = - blockScheduler.GetActualBlockShape(blockCoordMNK); - - MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, - 0}; - int64_t gmOffsetC = layoutC.GetOffset(offsetC); - auto gmBlockC = gmC[gmOffsetC]; - auto layoutBlockC = - layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); - - Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); - blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, - actualBlockShapeMNK, gmBlockC, layoutBlockC); - Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>( - flagAivFinishComputeList[stageId]); - - stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + BlockScheduler blockScheduler; + BlockMmad blockMmad(resource); + + // Represent the full gm + AscendC::GlobalTensor gmA; + gmA.SetGlobalBuffer(params.ptrA); + AscendC::GlobalTensor gmB; + gmB.SetGlobalBuffer(params.ptrB); + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + uint32_t coreIdx = AscendC::GetBlockIdx(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetA = 0; + int64_t gmGroupOffsetB = 0; + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t stageUsed = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); + LayoutB layoutB = params.layoutB; + + blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + // Determine the starting loopIdx of the current core under the current + // groupIdx + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + // Loop through the matmul of each groupIdx + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + // Compute block location + GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); + + Callback callbackBeforeFixpipe{}; + if (stageUsed == WORKSPACE_STAGES) { + callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); + } else { + ++stageUsed; + } + Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); + + // Compute initial location in logical coordinates + MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; + MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetA = layoutA.GetOffset(offsetA); + int64_t gmOffsetB = layoutB.GetOffset(offsetB); + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + + // Compute block-scoped matrix multiply-add + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); + } else { + callbackBeforeFixpipe(); + blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, + gmC[gmOffsetC], layoutC, actualBlockShape); + callbackAfterFixpipe(); + } + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); + gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; } - gmGroupOffsetScale += inGroupProblemShape.n(); - gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); - gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); + if constexpr (BlockMmad::DispatchPolicy::ASYNC) { + blockMmad.SynchronizeBlock(); + } - startCoreIdx = (startCoreIdx + coreLoops) % coreNum; - } + while (stageUsed > 0) { + uint32_t aivComputeStageId = + (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); + Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); + --stageUsed; + } } - icache_preload(4); - if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - if (get_subblockid() == 0) { - resource.pipe.Init(); - combiner->TPipeSet(&resource.pipe); - combiner->AllToAllSend(); - combiner->TPipeSet(nullptr); - resource.pipe.Destroy(); - } else { - resource.pipe.Init(); - combiner->TPipeSet(&resource.pipe); - combiner->ReducePermute(); - combiner->TPipeSet(nullptr); - resource.pipe.Destroy(); - } - } else { - resource.pipe.Init(); - combiner->TPipeSet(&resource.pipe); - combiner->Process(); - combiner->TPipeSet(nullptr); - resource.pipe.Destroy(); + template <> + ACT_DEVICE void operator()(Params const ¶ms) + { + auto *combiner = (MoeDistributeCombineImpl::CamMoeDistributeCombine *)params.combiner; + { + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID); + } + } + BlockScheduler blockScheduler; + BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo()); + + uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); + uint32_t coreNum = AscendC::GetBlockNum(); + int64_t gmGroupOffsetScale = 0; + int64_t gmGroupOffsetPerTokenScale = 0; + int64_t gmGroupOffsetD = 0; + + AscendC::GlobalTensor groupList; + groupList.SetGlobalBuffer(params.ptrGroupList); + + AscendC::GlobalTensor gmC; + gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); + auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; + + uint32_t stageId = 0; + uint32_t startCoreIdx = 0; + for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { + uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) + : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); + GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; + + LayoutScale layoutScale = params.layoutScale; + LayoutPerTokenScale layoutPerTokenScale = + params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); + LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); + + EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, + layoutScale, + params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, + layoutPerTokenScale, + params.ptrD + gmGroupOffsetD, + layoutD}; + + blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); + blockEpilogue.UpdateParams(epilogueParams); + uint32_t coreLoops = blockScheduler.GetCoreLoops(); + + GemmCoord blockShapeMNK = L1TileShape::ToCoord(); + uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; + for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { + GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); + GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); + + MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; + int64_t gmOffsetC = layoutC.GetOffset(offsetC); + auto gmBlockC = gmC[gmOffsetC]; + auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); + + Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); + blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, + layoutBlockC); + Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); + + stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; + } + + gmGroupOffsetScale += inGroupProblemShape.n(); + gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); + gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); + + startCoreIdx = (startCoreIdx + coreLoops) % coreNum; + } + } + + icache_preload(4); + if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { + if (get_subblockid() == 0) { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->AllToAllSend(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->ReducePermute(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } + } else { + resource.pipe.Init(); + combiner->TPipeSet(&resource.pipe); + combiner->Process(); + combiner->TPipeSet(nullptr); + resource.pipe.Destroy(); + } } - } private: - friend struct AicWaitFunc; - friend struct AicSetFunc; + friend struct AicWaitFunc; + friend struct AicSetFunc; - struct AicWaitFunc { - using MatmulKernel = GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace< - TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, - WORKSPACE_STAGES, ElementGroupList>; + struct AicWaitFunc { + using MatmulKernel = + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace; - ACT_DEVICE - AicWaitFunc() = default; + ACT_DEVICE + AicWaitFunc() = default; - ACT_DEVICE - void operator()() const { - Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); - } + ACT_DEVICE + void operator()() const + { + Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); + } - MatmulKernel *ptr{nullptr}; - uint32_t stageId; - }; + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; - struct AicSetFunc { - using MatmulKernel = GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace< - TemplateMC2TypeFunc, BlockMmad, BlockEpilogue, BlockScheduler, - WORKSPACE_STAGES, ElementGroupList>; + struct AicSetFunc { + using MatmulKernel = + GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace; - ACT_DEVICE - AicSetFunc() = default; + ACT_DEVICE + AicSetFunc() = default; - ACT_DEVICE - void operator()() const { - Arch::CrossCoreSetFlag<0x2, PIPE_FIX>( - ptr->flagAicFinishStoreList[stageId]); - } + ACT_DEVICE + void operator()() const + { + Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); + } - MatmulKernel *ptr{nullptr}; - uint32_t stageId; - }; + MatmulKernel *ptr{nullptr}; + uint32_t stageId; + }; - Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; - Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; + Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; - AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; - AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; - Arch::Resource resource; + AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; + AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; + Arch::Resource resource; }; -} // namespace Act::Gemm::Kernel +} // namespace Act::Gemm::Kernel -#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP +#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP diff --git a/act/gemm/tile/copy_gm_to_l1.hpp b/act/gemm/tile/copy_gm_to_l1.hpp index ba5e8207..5100d46f 100644 --- a/act/gemm/tile/copy_gm_to_l1.hpp +++ b/act/gemm/tile/copy_gm_to_l1.hpp @@ -22,835 +22,777 @@ using namespace tla; namespace Act::Gemm::Tile { -template struct CopyGmToL1 { - static_assert(DEPENDENT_FALSE, - "Unsupported copy gm to l1, can not find the specialization."); +template +struct CopyGmToL1 { + static_assert(DEPENDENT_FALSE, "Unsupported copy gm to l1, can not find the specialization."); }; /// Partial specialization for AtlasA2, half, RowMajor in and zN out. /// Matrix A confirm template -struct CopyGmToL1, - Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (layoutSrc.stride(0) < STRIDE_LIMIT) { - intriParams.nValue = layoutSrc.shape(0); - intriParams.srcDValue = layoutSrc.stride(0); - intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { - AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], - srcTensor[i * layoutSrc.stride(0)], intriParams); - } - } - } -}; +struct CopyGmToL1, Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; -template -struct CopyGmToL1, - Gemm::GemmType> { - using LayoutDst = layout::zZ; - using LayoutSrc = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::Nd2NzParams intriParams; - uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(0); - uint32_t ndNum = layoutSrc.shape(0) / C0_NUM_PER_FRACTAL; - uint32_t remains = layoutSrc.shape(0) % C0_NUM_PER_FRACTAL; - if (srcNdStride < STRIDE_LIMIT) { - if (ndNum) { - intriParams.ndNum = ndNum; - intriParams.nValue = C0_NUM_PER_FRACTAL; - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcNdMatrixStride = srcNdStride; - intriParams.srcDValue = layoutSrc.stride(0); + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + // Methods - intriParams.dstNzMatrixStride = layoutDst.stride(1); + ACT_DEVICE + CopyGmToL1() {}; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } - - if (remains) { - AscendC::Nd2NzParams tailParams; - tailParams.ndNum = 1; - tailParams.nValue = remains; - tailParams.dValue = layoutSrc.shape(1); - tailParams.srcNdMatrixStride = srcNdStride; - tailParams.srcDValue = layoutSrc.stride(0); - - tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - tailParams.dstNzMatrixStride = 0; //` - - AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], - srcTensor[ndNum * srcNdStride], tailParams); - } - } else if (layoutSrc.stride(0) < STRIDE_LIMIT) { - for (uint32_t i = 0; i < ndNum; i++) { + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; - intriParams.nValue = C0_NUM_PER_FRACTAL; intriParams.dValue = layoutSrc.shape(1); intriParams.srcNdMatrixStride = 0; - intriParams.srcDValue = layoutSrc.stride(0); - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; intriParams.dstNzMatrixStride = 0; - AscendC::DataCopy(dstTensor[i * layoutDst.stride(1)], - srcTensor[i * srcNdStride], intriParams); - } - if (remains) { - AscendC::Nd2NzParams tailParams; - tailParams.ndNum = 1; - tailParams.nValue = remains; - tailParams.dValue = layoutSrc.shape(1); - tailParams.srcNdMatrixStride = 0; - tailParams.srcDValue = layoutSrc.stride(0); - - tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - tailParams.dstNzMatrixStride = 0; - - AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], - srcTensor[ndNum * srcNdStride], tailParams); - } - } else { - for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { - uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; - uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + if (layoutSrc.stride(0) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(0)], intriParams); + } + } + } +}; - AscendC::Nd2NzParams intriParams; - intriParams.ndNum = 1; - intriParams.nValue = 1; - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcNdMatrixStride = 0; - intriParams.srcDValue = 0; +template +struct CopyGmToL1, Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::RowMajor; - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = 0; - intriParams.dstNzMatrixStride = 0; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; - uint32_t offsetDst = - i * idxR0 * layoutDst.stride(1) + idxInR0 * ELE_NUM_PER_C0; - uint32_t offsetSrc = i * layoutSrc.stride(0); - AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], - intriParams); - } + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(0); + uint32_t ndNum = layoutSrc.shape(0) / C0_NUM_PER_FRACTAL; + uint32_t remains = layoutSrc.shape(0) % C0_NUM_PER_FRACTAL; + if (srcNdStride < STRIDE_LIMIT) { + if (ndNum) { + intriParams.ndNum = ndNum; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = srcNdStride; + intriParams.srcDValue = layoutSrc.stride(0); + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + + intriParams.dstNzMatrixStride = layoutDst.stride(1); + + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } + + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(1); + tailParams.srcNdMatrixStride = srcNdStride; + tailParams.srcDValue = layoutSrc.stride(0); + + tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; //` + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else if (layoutSrc.stride(0) < STRIDE_LIMIT) { + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = layoutSrc.stride(0); + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[i * layoutDst.stride(1)], srcTensor[i * srcNdStride], intriParams); + } + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(1); + tailParams.srcNdMatrixStride = 0; + tailParams.srcDValue = layoutSrc.stride(0); + + tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else { + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; + uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = 0; + + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = 0; + intriParams.dstNzMatrixStride = 0; + + uint32_t offsetDst = i * idxR0 * layoutDst.stride(1) + idxInR0 * ELE_NUM_PER_C0; + uint32_t offsetSrc = i * layoutSrc.stride(0); + AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], intriParams); + } + } } - } }; template -struct CopyGmToL1, - Gemm::GemmType> { - using LayoutDst = layout::nN; - using LayoutSrc = layout::ColumnMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::Nd2NzParams intriParams; - uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(1); - uint32_t ndNum = layoutSrc.shape(1) / C0_NUM_PER_FRACTAL; - uint32_t remains = layoutSrc.shape(1) % C0_NUM_PER_FRACTAL; - if (srcNdStride < STRIDE_LIMIT) { - if (ndNum) { - intriParams.ndNum = ndNum; - intriParams.nValue = C0_NUM_PER_FRACTAL; - intriParams.dValue = layoutSrc.shape(0); - intriParams.srcNdMatrixStride = srcNdStride; - intriParams.srcDValue = layoutSrc.stride(1); +struct CopyGmToL1, Gemm::GemmType> { + using LayoutDst = layout::nN; + using LayoutSrc = layout::ColumnMajor; - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - intriParams.dstNzMatrixStride = layoutDst.stride(3); + // Methods - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } - - if (remains) { - AscendC::Nd2NzParams tailParams; - tailParams.ndNum = 1; - tailParams.nValue = remains; - tailParams.dValue = layoutSrc.shape(0); - tailParams.srcNdMatrixStride = srcNdStride; - tailParams.srcDValue = layoutSrc.stride(1); - - tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - tailParams.dstNzMatrixStride = 0; - - AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], - srcTensor[ndNum * srcNdStride], tailParams); - } - } else if (layoutSrc.stride(1) < STRIDE_LIMIT) { - for (uint32_t i = 0; i < ndNum; i++) { + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { AscendC::Nd2NzParams intriParams; - intriParams.ndNum = 1; - intriParams.nValue = C0_NUM_PER_FRACTAL; - intriParams.dValue = layoutSrc.shape(0); - intriParams.srcNdMatrixStride = 0; - intriParams.srcDValue = layoutSrc.stride(1); + uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(1); + uint32_t ndNum = layoutSrc.shape(1) / C0_NUM_PER_FRACTAL; + uint32_t remains = layoutSrc.shape(1) % C0_NUM_PER_FRACTAL; + if (srcNdStride < STRIDE_LIMIT) { + if (ndNum) { + intriParams.ndNum = ndNum; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = srcNdStride; + intriParams.srcDValue = layoutSrc.stride(1); + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + + intriParams.dstNzMatrixStride = layoutDst.stride(3); + + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } + + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(0); + tailParams.srcNdMatrixStride = srcNdStride; + tailParams.srcDValue = layoutSrc.stride(1); + + tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else if (layoutSrc.stride(1) < STRIDE_LIMIT) { + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = C0_NUM_PER_FRACTAL; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = layoutSrc.stride(1); + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[i * layoutDst.stride(3)], srcTensor[i * srcNdStride], intriParams); + } + if (remains) { + AscendC::Nd2NzParams tailParams; + tailParams.ndNum = 1; + tailParams.nValue = remains; + tailParams.dValue = layoutSrc.shape(0); + tailParams.srcNdMatrixStride = 0; + tailParams.srcDValue = layoutSrc.stride(1); + + tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + tailParams.dstNzMatrixStride = 0; + + AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], srcTensor[ndNum * srcNdStride], tailParams); + } + } else { + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; + uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + + AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; + intriParams.nValue = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.srcDValue = 0; + + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzNStride = 0; + intriParams.dstNzMatrixStride = 0; + + uint32_t offsetDst = i * idxR0 * layoutDst.stride(3) + idxInR0 * ELE_NUM_PER_C0; + uint32_t offsetSrc = i * layoutSrc.stride(1); + AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], intriParams); + } + } + } +}; - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; +template +struct CopyGmToL1, Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - AscendC::DataCopy(dstTensor[i * layoutDst.stride(3)], - srcTensor[i * srcNdStride], intriParams); - } - if (remains) { - AscendC::Nd2NzParams tailParams; - tailParams.ndNum = 1; - tailParams.nValue = remains; - tailParams.dValue = layoutSrc.shape(0); - tailParams.srcNdMatrixStride = 0; - tailParams.srcDValue = layoutSrc.stride(1); - - tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - tailParams.dstNzMatrixStride = 0; - - AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], - srcTensor[ndNum * srcNdStride], tailParams); - } - } else { - for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { - uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; - uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; + // Methods + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { AscendC::Nd2NzParams intriParams; + intriParams.ndNum = 1; - intriParams.nValue = 1; intriParams.dValue = layoutSrc.shape(0); intriParams.srcNdMatrixStride = 0; - intriParams.srcDValue = 0; - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = 0; intriParams.dstNzMatrixStride = 0; - uint32_t offsetDst = - i * idxR0 * layoutDst.stride(3) + idxInR0 * ELE_NUM_PER_C0; - uint32_t offsetSrc = i * layoutSrc.stride(1); - AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], - intriParams); - } - } - } -}; - -template -struct CopyGmToL1, - Gemm::GemmType> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::ColumnMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.shape(0); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (layoutSrc.stride(1) < STRIDE_LIMIT) { - intriParams.nValue = layoutSrc.shape(1); - intriParams.srcDValue = layoutSrc.stride(1); - intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { - AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], - srcTensor[i * layoutSrc.stride(1)], intriParams); - } + if (layoutSrc.stride(1) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(1); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(1)], intriParams); + } + } } - } }; /// Partial specialization for AtlasA2, RowMajor in and zN out. template struct CopyGmToL1> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (layoutSrc.stride(0) < STRIDE_LIMIT) { - intriParams.nValue = layoutSrc.shape(0); - intriParams.srcDValue = layoutSrc.stride(0); - intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { - AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], - srcTensor[i * layoutSrc.stride(0)], intriParams); - } + using LayoutDst = layout::zN; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(0) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(0)], intriParams); + } + } } - } - - // layoutSrc must be the layout of one of the src matrices - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc, - uint32_t ndNum, uint32_t srcNdMatrixStride, - uint32_t dstNzNStride, uint32_t dstNzMatrixStride, - uint32_t dstNzC0Stride) { - AscendC::Nd2NzParams intriParams; - - intriParams.nValue = layoutSrc.shape(0); - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcDValue = layoutSrc.stride(0); - intriParams.dstNzNStride = dstNzNStride; - intriParams.dstNzC0Stride = dstNzC0Stride; - if (srcNdMatrixStride < STRIDE_LIMIT) { - intriParams.ndNum = ndNum; - intriParams.srcNdMatrixStride = srcNdMatrixStride; - intriParams.dstNzMatrixStride = dstNzMatrixStride; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } else { - intriParams.ndNum = 1; - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzMatrixStride = 0; - for (uint32_t i = 0; i < ndNum; i++) { - AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], - srcTensor[i * srcNdMatrixStride], intriParams); - } + + // layoutSrc must be the layout of one of the src matrices + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc, uint32_t ndNum, uint32_t srcNdMatrixStride, + uint32_t dstNzNStride, uint32_t dstNzMatrixStride, uint32_t dstNzC0Stride) + { + AscendC::Nd2NzParams intriParams; + + intriParams.nValue = layoutSrc.shape(0); + intriParams.dValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = dstNzNStride; + intriParams.dstNzC0Stride = dstNzC0Stride; + if (srcNdMatrixStride < STRIDE_LIMIT) { + intriParams.ndNum = ndNum; + intriParams.srcNdMatrixStride = srcNdMatrixStride; + intriParams.dstNzMatrixStride = dstNzMatrixStride; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.ndNum = 1; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzMatrixStride = 0; + for (uint32_t i = 0; i < ndNum; i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * srcNdMatrixStride], intriParams); + } + } } - } }; /// Partial specialization for AtlasA2, ColumnMajor in and nZ out. template struct CopyGmToL1> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::ColumnMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.shape(0); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (layoutSrc.stride(1) < STRIDE_LIMIT) { - intriParams.nValue = layoutSrc.shape(1); - intriParams.srcDValue = layoutSrc.stride(1); - intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { - AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], - srcTensor[i * layoutSrc.stride(1)], intriParams); - } + using LayoutDst = layout::nZ; + using LayoutSrc = layout::ColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.shape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (layoutSrc.stride(1) < STRIDE_LIMIT) { + intriParams.nValue = layoutSrc.shape(1); + intriParams.srcDValue = layoutSrc.stride(1); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { + AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(1)], intriParams); + } + } } - } }; /// Partial specialization for zN in and zN out. template struct CopyGmToL1> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - uint32_t blockCount = CeilDiv(layoutSrc.orgShape(1)); - uint32_t blockLen = RoundUp(layoutSrc.orgShape(0)); - - AscendC::DataCopyParams repeatParams; - - if (layoutSrc.stride(3) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { - repeatParams.blockCount = blockCount; - repeatParams.blockLen = blockLen; - repeatParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_C0 - blockLen; - repeatParams.dstStride = layoutDst.stride(3) / ELE_NUM_PER_C0 - blockLen; - AscendC::DataCopy(dstTensor, srcTensor, repeatParams); - } else { - repeatParams.blockCount = 1; - repeatParams.blockLen = blockLen; - repeatParams.srcStride = 0; - repeatParams.dstStride = 0; - for (uint32_t i = 0; i < blockCount; i++) { - uint64_t dstOffset = i * layoutDst.stride(3); - uint64_t srcOffset = i * layoutSrc.stride(3); - AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], - repeatParams); - } + using LayoutDst = layout::zN; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t blockCount = CeilDiv(layoutSrc.orgShape(1)); + uint32_t blockLen = RoundUp(layoutSrc.orgShape(0)); + + AscendC::DataCopyParams repeatParams; + + if (layoutSrc.stride(3) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { + repeatParams.blockCount = blockCount; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_C0 - blockLen; + repeatParams.dstStride = layoutDst.stride(3) / ELE_NUM_PER_C0 - blockLen; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } else { + repeatParams.blockCount = 1; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = 0; + repeatParams.dstStride = 0; + for (uint32_t i = 0; i < blockCount; i++) { + uint64_t dstOffset = i * layoutDst.stride(3); + uint64_t srcOffset = i * layoutSrc.stride(3); + AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], repeatParams); + } + } } - } }; /// Partial specialization for nZ in and nZ out. template struct CopyGmToL1> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - uint32_t blockCount = CeilDiv(layoutSrc.orgShape(0)); - uint32_t blockLen = RoundUp(layoutSrc.orgShape(1)); - - AscendC::DataCopyParams repeatParams; - - if (layoutSrc.stride(1) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { - repeatParams.blockCount = blockCount; - repeatParams.blockLen = blockLen; - repeatParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_C0 - blockLen; - repeatParams.dstStride = layoutDst.stride(1) / ELE_NUM_PER_C0 - blockLen; - AscendC::DataCopy(dstTensor, srcTensor, repeatParams); - } else { - repeatParams.blockCount = 1; - repeatParams.blockLen = blockLen; - repeatParams.srcStride = 0; - repeatParams.dstStride = 0; - for (uint32_t i = 0; i < blockCount; i++) { - uint64_t dstOffset = i * layoutDst.stride(1); - uint64_t srcOffset = i * layoutSrc.stride(1); - AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], - repeatParams); - } + using LayoutDst = layout::nZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t blockCount = CeilDiv(layoutSrc.orgShape(0)); + uint32_t blockLen = RoundUp(layoutSrc.orgShape(1)); + + AscendC::DataCopyParams repeatParams; + + if (layoutSrc.stride(1) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { + repeatParams.blockCount = blockCount; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_C0 - blockLen; + repeatParams.dstStride = layoutDst.stride(1) / ELE_NUM_PER_C0 - blockLen; + AscendC::DataCopy(dstTensor, srcTensor, repeatParams); + } else { + repeatParams.blockCount = 1; + repeatParams.blockLen = blockLen; + repeatParams.srcStride = 0; + repeatParams.dstStride = 0; + for (uint32_t i = 0; i < blockCount; i++) { + uint64_t dstOffset = i * layoutDst.stride(1); + uint64_t srcOffset = i * layoutSrc.stride(1); + AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], repeatParams); + } + } } - } }; /// Partial specialization for AtlasA2, PaddingRowMajor in and zN out. template -struct CopyGmToL1> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::PaddingRowMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.orgShape(1); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - intriParams.nValue = layoutSrc.orgShape(0); - intriParams.srcDValue = layoutSrc.stride(0); - intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } +struct CopyGmToL1> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::PaddingRowMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.orgShape(1); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = layoutSrc.orgShape(0); + intriParams.srcDValue = layoutSrc.stride(0); + intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } }; /// Partial specialization for AtlasA2, ColumnMajor in and nZ out. template -struct CopyGmToL1> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::PaddingColumnMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.orgShape(0); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - intriParams.nValue = layoutSrc.orgShape(1); - intriParams.srcDValue = layoutSrc.stride(2); - intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } +struct CopyGmToL1> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::PaddingColumnMajor; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = layoutSrc.orgShape(0); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = layoutSrc.orgShape(1); + intriParams.srcDValue = layoutSrc.stride(2); + intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor, srcTensor, intriParams); + } }; /// Partial specialization for AtlasA2, RowMajor in and RowMajor out. template -struct CopyGmToL1< - Arch::AtlasA2, Gemm::GemmType, - Gemm::GemmType> { - using LayoutDst = layout::RowMajor; - using LayoutSrc = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; - static constexpr uint32_t MAX_REPEAT = 4095; - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - uint32_t rows = layoutSrc.shape(0); - uint32_t cols = layoutSrc.shape(1); - uint32_t srcStride = - (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; - uint32_t dstStride = - (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; - - if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && - (layoutDst.shape(1) == layoutDst.stride(0))) { - DataCopy(dstTensor, srcTensor, rows * cols); - } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && - (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { - uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); - for (uint32_t i = 0; i < rLoops; ++i) { - uint32_t rActual = - (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; - AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, - srcStride, dstStride); - DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], - srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], - dataCopyParams); - } - } else { - for (uint32_t i = 0; i < rows; ++i) { - DataCopy(dstTensor[i * layoutDst.stride(0)], - srcTensor[i * layoutSrc.stride(0)], cols); - } +struct CopyGmToL1, + Gemm::GemmType> { + using LayoutDst = layout::RowMajor; + using LayoutSrc = layout::RowMajor; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); + static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; + static constexpr uint32_t MAX_REPEAT = 4095; + + // Methods + + ACT_DEVICE + CopyGmToL1() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + uint32_t rows = layoutSrc.shape(0); + uint32_t cols = layoutSrc.shape(1); + uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; + uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; + + if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { + DataCopy(dstTensor, srcTensor, rows * cols); + } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { + uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); + for (uint32_t i = 0; i < rLoops; ++i) { + uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; + AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); + DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], + srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); + } + } else { + for (uint32_t i = 0; i < rows; ++i) { + DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); + } + } } - } }; ///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// /// Partial specialization for CopyGmToL1, AtlasA2, RowMajor in and zN out. -template +template struct TileCopyTla< - Arch::AtlasA2, - Tensor, LayoutSrc_, - AscendC::TPosition::GM>, - Tensor, LayoutDst_, - AscendC::TPosition::A1>, - std::enable_if_t::value && - tla::detail::iszN::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::A1>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::GM>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - const uint32_t nValue = get<0>(srcTensor.shape()); - const uint32_t dValue = get<1>(srcTensor.shape()); - const uint32_t srcDValue = get<0>(srcTensor.stride()); - const uint32_t dstInnerStrideRow = get<0, 0>(dstTensor.stride()); - const uint32_t dstOuterStrideCol = get<1, 1>(dstTensor.stride()); - - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = dValue; - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (srcDValue < STRIDE_LIMIT) { - intriParams.nValue = nValue; - intriParams.srcDValue = srcDValue; - intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < nValue; i++) { - AscendC::DataCopy(dstTensor.data()[i * ELE_NUM_PER_C0], - srcTensor.data()[i * srcDValue], intriParams); - } + Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::GM>, + Tensor, LayoutDst_, AscendC::TPosition::A1>, + std::enable_if_t::value && tla::detail::iszN::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t nValue = get<0>(srcTensor.shape()); + const uint32_t dValue = get<1>(srcTensor.shape()); + const uint32_t srcDValue = get<0>(srcTensor.stride()); + const uint32_t dstInnerStrideRow = get<0, 0>(dstTensor.stride()); + const uint32_t dstOuterStrideCol = get<1, 1>(dstTensor.stride()); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = dValue; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (srcDValue < STRIDE_LIMIT) { + intriParams.nValue = nValue; + intriParams.srcDValue = srcDValue; + intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < nValue; i++) { + AscendC::DataCopy(dstTensor.data()[i * ELE_NUM_PER_C0], srcTensor.data()[i * srcDValue], intriParams); + } + } } - } }; /// Partial specialization for CopyGmToL1, AtlasA2, ColumnMajor in and nZ out. -template -struct TileCopyTla< - Arch::AtlasA2, - Tensor, LayoutSrc_, - AscendC::TPosition::GM>, - Tensor, LayoutDst_, - AscendC::TPosition::A1>, - std::enable_if_t::value && - tla::detail::isnZ::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::A1>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::GM>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - const uint32_t nValue = get<1>(srcTensor.shape()); - const uint32_t dValue = get<0>(srcTensor.shape()); - const uint32_t srcDValue = get<1>(srcTensor.stride()); - const uint32_t dstInnerStrideRow = get<1, 0>(dstTensor.stride()); - const uint32_t dstOuterStrideCol = get<0, 1>(dstTensor.stride()); - - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = dValue; - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (srcDValue < STRIDE_LIMIT) { - intriParams.nValue = nValue; - intriParams.srcDValue = srcDValue; - intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < nValue; i++) { - AscendC::DataCopy(dstTensor.data()[i * ELE_NUM_PER_C0], - srcTensor.data()[i * srcDValue], intriParams); - } +template +struct TileCopyTla, LayoutSrc_, AscendC::TPosition::GM>, + Tensor, LayoutDst_, AscendC::TPosition::A1>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t nValue = get<1>(srcTensor.shape()); + const uint32_t dValue = get<0>(srcTensor.shape()); + const uint32_t srcDValue = get<1>(srcTensor.stride()); + const uint32_t dstInnerStrideRow = get<1, 0>(dstTensor.stride()); + const uint32_t dstOuterStrideCol = get<0, 1>(dstTensor.stride()); + + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = dValue; + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + if (srcDValue < STRIDE_LIMIT) { + intriParams.nValue = nValue; + intriParams.srcDValue = srcDValue; + intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } else { + intriParams.nValue = 1; + intriParams.srcDValue = 0; + intriParams.dstNzNStride = 0; + for (uint32_t i = 0; i < nValue; i++) { + AscendC::DataCopy(dstTensor.data()[i * ELE_NUM_PER_C0], srcTensor.data()[i * srcDValue], intriParams); + } + } } - } }; /// Partial specialization for CopyGmToL1, AtlasA2, PaddingRowMajor in and zN /// out. -template -struct TileCopyTlaExt, LayoutSrc_, - AscendC::TPosition::GM>, - Tensor, LayoutDst_, - AscendC::TPosition::A1>, +template +struct TileCopyTlaExt, LayoutSrc_, AscendC::TPosition::GM>, + Tensor, LayoutDst_, AscendC::TPosition::A1>, layout::PaddingRowMajor, layout::zN> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::A1>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::GM>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTlaExt() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = get<1>(srcTensor.orgShape()); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = get<1, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - intriParams.nValue = get<0>(srcTensor.orgShape()); - intriParams.srcDValue = get<0, 0>(srcTensor.stride()); - intriParams.dstNzNStride = get<0, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); - } + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTlaExt() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = get<1>(srcTensor.orgShape()); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = get<1, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = get<0>(srcTensor.orgShape()); + intriParams.srcDValue = get<0, 0>(srcTensor.stride()); + intriParams.dstNzNStride = get<0, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } }; /// Partial specialization for TileCopyTlaExt, CopyGmToL1, AtlasA2, /// PaddingColumnMajor in and nZ out. -template -struct TileCopyTlaExt, LayoutSrc_, - AscendC::TPosition::GM>, - Tensor, LayoutDst_, - AscendC::TPosition::A1>, +template +struct TileCopyTlaExt, LayoutSrc_, AscendC::TPosition::GM>, + Tensor, LayoutDst_, AscendC::TPosition::A1>, layout::PaddingColumnMajor, layout::nZ> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::A1>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::GM>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTlaExt() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = get<0>(srcTensor.orgShape()); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = get<0, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - intriParams.nValue = get<1>(srcTensor.orgShape()); - intriParams.srcDValue = get<1, 0>(srcTensor.stride()); - intriParams.dstNzNStride = get<1, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); - } + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTlaExt() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + AscendC::Nd2NzParams intriParams; + + intriParams.ndNum = 1; + intriParams.dValue = get<0>(srcTensor.orgShape()); + intriParams.srcNdMatrixStride = 0; + intriParams.dstNzC0Stride = get<0, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; + intriParams.dstNzMatrixStride = 0; + + intriParams.nValue = get<1>(srcTensor.orgShape()); + intriParams.srcDValue = get<1, 0>(srcTensor.stride()); + intriParams.dstNzNStride = get<1, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; + AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace Act::Gemm::Tile +} // namespace Act::Gemm::Tile -#endif // ACT_GEMM_TILE_COPY_GM_TO_L1_HPP +#endif // ACT_GEMM_TILE_COPY_GM_TO_L1_HPP diff --git a/act/gemm/tile/copy_gm_to_ub.hpp b/act/gemm/tile/copy_gm_to_ub.hpp index 6e690d94..d5065005 100644 --- a/act/gemm/tile/copy_gm_to_ub.hpp +++ b/act/gemm/tile/copy_gm_to_ub.hpp @@ -19,46 +19,35 @@ namespace Act::Gemm::Tile { /// Partial specialization for AtlasA2, RowMajor in and RowMajor out. -template +template struct TileCopyTla< - Arch::AtlasA2, - Tensor, LayoutSrc_, - AscendC::TPosition::GM>, - Tensor, LayoutDst_, - AscendC::TPosition::VECCALC>, - std::enable_if_t::value && - tla::detail::isRowMajor::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::VECCALC>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::GM>; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - AscendC::DataCopyExtParams dataCopyParams( - get<0>(srcTensor.shape()), - get<1>(srcTensor.shape()) * sizeof(ElementSrc), - (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) * - sizeof(ElementSrc), - (get<0>(dstTensor.stride()) - get<1>(dstTensor.shape())) / - ELE_NUM_PER_BLK, - 0); - AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); - AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams, - padParams); - }; + Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::GM>, + Tensor, LayoutDst_, AscendC::TPosition::VECCALC>, + std::enable_if_t::value && tla::detail::isRowMajor::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::VECCALC>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; + + static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + AscendC::DataCopyExtParams dataCopyParams( + get<0>(srcTensor.shape()), get<1>(srcTensor.shape()) * sizeof(ElementSrc), + (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) * sizeof(ElementSrc), + (get<0>(dstTensor.stride()) - get<1>(dstTensor.shape())) / ELE_NUM_PER_BLK, 0); + AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); + AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams, padParams); + }; }; -} // namespace Act::Gemm::Tile +} // namespace Act::Gemm::Tile -#endif // ACT_GEMM_TILE_COPY_GM_TO_UB_HPP +#endif // ACT_GEMM_TILE_COPY_GM_TO_UB_HPP diff --git a/act/gemm/tile/copy_l0c_to_gm.hpp b/act/gemm/tile/copy_l0c_to_gm.hpp index 534af09a..b25e28b0 100644 --- a/act/gemm/tile/copy_l0c_to_gm.hpp +++ b/act/gemm/tile/copy_l0c_to_gm.hpp @@ -17,241 +17,203 @@ namespace Act::Gemm::Tile { -enum class ScaleGranularity { - UNDEFINED = -1, - NO_QUANT = 0, - PER_TENSOR, - PER_CHANNEL, - PER_GROUP -}; +enum class ScaleGranularity { UNDEFINED = -1, NO_QUANT = 0, PER_TENSOR, PER_CHANNEL, PER_GROUP }; template struct CopyL0CToGmQuantMode { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded copy l0c to gm, can not find the specialization."); + static_assert(DEPENDENT_FALSE, "Unsupporteded copy l0c to gm, can not find the specialization."); }; // CopyL0CToGm cast fp32 to fp16 template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::F322F16; +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::F322F16; }; // CopyL0CToGm cast fp32 to bf16 template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::F322BF16; +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::F322BF16; }; // CopyL0CToGm output fp32 template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::NoQuant; +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::NoQuant; }; // CopyL0CToGm output int32 template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::NoQuant; +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::NoQuant; }; // CopyL0CToGm cast int32_t to fp16 template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::DEQF16; +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::DEQF16; }; template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::VDEQF16; +struct CopyL0CToGmQuantMode { + static constexpr auto VALUE = QuantMode_t::VDEQF16; }; template + ScaleGranularity DEQUANT_GRANULARITY = ScaleGranularity::NO_QUANT, bool ReluEnable = false> struct CopyL0CToGm { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded copy l0c to gm, can not find the specialization."); + static_assert(DEPENDENT_FALSE, "Unsupporteded copy l0c to gm, can not find the specialization."); }; template -struct CopyL0CToGm, +struct CopyL0CToGm, ScaleGranularity::NO_QUANT, ReluEnable_> { - using ArchTag = Act::Arch::AtlasA2; - using ElementDst = ElementDst_; - using ElementSrc = ElementAccumulator_; - using LayoutSrc = Act::layout::zN; - using LayoutDst = Act::layout::RowMajor; - static constexpr auto quantPre = - CopyL0CToGmQuantMode::VALUE; - static constexpr auto reluEn = ReluEnable_; - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &dst, - AscendC::LocalTensor const &src, - LayoutDst const &dstLayout, LayoutSrc const &srcLayout, - uint8_t unitFlag = 0) { - AscendC::FixpipeParamsV220 intriParams; - - // Fixpipe layout information - intriParams.nSize = dstLayout.shape(1); - intriParams.mSize = dstLayout.shape(0); - intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0); - intriParams.dstStride = dstLayout.stride(0); - - // Fixpipe auxiliary arguments - intriParams.quantPre = quantPre; - intriParams.reluEn = reluEn; - intriParams.unitFlag = unitFlag; - - // Call AscendC Fixpipe - AscendC::Fixpipe( - dst, src, intriParams); - } + using ArchTag = Act::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Act::layout::zN; + using LayoutDst = Act::layout::RowMajor; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = dstLayout.shape(1); + intriParams.mSize = dstLayout.shape(0); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0); + intriParams.dstStride = dstLayout.stride(0); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dst, src, intriParams); + } }; template -struct CopyL0CToGm, +struct CopyL0CToGm, ScaleGranularity::NO_QUANT, ReluEnable_> { - using ArchTag = Act::Arch::AtlasA2; - using ElementDst = ElementDst_; - using ElementSrc = ElementAccumulator_; - using LayoutSrc = Act::layout::zN; - using LayoutDst = Act::layout::ColumnMajor; - static constexpr auto quantPre = - CopyL0CToGmQuantMode::VALUE; - static constexpr auto reluEn = ReluEnable_; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementDst); - - ACT_DEVICE - CopyL0CToGm() {} - - ACT_DEVICE - void operator()(AscendC::GlobalTensor dstTensor, - AscendC::LocalTensor srcTensor, - LayoutDst const &dstLayout, LayoutSrc const &srcLayout, - uint8_t unitFlag = 0) { - AscendC::DataCopyCO12DstParams params; - - params.nSize = dstLayout.shape(0); - params.mSize = dstLayout.shape(1); - params.dstStride = dstLayout.stride(1); - params.srcStride = srcLayout.shape(2) * srcLayout.shape(3); - params.quantPre = quantPre; - params.reluPre = 0; - params.channelSplit = false; - params.nz2ndEn = true; - AscendC::DataCopy(dstTensor, srcTensor, params); - } + using ArchTag = Act::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Act::layout::zN; + using LayoutDst = Act::layout::ColumnMajor; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementDst); + + ACT_DEVICE + CopyL0CToGm() {} + + ACT_DEVICE + void operator()(AscendC::GlobalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::DataCopyCO12DstParams params; + + params.nSize = dstLayout.shape(0); + params.mSize = dstLayout.shape(1); + params.dstStride = dstLayout.stride(1); + params.srcStride = srcLayout.shape(2) * srcLayout.shape(3); + params.quantPre = quantPre; + params.reluPre = 0; + params.channelSplit = false; + params.nz2ndEn = true; + AscendC::DataCopy(dstTensor, srcTensor, params); + } }; template -struct CopyL0CToGm, +struct CopyL0CToGm, ScaleGranularity::NO_QUANT, ReluEnable_> { - using ArchTag = Act::Arch::AtlasA2; - using ElementDst = ElementDst_; - using ElementSrc = ElementAccumulator_; - using LayoutSrc = Act::layout::zN; - using LayoutDst = Act::layout::zN; - static constexpr auto quantPre = - CopyL0CToGmQuantMode::VALUE; - static constexpr auto reluEn = ReluEnable_; - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &dst, - AscendC::LocalTensor const &src, - LayoutDst const &dstLayout, LayoutSrc const &srcLayout, - uint8_t unitFlag = 0) { - AscendC::FixpipeParamsV220 intriParams; - - // Fixpipe layout information - intriParams.nSize = dstLayout.shape(2) * dstLayout.shape(3); - intriParams.mSize = dstLayout.shape(0) * dstLayout.shape(1); - intriParams.srcStride = srcLayout.stride(3) / srcLayout.shape(2); - intriParams.dstStride = - dstLayout.stride(3) / (BYTE_PER_C0 / sizeof(ElementDst)); - - // Fixpipe auxiliary arguments - intriParams.quantPre = quantPre; - intriParams.reluEn = reluEn; - intriParams.unitFlag = unitFlag; - - // Call AscendC Fixpipe - AscendC::Fixpipe(dst, src, - intriParams); - } + using ArchTag = Act::Arch::AtlasA2; + using ElementDst = ElementDst_; + using ElementSrc = ElementAccumulator_; + using LayoutSrc = Act::layout::zN; + using LayoutDst = Act::layout::zN; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + ACT_DEVICE + void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, + LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = dstLayout.shape(2) * dstLayout.shape(3); + intriParams.mSize = dstLayout.shape(0) * dstLayout.shape(1); + intriParams.srcStride = srcLayout.stride(3) / srcLayout.shape(2); + intriParams.dstStride = dstLayout.stride(3) / (BYTE_PER_C0 / sizeof(ElementDst)); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dst, src, intriParams); + } }; ///////////////////////////////////////////CopyL0CToGmTla///////////////////////////////////////////////// template + ScaleGranularity DEQUANT_GRANULARITY = ScaleGranularity::NO_QUANT, bool ReluEnable = false, + class Enable = void> struct CopyL0CToGmTla { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded copy l0c to gm, can not find the specialization."); + static_assert(DEPENDENT_FALSE, "Unsupporteded copy l0c to gm, can not find the specialization."); }; -template +template struct CopyL0CToGmTla< - Act::Arch::AtlasA2, TensorSrc_, - Tensor, LayoutDst_, - AscendC::TPosition::GM>, - ScaleGranularity::NO_QUANT, ReluEnable_, - std::enable_if_t::value>> { - using ArchTag = Act::Arch::AtlasA2; - using TensorDst = Tensor, LayoutDst_, - AscendC::TPosition::GM>; - using ElementDst = ElementDst_; - using TensorSrc = TensorSrc_; - using ElementSrc = typename TensorSrc::Element; - static constexpr auto quantPre = - CopyL0CToGmQuantMode::VALUE; - static constexpr auto reluEn = ReluEnable_; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor, - uint8_t unitFlag = 0) { - AscendC::FixpipeParamsV220 intriParams; - - // Fixpipe layout information - intriParams.nSize = get<1>(dstTensor.shape()); - intriParams.mSize = get<0>(dstTensor.shape()); - intriParams.srcStride = - get<1, 1>(srcTensor.stride()) / get<0, 0>(srcTensor.stride()); - intriParams.dstStride = get<0>(dstTensor.stride()); - - // Fixpipe auxiliary arguments - intriParams.quantPre = quantPre; - intriParams.reluEn = reluEn; - intriParams.unitFlag = unitFlag; - - // Call AscendC Fixpipe - AscendC::Fixpipe( - dstTensor.data(), srcTensor.data(), intriParams); - } + Act::Arch::AtlasA2, TensorSrc_, Tensor, LayoutDst_, AscendC::TPosition::GM>, + ScaleGranularity::NO_QUANT, ReluEnable_, std::enable_if_t::value>> { + using ArchTag = Act::Arch::AtlasA2; + using TensorDst = Tensor, LayoutDst_, AscendC::TPosition::GM>; + using ElementDst = ElementDst_; + using TensorSrc = TensorSrc_; + using ElementSrc = typename TensorSrc::Element; + static constexpr auto quantPre = + CopyL0CToGmQuantMode::VALUE; + static constexpr auto reluEn = ReluEnable_; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor, uint8_t unitFlag = 0) + { + AscendC::FixpipeParamsV220 intriParams; + + // Fixpipe layout information + intriParams.nSize = get<1>(dstTensor.shape()); + intriParams.mSize = get<0>(dstTensor.shape()); + intriParams.srcStride = get<1, 1>(srcTensor.stride()) / get<0, 0>(srcTensor.stride()); + intriParams.dstStride = get<0>(dstTensor.stride()); + + // Fixpipe auxiliary arguments + intriParams.quantPre = quantPre; + intriParams.reluEn = reluEn; + intriParams.unitFlag = unitFlag; + + // Call AscendC Fixpipe + AscendC::Fixpipe(dstTensor.data(), srcTensor.data(), + intriParams); + } }; ///////////////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace Act::Gemm::Tile +} // namespace Act::Gemm::Tile -#endif // ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP +#endif // ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP diff --git a/act/gemm/tile/copy_l1_to_l0a.hpp b/act/gemm/tile/copy_l1_to_l0a.hpp index ba88499b..14639773 100644 --- a/act/gemm/tile/copy_l1_to_l0a.hpp +++ b/act/gemm/tile/copy_l1_to_l0a.hpp @@ -22,417 +22,371 @@ using namespace tla; namespace Act::Gemm::Tile { -template struct CopyL1ToL0A { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded copy l1 to l0, can not find the specialization."); +template +struct CopyL1ToL0A { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy l1 to l0, can not find the specialization."); }; //////////////////////////////// /// new add gemm template -struct CopyL1ToL0A, - Act::Gemm::GemmType> { - using LayoutDst = layout::zZ; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0A() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, - AscendC::LocalTensor srcTensor, LayoutDst layoutDst, - LayoutSrc layoutSrc) { - AscendC::LoadData2DParams loadDataParams; - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < layoutDst.shape(1); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], - srcTensor[i * layoutSrc.stride(1)], loadDataParams); +struct CopyL1ToL0A, Act::Gemm::GemmType> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } } - } }; template -struct CopyL1ToL0A, - Act::Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::nN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0A() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, - AscendC::LocalTensor srcTensor, LayoutDst layoutDst, - LayoutSrc layoutSrc) { - AscendC::LoadData2DParams loadDataParams; - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutSrc.shape(1)); - loadDataParams.srcStride = 1; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - for (uint32_t i = 0; i < layoutDst.shape(1); i++) { - AscendC::LoadData(dstTensor[i * layoutSrc.stride(3)], - srcTensor[i * layoutSrc.stride(3)], loadDataParams); +struct CopyL1ToL0A, Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(1)); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutSrc.stride(3)], srcTensor[i * layoutSrc.stride(3)], loadDataParams); + } } - } }; template -struct CopyL1ToL0A, - Act::Gemm::GemmType> { - using Element = float; - using LayoutDst = layout::zN; - using LayoutSrc = layout::nN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0A() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, - AscendC::LocalTensor srcTensor, LayoutDst layoutDst, - LayoutSrc layoutSrc) { - AscendC::LoadData2dTransposeParams loadDataParams; - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutSrc.shape(1) / 2); - loadDataParams.srcStride = 1; - loadDataParams.dstGap = 0; - loadDataParams.dstFracGap = - static_cast(layoutSrc.shape(1) / 2) - 1; - for (uint32_t i = 0; i < layoutDst.shape(1); i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * layoutSrc.stride(3)], - srcTensor[i * layoutSrc.stride(3)], - loadDataParams); +struct CopyL1ToL0A, Act::Gemm::GemmType> { + using Element = float; + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(1) / 2); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = static_cast(layoutSrc.shape(1) / 2) - 1; + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutSrc.stride(3)], srcTensor[i * layoutSrc.stride(3)], + loadDataParams); + } } - } }; template -struct CopyL1ToL0A, - Act::Gemm::GemmType> { - using Element = int8_t; - using LayoutDst = layout::zN; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0A() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, - AscendC::LocalTensor srcTensor, LayoutDst layoutDst, - LayoutSrc layoutSrc) { - uint32_t MRound = layoutSrc.shape(0) * layoutSrc.shape(1); - uint32_t KRound = layoutSrc.shape(2) * layoutSrc.shape(3); - uint32_t KL0Alignment = C0_NUM_PER_FRACTAL * 2; - uint32_t KLoops = CeilDiv(KRound, KL0Alignment); - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(MRound / ELE_NUM_PER_C0); - loadDataParams.srcStride = static_cast(KRound / KL0Alignment); - loadDataParams.dstGap = 1; - loadDataParams.dstFracGap = 0; - - for (uint32_t i = 0; i < KLoops; i++) { - AscendC::LoadDataWithTranspose( - dstTensor[i * MRound * KL0Alignment], - srcTensor[i * KL0Alignment * ELE_NUM_PER_C0], loadDataParams); +struct CopyL1ToL0A, Act::Gemm::GemmType> { + using Element = int8_t; + using LayoutDst = layout::zN; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + uint32_t MRound = layoutSrc.shape(0) * layoutSrc.shape(1); + uint32_t KRound = layoutSrc.shape(2) * layoutSrc.shape(3); + uint32_t KL0Alignment = C0_NUM_PER_FRACTAL * 2; + uint32_t KLoops = CeilDiv(KRound, KL0Alignment); + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(MRound / ELE_NUM_PER_C0); + loadDataParams.srcStride = static_cast(KRound / KL0Alignment); + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < KLoops; i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * MRound * KL0Alignment], + srcTensor[i * KL0Alignment * ELE_NUM_PER_C0], loadDataParams); + } } - } }; ////////////////////////////////////////// /// Partial specialization for zN in and zZ out. template -struct CopyL1ToL0A< - ArchTag, Gemm::GemmType> { - using LayoutDst = layout::zZ; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0A() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < layoutDst.shape(1); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], - srcTensor[i * layoutSrc.stride(1)], loadDataParams); +struct CopyL1ToL0A> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0A() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } } - } }; template -struct CopyL1ToL0A< - ArchTag, Gemm::GemmType> { - using LayoutDst = layout::zZ; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0A() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast( - CeilDiv(layoutDst.orgShape(1))); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); - i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], - srcTensor[i * layoutSrc.stride(1)], loadDataParams); +struct CopyL1ToL0A> { + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0A() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } } - } }; /// Partial specialization for int8_t, nZ in and zZ out. (Transpose A) template -struct CopyL1ToL0A> { - using Element = int8_t; - using LayoutDst = layout::zZ; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0A() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = - static_cast(CeilDiv(layoutDst.orgShape(1))); - loadDataParams.srcStride = 1; - loadDataParams.dstGap = 0; - loadDataParams.dstFracGap = - CeilDiv(layoutDst.orgShape(1)) - 1; - - for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); - i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1) * 2], - srcTensor[i * layoutSrc.stride(1)], - loadDataParams); +struct CopyL1ToL0A> { + using Element = int8_t; + using LayoutDst = layout::zZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0A() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = CeilDiv(layoutDst.orgShape(1)) - 1; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1) * 2], srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } } - } }; ///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// /// Partial specialization for CopyL1ToL0A, AtlasA2, zN in and zZ out. -template -struct TileCopyTla< - Arch::AtlasA2, - Tensor, LayoutSrc_, - AscendC::TPosition::A1>, - Tensor, LayoutDst_, - AscendC::TPosition::A2>, - std::enable_if_t::value && - tla::detail::iszN::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::A2>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); - const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = dstOuterShapeCol; - loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < dstOuterShapeRow; i++) { - AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], - srcTensor.data()[i * srcOuterStrideRow], - loadDataParams); +template +struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::A2>, + std::enable_if_t::value && + tla::detail::iszN::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } } - } }; /// Partial specialization for CopyL1ToL0A, AtlasA2, nZ in and zZ out. /// (Transpose A) -template -struct TileCopyTla< - Arch::AtlasA2, - Tensor, LayoutSrc_, - AscendC::TPosition::A1>, - Tensor, LayoutDst_, - AscendC::TPosition::A2>, - std::enable_if_t::value && - tla::detail::isnZ::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::A2>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); - const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = dstOuterShapeCol; - loadDataParams.srcStride = 1; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < dstOuterShapeRow; i++) { - AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], - srcTensor.data()[i * srcOuterStrideRow], - loadDataParams); +template +struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::A2>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } } - } }; /// Partial specialization for CopyL1ToL0A, AtlasA2, int8_t, nZ in and zZ out. /// (Transpose A) template struct TileCopyTla< - Arch::AtlasA2, - Tensor, LayoutSrc_, AscendC::TPosition::A1>, + Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::A1>, Tensor, LayoutDst_, AscendC::TPosition::A2>, - std::enable_if_t::value && - tla::detail::isnZ::value>> { - using Element = int8_t; - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = - Tensor, LayoutDst, AscendC::TPosition::A2>; - using TensorSrc = - Tensor, LayoutSrc, AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - const uint32_t srcOuterShapeRow = get<0, 1>(srcTensor.shape()); - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = dstOuterShapeCol; - loadDataParams.srcStride = 1; - loadDataParams.dstGap = 0; - loadDataParams.dstFracGap = dstOuterShapeCol - 1; - - for (uint32_t i = 0; i < srcOuterShapeRow; i++) { - AscendC::LoadDataWithTranspose( - dstTensor.data()[i * dstOuterStrideRow * 2], - srcTensor.data()[i * srcOuterStrideRow], loadDataParams); + std::enable_if_t::value && tla::detail::isnZ::value>> { + using Element = int8_t; + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterShapeRow = get<0, 1>(srcTensor.shape()); + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = dstOuterShapeCol - 1; + + for (uint32_t i = 0; i < srcOuterShapeRow; i++) { + AscendC::LoadDataWithTranspose(dstTensor.data()[i * dstOuterStrideRow * 2], + srcTensor.data()[i * srcOuterStrideRow], loadDataParams); + } } - } }; ///////////////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace Act::Gemm::Tile +} // namespace Act::Gemm::Tile -#endif // ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP +#endif // ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP diff --git a/act/gemm/tile/copy_l1_to_l0b.hpp b/act/gemm/tile/copy_l1_to_l0b.hpp index f9778dcf..6f1ced1d 100644 --- a/act/gemm/tile/copy_l1_to_l0b.hpp +++ b/act/gemm/tile/copy_l1_to_l0b.hpp @@ -22,577 +22,516 @@ using namespace tla; namespace Act::Gemm::Tile { -template struct CopyL1ToL0B { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded copy l1 to l0, can not find the specialization."); +template +struct CopyL1ToL0B { + static_assert(DEPENDENT_FALSE, "Unsupporteded copy l1 to l0, can not find the specialization."); }; //////////////////////////////////////// /// new add gemm template -struct CopyL1ToL0B, - Act::Gemm::GemmType> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::zZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0B() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, - AscendC::LocalTensor srcTensor, LayoutDst layoutDst, - LayoutSrc layoutSrc) { - AscendC::LoadData2DParams loadDataParams; - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutSrc.shape(3)); - loadDataParams.srcStride = 1; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - for (uint32_t i = 0; i < layoutDst.shape(3); i++) { // K N - AscendC::LoadData(dstTensor[i * layoutSrc.stride(1)], - srcTensor[i * layoutSrc.stride(1)], loadDataParams); +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(3)); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + for (uint32_t i = 0; i < layoutDst.shape(3); i++) { // K N + AscendC::LoadData(dstTensor[i * layoutSrc.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } } - } }; template -struct CopyL1ToL0B, - Act::Gemm::GemmType> { - using Element = float; - using LayoutDst = layout::nZ; - using LayoutSrc = layout::zZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0B() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, - AscendC::LocalTensor srcTensor, LayoutDst layoutDst, - LayoutSrc layoutSrc) { - AscendC::LoadData2dTransposeParams loadDataParams; - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutSrc.shape(3) / 2); - loadDataParams.srcStride = 1; - loadDataParams.dstGap = 0; - loadDataParams.dstFracGap = - static_cast(layoutSrc.shape(3) / 2) - 1; - for (uint32_t i = 0; i < layoutDst.shape(3); i++) { // K N - AscendC::LoadDataWithTranspose(dstTensor[i * layoutSrc.stride(1)], - srcTensor[i * layoutSrc.stride(1)], - loadDataParams); +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using Element = float; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutSrc.shape(3) / 2); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = static_cast(layoutSrc.shape(3) / 2) - 1; + for (uint32_t i = 0; i < layoutDst.shape(3); i++) { // K N + AscendC::LoadDataWithTranspose(dstTensor[i * layoutSrc.stride(1)], srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } } - } }; template -struct CopyL1ToL0B, - Act::Gemm::GemmType> { - using Element = int8_t; - using LayoutDst = layout::nZ; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0B() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, - AscendC::LocalTensor srcTensor, LayoutDst layoutDst, - LayoutSrc layoutSrc) { - uint32_t NRound = layoutSrc.shape(2) * layoutSrc.shape(3); - uint32_t KRound = layoutSrc.shape(0) * layoutSrc.shape(1); - uint32_t KL0Alignment = C0_NUM_PER_FRACTAL * 2; - uint32_t KLoops = CeilDiv(KRound, KL0Alignment); - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(NRound / ELE_NUM_PER_C0); - loadDataParams.srcStride = static_cast(KRound / KL0Alignment); - loadDataParams.dstGap = 1; - loadDataParams.dstFracGap = 0; - - for (uint32_t i = 0; i < KLoops; i++) { - AscendC::LoadDataWithTranspose( - dstTensor[i * NRound * KL0Alignment], - srcTensor[i * KL0Alignment * ELE_NUM_PER_C0], loadDataParams); +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using Element = int8_t; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + uint32_t NRound = layoutSrc.shape(2) * layoutSrc.shape(3); + uint32_t KRound = layoutSrc.shape(0) * layoutSrc.shape(1); + uint32_t KL0Alignment = C0_NUM_PER_FRACTAL * 2; + uint32_t KLoops = CeilDiv(KRound, KL0Alignment); + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(NRound / ELE_NUM_PER_C0); + loadDataParams.srcStride = static_cast(KRound / KL0Alignment); + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < KLoops; i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * NRound * KL0Alignment], + srcTensor[i * KL0Alignment * ELE_NUM_PER_C0], loadDataParams); + } } - } }; template -struct CopyL1ToL0B, - Act::Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); - loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < layoutDst.shape(3); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], - srcTensor[i * layoutSrc.stride(3)], loadDataParams); +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(3); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3)], loadDataParams); + } } - } }; template -struct CopyL1ToL0B, - Act::Gemm::GemmType> { - using LayoutDst = layout::nN; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, - AscendC::LocalTensor srcTensor, LayoutDst layoutDst, - LayoutSrc layoutSrc) { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); - loadDataParams.srcStride = layoutSrc.shape(3); - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < layoutSrc.shape(3); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], - srcTensor[i * layoutSrc.stride(3)], loadDataParams); +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::nN; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, + LayoutDst layoutDst, LayoutSrc layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); + loadDataParams.srcStride = layoutSrc.shape(3); + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutSrc.shape(3); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3)], loadDataParams); + } } - } }; template -struct CopyL1ToL0B, - Act::Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::nN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = layoutDst.shape(1) * layoutDst.shape(3); - loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - AscendC::LoadData(dstTensor, srcTensor, loadDataParams); - }; +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = layoutDst.shape(1) * layoutDst.shape(3); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + }; }; template -struct CopyL1ToL0B, - Act::Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::nN; - using Element = float; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast( - CeilDiv(layoutDst.orgShape(0))); - loadDataParams.srcStride = 1; - loadDataParams.dstGap = 0; - loadDataParams.dstFracGap = - CeilDiv(layoutDst.orgShape(0)) - 1; - - for (uint32_t i = 0; i < CeilDiv<2 * ELE_NUM_PER_C0>(layoutDst.orgShape(1)); - i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3) * 2], - srcTensor[i * layoutSrc.stride(3)], - loadDataParams); - } - }; +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nN; + using Element = float; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(0))); + loadDataParams.srcStride = 1; + loadDataParams.dstGap = 0; + loadDataParams.dstFracGap = CeilDiv(layoutDst.orgShape(0)) - 1; + + for (uint32_t i = 0; i < CeilDiv<2 * ELE_NUM_PER_C0>(layoutDst.orgShape(1)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3) * 2], srcTensor[i * layoutSrc.stride(3)], + loadDataParams); + } + }; }; template -struct CopyL1ToL0B, - Act::Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::nZ; - using Element = int8_t; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = - static_cast(CeilDiv(layoutDst.orgShape(0))); - loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL / 2; - loadDataParams.dstGap = 1; - loadDataParams.dstFracGap = 0; - - for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(1)); - i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3)], - srcTensor[i * layoutSrc.stride(3) * 2], - loadDataParams); +struct CopyL1ToL0B, Act::Gemm::GemmType> { + using LayoutDst = layout::zN; + using LayoutSrc = layout::nZ; + using Element = int8_t; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(0))); + loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(1)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3) * 2], + loadDataParams); + } } - } }; //////////////////////////////////////////// /// Partial specialization for int8_t, zN in and nZ out. template -struct CopyL1ToL0B> { - using Element = int8_t; - using LayoutDst = layout::nZ; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = - static_cast(CeilDiv(layoutDst.orgShape(1))); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL / 2; - loadDataParams.dstGap = 1; - loadDataParams.dstFracGap = 0; - - for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); - i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1)], - srcTensor[i * layoutSrc.stride(1) * 2], - loadDataParams); +struct CopyL1ToL0B> { + using Element = int8_t; + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1) * 2], + loadDataParams); + } } - } }; /// Partial specialization for zN in and nZ out. template -struct CopyL1ToL0B< - ArchTag, Gemm::GemmType> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = - static_cast(CeilDiv(layoutDst.orgShape(1))); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); - i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], - srcTensor[i * layoutSrc.stride(1)], loadDataParams); +struct CopyL1ToL0B> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::zN; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); + } } - } }; /// Partial specialization for nZ in and nZ out. (Transpose B) template -struct CopyL1ToL0B< - ArchTag, Gemm::GemmType> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, - AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) { - AscendC::LoadData2DParams loadDataParams; - if (layoutSrc.shape(3) == layoutDst.shape(3)) { - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = - static_cast(layoutDst.shape(1) * layoutDst.shape(3)); - loadDataParams.srcStride = 1; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - AscendC::LoadData(dstTensor, srcTensor, loadDataParams); - } else { - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < layoutDst.shape(1); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], - srcTensor[i * layoutSrc.stride(1)], loadDataParams); - } +struct CopyL1ToL0B> { + using LayoutDst = layout::nZ; + using LayoutSrc = layout::nZ; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + CopyL1ToL0B() {}; + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, + LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) + { + AscendC::LoadData2DParams loadDataParams; + if (layoutSrc.shape(3) == layoutDst.shape(3)) { + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(1) * layoutDst.shape(3)); + loadDataParams.srcStride = 1; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + AscendC::LoadData(dstTensor, srcTensor, loadDataParams); + } else { + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); + loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < layoutDst.shape(1); i++) { + AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], + loadDataParams); + } + } } - } }; ///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// /// Partial specialization for CopyL1ToL0B, AtlasA2, zN in and nZ out. -template -struct TileCopyTla< - Arch::AtlasA2, - Tensor, LayoutSrc_, - AscendC::TPosition::A1>, - Tensor, LayoutDst_, - AscendC::TPosition::B2>, - std::enable_if_t::value && - tla::detail::iszN::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::B2>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); - const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = dstOuterShapeCol; - loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < dstOuterShapeRow; i++) { - AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], - srcTensor.data()[i * srcOuterStrideRow], - loadDataParams); +template +struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::B2>, + std::enable_if_t::value && + tla::detail::iszN::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::B2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = true; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } } - } }; /// Partial specialization for CopyL1ToL0B, AtlasA2, nZ in and nZ out. /// (Transpose B) -template -struct TileCopyTla< - Arch::AtlasA2, - Tensor, LayoutSrc_, - AscendC::TPosition::A1>, - Tensor, LayoutDst_, - AscendC::TPosition::B2>, - std::enable_if_t::value && - tla::detail::isnZ::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::B2>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); - const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = dstOuterShapeCol; - loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < dstOuterShapeRow; i++) { - AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], - srcTensor.data()[i * srcOuterStrideRow], - loadDataParams); +template +struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, + Tensor, LayoutDst_, AscendC::TPosition::B2>, + std::enable_if_t::value && + tla::detail::isnZ::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::B2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2DParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = dstOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; + loadDataParams.sid = 0; + loadDataParams.dstGap = 0; + loadDataParams.ifTranspose = false; + loadDataParams.addrMode = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], + loadDataParams); + } } - } }; /// Partial specialization for CopyL1ToL0B, AtlasA2, int8_t, zN in and nZ out. template struct TileCopyTla< - Arch::AtlasA2, - Tensor, LayoutSrc_, AscendC::TPosition::A1>, + Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::A1>, Tensor, LayoutDst_, AscendC::TPosition::B2>, - std::enable_if_t::value && - tla::detail::iszN::value>> { - using Element = int8_t; - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = - Tensor, LayoutDst, AscendC::TPosition::B2>; - using TensorSrc = - Tensor, LayoutSrc, AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - const uint32_t srcOuterShapeCol = get<1, 1>(srcTensor.shape()); - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = srcOuterShapeCol; - loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL / 2; - loadDataParams.dstGap = 1; - loadDataParams.dstFracGap = 0; - - for (uint32_t i = 0; i < dstOuterShapeRow; i++) { - AscendC::LoadDataWithTranspose( - dstTensor.data()[i * dstOuterStrideRow], - srcTensor.data()[i * srcOuterStrideRow * 2], loadDataParams); + std::enable_if_t::value && tla::detail::iszN::value>> { + using Element = int8_t; + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::B2>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; + + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + + // Methods + + ACT_DEVICE + TileCopyTla() {}; + + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + const uint32_t srcOuterShapeCol = get<1, 1>(srcTensor.shape()); + const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); + const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); + const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); + const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); + + AscendC::LoadData2dTransposeParams loadDataParams; + + loadDataParams.startIndex = 0; + loadDataParams.repeatTimes = srcOuterShapeCol; + loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL / 2; + loadDataParams.dstGap = 1; + loadDataParams.dstFracGap = 0; + + for (uint32_t i = 0; i < dstOuterShapeRow; i++) { + AscendC::LoadDataWithTranspose(dstTensor.data()[i * dstOuterStrideRow], + srcTensor.data()[i * srcOuterStrideRow * 2], loadDataParams); + } } - } }; ///////////////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace Act::Gemm::Tile +} // namespace Act::Gemm::Tile -#endif // ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP +#endif // ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP diff --git a/act/gemm/tile/copy_ub_to_gm.hpp b/act/gemm/tile/copy_ub_to_gm.hpp index 2ab44a84..87d86e3b 100644 --- a/act/gemm/tile/copy_ub_to_gm.hpp +++ b/act/gemm/tile/copy_ub_to_gm.hpp @@ -19,81 +19,62 @@ namespace Act::Gemm::Tile { /// Partial specialization for AtlasA2, RowMajor in and RowMajor out. -template +template struct TileCopyTla< - Arch::AtlasA2, - Tensor, LayoutSrc_, - AscendC::TPosition::VECCALC>, - Tensor, LayoutDst_, - AscendC::TPosition::GM>, - std::enable_if_t::value && - tla::detail::isRowMajor::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::GM>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::VECCALC>; + Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::VECCALC>, + Tensor, LayoutDst_, AscendC::TPosition::GM>, + std::enable_if_t::value && tla::detail::isRowMajor::value>> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::GM>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::VECCALC>; - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - // Methods + // Methods - ACT_DEVICE - TileCopyTla() {}; + ACT_DEVICE + TileCopyTla() {}; - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - AscendC::DataCopyExtParams dataCopyParams( - get<0>(dstTensor.shape()), - get<1>(dstTensor.shape()) * sizeof(ElementSrc), - (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) / - ELE_NUM_PER_C0, - (get<0>(dstTensor.stride()) - get<1>(dstTensor.shape())) * - sizeof(ElementSrc), - 0); - AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams); - }; + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + AscendC::DataCopyExtParams dataCopyParams( + get<0>(dstTensor.shape()), get<1>(dstTensor.shape()) * sizeof(ElementSrc), + (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) / ELE_NUM_PER_C0, + (get<0>(dstTensor.stride()) - get<1>(dstTensor.shape())) * sizeof(ElementSrc), 0); + AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams); + }; }; /// Partial specialization for AtlasA2, RowMajor in and PaddingRowMajor out. -template -struct TileCopyTlaExt, LayoutSrc_, - AscendC::TPosition::VECCALC>, - Tensor, LayoutDst_, - AscendC::TPosition::GM>, - layout::RowMajor, layout::PaddingRowMajor> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, - AscendC::TPosition::GM>; - using TensorSrc = Tensor, LayoutSrc, - AscendC::TPosition::VECCALC>; +template +struct TileCopyTlaExt, LayoutSrc_, AscendC::TPosition::VECCALC>, + Tensor, LayoutDst_, AscendC::TPosition::GM>, layout::RowMajor, + layout::PaddingRowMajor> { + using LayoutDst = LayoutDst_; + using LayoutSrc = LayoutSrc_; + using TensorDst = Tensor, LayoutDst, AscendC::TPosition::GM>; + using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::VECCALC>; - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - // Methods + // Methods - ACT_DEVICE - TileCopyTlaExt() {}; + ACT_DEVICE + TileCopyTlaExt() {}; - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) { - AscendC::DataCopyExtParams dataCopyParams( - get<1, 1>(dstTensor.shape()), - get<1, 0>(dstTensor.shape()) * sizeof(ElementSrc), - (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) / - ELE_NUM_PER_C0, - (get<1, 1>(dstTensor.stride()) - get<1, 0>(dstTensor.shape())) * - sizeof(ElementSrc), - 0); - AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams); - }; + ACT_DEVICE + void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) + { + AscendC::DataCopyExtParams dataCopyParams( + get<1, 1>(dstTensor.shape()), get<1, 0>(dstTensor.shape()) * sizeof(ElementSrc), + (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) / ELE_NUM_PER_C0, + (get<1, 1>(dstTensor.stride()) - get<1, 0>(dstTensor.shape())) * sizeof(ElementSrc), 0); + AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams); + }; }; -} // namespace Act::Gemm::Tile +} // namespace Act::Gemm::Tile -#endif // ACT_GEMM_TILE_COPY_UB_TO_GM_HPP +#endif // ACT_GEMM_TILE_COPY_UB_TO_GM_HPP diff --git a/act/gemm/tile/tile_copy.hpp b/act/gemm/tile/tile_copy.hpp index c9b9b69f..c7135709 100644 --- a/act/gemm/tile/tile_copy.hpp +++ b/act/gemm/tile/tile_copy.hpp @@ -20,18 +20,14 @@ namespace Act::Gemm::Tile { template struct TileCopyTla { - static_assert(DEPENDENT_FALSE, - "Unsupporteded tileCopyTla, can not find the specialization."); + static_assert(DEPENDENT_FALSE, "Unsupporteded tileCopyTla, can not find the specialization."); }; -template +template struct TileCopyTlaExt { - static_assert( - DEPENDENT_FALSE, - "Unsupporteded tileCopyTlaExt, can not find the specialization."); + static_assert(DEPENDENT_FALSE, "Unsupporteded tileCopyTlaExt, can not find the specialization."); }; -} // namespace Act::Gemm::Tile +} // namespace Act::Gemm::Tile #include "../../../act/gemm/helper.hpp" #include "../../../act/gemm/tile/copy_gm_to_l1.hpp" @@ -55,21 +51,16 @@ template < /// GemmType type for Bias operand class BiasType = void> struct TileCopy { - using ElementA = typename AType::Element; - using ElementB = typename BType::Element; - using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< - ElementA, ElementB>::ElementAccumulator; - - using CopyGmToL1A = Gemm::Tile::CopyGmToL1; - using CopyGmToL1B = Gemm::Tile::CopyGmToL1; - using CopyL1ToL0A = - Gemm::Tile::CopyL1ToL0A::L1AType>; - using CopyL1ToL0B = - Gemm::Tile::CopyL1ToL0B::L1BType>; - using CopyL0CToGm = - Gemm::Tile::CopyL0CToGm; + using ElementA = typename AType::Element; + using ElementB = typename BType::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using CopyGmToL1A = Gemm::Tile::CopyGmToL1; + using CopyGmToL1B = Gemm::Tile::CopyGmToL1; + using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A::L1AType>; + using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B::L1BType>; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; }; /// new add @@ -85,130 +76,108 @@ template < /// GemmTpe type for Bias operand class BiasType = void> struct TileCopyGemm { - using ElementA = typename AType::Element; - using ElementB = typename BType::Element; - using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< - ElementA, ElementB>::ElementAccumulator; - // change structural - using L1AType = typename helper::L1ATypeSelectorGemm::L1AType; - using L1BType = typename helper::L1BTypeSelectorGemm::L1BType; - using L0AType = typename helper::L0ATypeSelector::L0AType; - using L0BType = typename helper::L0BTypeSelectorGemm::L0BType; - - using CopyGmToL1A = Gemm::Tile::CopyGmToL1; - using CopyGmToL1B = Gemm::Tile::CopyGmToL1; - using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A; - using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B; - using CopyL0CToGm = - Gemm::Tile::CopyL0CToGm; + using ElementA = typename AType::Element; + using ElementB = typename BType::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + // change structural + using L1AType = typename helper::L1ATypeSelectorGemm::L1AType; + using L1BType = typename helper::L1BTypeSelectorGemm::L1BType; + using L0AType = typename helper::L0ATypeSelector::L0AType; + using L0BType = typename helper::L0BTypeSelectorGemm::L0BType; + + using CopyGmToL1A = Gemm::Tile::CopyGmToL1; + using CopyGmToL1B = Gemm::Tile::CopyGmToL1; + using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A; + using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; }; template < /// Tag indicating architecture - class ArchTag, class TensorA, class LayoutTagA, class TensorB, - class LayoutTagB, class TensorC, class LayoutTagC, class TensorBias = void, - class LayoutTagBias = void> + class ArchTag, class TensorA, class LayoutTagA, class TensorB, class LayoutTagB, class TensorC, class LayoutTagC, + class TensorBias = void, class LayoutTagBias = void> struct PackedTileCopyTla { - using ElementA = typename TensorA::Element; - using ElementB = typename TensorB::Element; - using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< - ElementA, ElementB>::ElementAccumulator; - - using LayoutL1A = detail::TagToLayout_t< - ElementA, typename helper::L1ATypeSelector< - Gemm::GemmType>::L1AType::Layout>; - using LayoutL1B = detail::TagToLayout_t< - ElementB, typename helper::L1BTypeSelector< - Gemm::GemmType>::L1BType::Layout>; - using LayoutL0A = detail::TagToLayout_t; - using LayoutL0B = detail::TagToLayout_t; - using LayoutL0C = typename detail::LayoutL0C; - - using TensorL1A = - Tensor, LayoutL1A, AscendC::TPosition::A1>; - using TensorL1B = - Tensor, LayoutL1B, AscendC::TPosition::A1>; - using TensorL0A = - Tensor, LayoutL0A, AscendC::TPosition::A2>; - using TensorL0B = - Tensor, LayoutL0B, AscendC::TPosition::B2>; - using TensorL0C = Tensor, LayoutL0C, - AscendC::TPosition::CO1>; - - using L1AAlignHelper = Gemm::helper::L1AlignHelper; - using L1BAlignHelper = Gemm::helper::L1AlignHelper; - - using CopyGmToL1A = Gemm::Tile::TileCopyTla; - using CopyGmToL1B = Gemm::Tile::TileCopyTla; - using CopyL1ToL0A = Gemm::Tile::TileCopyTla; - using CopyL1ToL0B = Gemm::Tile::TileCopyTla; - using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; + using ElementA = typename TensorA::Element; + using ElementB = typename TensorB::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using LayoutL1A = + detail::TagToLayout_t>::L1AType::Layout>; + using LayoutL1B = + detail::TagToLayout_t>::L1BType::Layout>; + using LayoutL0A = detail::TagToLayout_t; + using LayoutL0B = detail::TagToLayout_t; + using LayoutL0C = typename detail::LayoutL0C; + + using TensorL1A = Tensor, LayoutL1A, AscendC::TPosition::A1>; + using TensorL1B = Tensor, LayoutL1B, AscendC::TPosition::A1>; + using TensorL0A = Tensor, LayoutL0A, AscendC::TPosition::A2>; + using TensorL0B = Tensor, LayoutL0B, AscendC::TPosition::B2>; + using TensorL0C = Tensor, LayoutL0C, AscendC::TPosition::CO1>; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + using CopyGmToL1A = Gemm::Tile::TileCopyTla; + using CopyGmToL1B = Gemm::Tile::TileCopyTla; + using CopyL1ToL0A = Gemm::Tile::TileCopyTla; + using CopyL1ToL0B = Gemm::Tile::TileCopyTla; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; }; template < /// Tag indicating architecture - class ArchTag, class TensorA, class LayoutTagA, class TensorB, - class LayoutTagB, class TensorC, class LayoutTagC, class TensorBias = void, - class LayoutTagBias = void, bool IS_PADDING_A = false, - bool IS_PADDING_B = false> + class ArchTag, class TensorA, class LayoutTagA, class TensorB, class LayoutTagB, class TensorC, class LayoutTagC, + class TensorBias = void, class LayoutTagBias = void, bool IS_PADDING_A = false, bool IS_PADDING_B = false> struct PaddingPackedTileCopyTla { - static_assert(std::is_same_v || - std::is_same_v, - "Unsupporteded layout, only can be RowMajor and ColumnMajor"); - static_assert(std::is_same_v || - std::is_same_v, - "Unsupporteded layout, only can be RowMajor and ColumnMajor"); - using ElementA = typename TensorA::Element; - using ElementB = typename TensorB::Element; - using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< - ElementA, ElementB>::ElementAccumulator; - - using LayoutTagL1A = typename helper::L1ATypeSelector< - Gemm::GemmType>::L1AType::Layout; - using LayoutTagL1B = typename helper::L1BTypeSelector< - Gemm::GemmType>::L1BType::Layout; - using LayoutL1A = detail::TagToLayout_t; - using LayoutL1B = detail::TagToLayout_t; - using LayoutL0A = detail::TagToLayout_t; - using LayoutL0B = detail::TagToLayout_t; - using LayoutL0C = typename detail::LayoutL0C; - - using TensorL1A = - Tensor, LayoutL1A, AscendC::TPosition::A1>; - using TensorL1B = - Tensor, LayoutL1B, AscendC::TPosition::A1>; - using TensorL0A = - Tensor, LayoutL0A, AscendC::TPosition::A2>; - using TensorL0B = - Tensor, LayoutL0B, AscendC::TPosition::B2>; - using TensorL0C = Tensor, LayoutL0C, - AscendC::TPosition::CO1>; - - using L1AAlignHelper = Gemm::helper::L1AlignHelper; - using L1BAlignHelper = Gemm::helper::L1AlignHelper; - - using LayoutPaddingTagA = - std::conditional_t, - layout::PaddingRowMajor, layout::PaddingColumnMajor>; - using LayoutPaddingTagB = - std::conditional_t, - layout::PaddingRowMajor, layout::PaddingColumnMajor>; - - using CopyGmToL1A = std::conditional_t< - IS_PADDING_A, - Gemm::Tile::TileCopyTlaExt, - Gemm::Tile::TileCopyTla>; - using CopyGmToL1B = std::conditional_t< - IS_PADDING_B, - Gemm::Tile::TileCopyTlaExt, - Gemm::Tile::TileCopyTla>; - - using CopyL1ToL0A = Gemm::Tile::TileCopyTla; - using CopyL1ToL0B = Gemm::Tile::TileCopyTla; - using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; + static_assert(std::is_same_v || std::is_same_v, + "Unsupporteded layout, only can be RowMajor and ColumnMajor"); + static_assert(std::is_same_v || std::is_same_v, + "Unsupporteded layout, only can be RowMajor and ColumnMajor"); + using ElementA = typename TensorA::Element; + using ElementB = typename TensorB::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + using LayoutTagL1A = typename helper::L1ATypeSelector>::L1AType::Layout; + using LayoutTagL1B = typename helper::L1BTypeSelector>::L1BType::Layout; + using LayoutL1A = detail::TagToLayout_t; + using LayoutL1B = detail::TagToLayout_t; + using LayoutL0A = detail::TagToLayout_t; + using LayoutL0B = detail::TagToLayout_t; + using LayoutL0C = typename detail::LayoutL0C; + + using TensorL1A = Tensor, LayoutL1A, AscendC::TPosition::A1>; + using TensorL1B = Tensor, LayoutL1B, AscendC::TPosition::A1>; + using TensorL0A = Tensor, LayoutL0A, AscendC::TPosition::A2>; + using TensorL0B = Tensor, LayoutL0B, AscendC::TPosition::B2>; + using TensorL0C = Tensor, LayoutL0C, AscendC::TPosition::CO1>; + + using L1AAlignHelper = Gemm::helper::L1AlignHelper; + using L1BAlignHelper = Gemm::helper::L1AlignHelper; + + using LayoutPaddingTagA = std::conditional_t, layout::PaddingRowMajor, + layout::PaddingColumnMajor>; + using LayoutPaddingTagB = std::conditional_t, layout::PaddingRowMajor, + layout::PaddingColumnMajor>; + + using CopyGmToL1A = + std::conditional_t, + Gemm::Tile::TileCopyTla>; + using CopyGmToL1B = + std::conditional_t, + Gemm::Tile::TileCopyTla>; + + using CopyL1ToL0A = Gemm::Tile::TileCopyTla; + using CopyL1ToL0B = Gemm::Tile::TileCopyTla; + using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; }; -} // namespace Act::Gemm::Tile +} // namespace Act::Gemm::Tile -#endif // ACT_GEMM_TILE_TILE_COPY_HPP +#endif // ACT_GEMM_TILE_TILE_COPY_HPP diff --git a/act/gemm/tile/tile_mmad.hpp b/act/gemm/tile/tile_mmad.hpp index 44824087..7beacdf7 100644 --- a/act/gemm/tile/tile_mmad.hpp +++ b/act/gemm/tile/tile_mmad.hpp @@ -29,37 +29,35 @@ template < /// GemmType type for Bias operand class BiasType_> struct TileMmad { - using ElementA = typename AType_::Element; - using ElementB = typename BType_::Element; - using ElementAccumulator = typename Gemm::helper::ElementAccumulatorSelector< - ElementA, ElementB>::ElementAccumulator; - - // Methods - - ACT_DEVICE - TileMmad() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &l0CTensor, - AscendC::LocalTensor const &l0ATensor, - AscendC::LocalTensor const &l0BTensor, uint32_t m, - uint32_t n, uint32_t k, bool initC = true, - uint8_t unitFlag = 0) { - AscendC::MmadParams mmadParams; - mmadParams.m = m; - mmadParams.n = n; - mmadParams.k = k; - mmadParams.unitFlag = unitFlag; - mmadParams.cmatrixInitVal = initC; - - AscendC::Mmad(l0CTensor, l0ATensor, l0BTensor, mmadParams); - - const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; - if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < - PIPE_M_BARRIER_THRESHOLD) { - AscendC::PipeBarrier(); + using ElementA = typename AType_::Element; + using ElementB = typename BType_::Element; + using ElementAccumulator = + typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; + + // Methods + + ACT_DEVICE + TileMmad() {} + + ACT_DEVICE + void operator()(AscendC::LocalTensor const &l0CTensor, + AscendC::LocalTensor const &l0ATensor, AscendC::LocalTensor const &l0BTensor, + uint32_t m, uint32_t n, uint32_t k, bool initC = true, uint8_t unitFlag = 0) + { + AscendC::MmadParams mmadParams; + mmadParams.m = m; + mmadParams.n = n; + mmadParams.k = k; + mmadParams.unitFlag = unitFlag; + mmadParams.cmatrixInitVal = initC; + + AscendC::Mmad(l0CTensor, l0ATensor, l0BTensor, mmadParams); + + const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; + if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < PIPE_M_BARRIER_THRESHOLD) { + AscendC::PipeBarrier(); + } } - } }; ///////////////////////////////////////////TileMmadTla///////////////////////////////////////////////// @@ -76,39 +74,37 @@ template < /// Tensor type for Bias operand class TensorBias = void> struct TileMmadTla { - // Methods - - ACT_DEVICE - TileMmadTla() {} - - ACT_DEVICE - void operator()(TensorC const &l0CTensor, TensorA const &l0ATensor, - TensorB const &l0BTensor, bool initC = true, - uint8_t unitFlag = 0) { - const uint32_t m = get<0>(l0ATensor.orgShape()); - const uint32_t n = get<1>(l0BTensor.orgShape()); - const uint32_t k = get<1>(l0ATensor.orgShape()); - - AscendC::MmadParams mmadParams; - mmadParams.m = m; - mmadParams.n = n; - mmadParams.k = k; - mmadParams.unitFlag = unitFlag; - mmadParams.cmatrixInitVal = initC; - - AscendC::Mmad(l0CTensor.data(), l0ATensor.data(), l0BTensor.data(), - mmadParams); - - const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; - if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < - PIPE_M_BARRIER_THRESHOLD) { - AscendC::PipeBarrier(); + // Methods + + ACT_DEVICE + TileMmadTla() {} + + ACT_DEVICE + void operator()(TensorC const &l0CTensor, TensorA const &l0ATensor, TensorB const &l0BTensor, bool initC = true, + uint8_t unitFlag = 0) + { + const uint32_t m = get<0>(l0ATensor.orgShape()); + const uint32_t n = get<1>(l0BTensor.orgShape()); + const uint32_t k = get<1>(l0ATensor.orgShape()); + + AscendC::MmadParams mmadParams; + mmadParams.m = m; + mmadParams.n = n; + mmadParams.k = k; + mmadParams.unitFlag = unitFlag; + mmadParams.cmatrixInitVal = initC; + + AscendC::Mmad(l0CTensor.data(), l0ATensor.data(), l0BTensor.data(), mmadParams); + + const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; + if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < PIPE_M_BARRIER_THRESHOLD) { + AscendC::PipeBarrier(); + } } - } }; ///////////////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace Act::Gemm::Tile +} // namespace Act::Gemm::Tile -#endif // ACT_GEMM_TILE_TILE_MMAD_HPP +#endif // ACT_GEMM_TILE_TILE_MMAD_HPP diff --git a/act/gemm_coord.hpp b/act/gemm_coord.hpp index 6eb6f83a..2e8dbb56 100644 --- a/act/gemm_coord.hpp +++ b/act/gemm_coord.hpp @@ -26,95 +26,134 @@ template < /// Inner dimension of matrix product uint32_t K_ = 1> struct GemmShape { - static constexpr uint32_t M = M_; - static constexpr uint32_t N = N_; - static constexpr uint32_t K = K_; - - static constexpr int64_t MN = M * N; - static constexpr int64_t MK = M * K; - static constexpr int64_t KN = N * K; - static constexpr int64_t MNK = M * N * K; - - static constexpr int64_t COUNT = MNK; - - /// Returns a Coord object - ACT_HOST_DEVICE - static Coord<3> ToCoord() { return MakeCoord(M, N, K); } - - ACT_HOST_DEVICE - static Coord<2> ToCoordMN() { return MakeCoord(M, N); } - - ACT_HOST_DEVICE - static Coord<2> ToCoordMK() { return MakeCoord(M, K); } - - ACT_HOST_DEVICE - static Coord<2> ToCoordKN() { return MakeCoord(K, N); } + static constexpr uint32_t M = M_; + static constexpr uint32_t N = N_; + static constexpr uint32_t K = K_; + + static constexpr int64_t MN = M * N; + static constexpr int64_t MK = M * K; + static constexpr int64_t KN = N * K; + static constexpr int64_t MNK = M * N * K; + + static constexpr int64_t COUNT = MNK; + + /// Returns a Coord object + ACT_HOST_DEVICE + static Coord<3> ToCoord() + { + return MakeCoord(M, N, K); + } + + ACT_HOST_DEVICE + static Coord<2> ToCoordMN() + { + return MakeCoord(M, N); + } + + ACT_HOST_DEVICE + static Coord<2> ToCoordMK() + { + return MakeCoord(M, K); + } + + ACT_HOST_DEVICE + static Coord<2> ToCoordKN() + { + return MakeCoord(K, N); + } }; /// GemmCoord is a structure derived from Coord<3> that specifies a location /// within the coordinate space of a Gemm problem. struct GemmCoord : public Coord<3, uint32_t> { - /// Integer-valued index - using Index = uint32_t; - - /// Base type is a Coord of rank=3 - using Base = Coord<3, Index>; - - /// Gemm M dimension - rows of the output C matrix - static constexpr int M_INDEX = 0; - - /// Gemm N dimension - columns of the output C matrix - static constexpr int N_INDEX = 1; - - /// Gemm K dimension - inner dimension of the Gemm problem - static constexpr int K_INDEX = 2; - - /// Default ctor - ACT_HOST_DEVICE - GemmCoord() {} - - /// Constructs from Coord<3> and a batch - ACT_HOST_DEVICE - GemmCoord(Coord<3, Index> const &coord) : Base(coord) {} - - /// Helper to construct from a K, N, M, batch variables - ACT_HOST_DEVICE - GemmCoord(Index m, Index n, Index k) : Base(MakeCoord(m, n, k)) {} - - /// Returns the Gemm M coordinate - ACT_HOST_DEVICE - Index const &m() const { return this->At(M_INDEX); } - - /// Returns reference to the Gemm M coordinate - ACT_HOST_DEVICE - Index &m() { return this->At(M_INDEX); } - - /// Returns the Gemm N coordinate - ACT_HOST_DEVICE - Index const &n() const { return this->At(N_INDEX); } - - /// Returns reference to the Gemm N coordinate - ACT_HOST_DEVICE - Index &n() { return this->At(N_INDEX); } - - /// Returns the Gemm K coordinate - ACT_HOST_DEVICE - Index const &k() const { return this->At(K_INDEX); } - - /// Returns reference to the Gemm K coordinate - ACT_HOST_DEVICE - Index &k() { return this->At(K_INDEX); } - - ACT_HOST_DEVICE - auto GetCoordMN() const { return this->GetCoordByAxis(); } - - ACT_HOST_DEVICE - auto GetCoordMK() const { return this->GetCoordByAxis(); } - - ACT_HOST_DEVICE - auto GetCoordKN() const { return this->GetCoordByAxis(); } + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=3 + using Base = Coord<3, Index>; + + /// Gemm M dimension - rows of the output C matrix + static constexpr int M_INDEX = 0; + + /// Gemm N dimension - columns of the output C matrix + static constexpr int N_INDEX = 1; + + /// Gemm K dimension - inner dimension of the Gemm problem + static constexpr int K_INDEX = 2; + + /// Default ctor + ACT_HOST_DEVICE + GemmCoord() {} + + /// Constructs from Coord<3> and a batch + ACT_HOST_DEVICE + GemmCoord(Coord<3, Index> const &coord) : Base(coord) {} + + /// Helper to construct from a K, N, M, batch variables + ACT_HOST_DEVICE + GemmCoord(Index m, Index n, Index k) : Base(MakeCoord(m, n, k)) {} + + /// Returns the Gemm M coordinate + ACT_HOST_DEVICE + Index const &m() const + { + return this->At(M_INDEX); + } + + /// Returns reference to the Gemm M coordinate + ACT_HOST_DEVICE + Index &m() + { + return this->At(M_INDEX); + } + + /// Returns the Gemm N coordinate + ACT_HOST_DEVICE + Index const &n() const + { + return this->At(N_INDEX); + } + + /// Returns reference to the Gemm N coordinate + ACT_HOST_DEVICE + Index &n() + { + return this->At(N_INDEX); + } + + /// Returns the Gemm K coordinate + ACT_HOST_DEVICE + Index const &k() const + { + return this->At(K_INDEX); + } + + /// Returns reference to the Gemm K coordinate + ACT_HOST_DEVICE + Index &k() + { + return this->At(K_INDEX); + } + + ACT_HOST_DEVICE + auto GetCoordMN() const + { + return this->GetCoordByAxis(); + } + + ACT_HOST_DEVICE + auto GetCoordMK() const + { + return this->GetCoordByAxis(); + } + + ACT_HOST_DEVICE + auto GetCoordKN() const + { + return this->GetCoordByAxis(); + } }; -} // namespace Act +} // namespace Act -#endif // ACT_GEMM_COORD_HPP +#endif // ACT_GEMM_COORD_HPP diff --git a/act/gemv_coord.hpp b/act/gemv_coord.hpp index 08af1180..2e925c4a 100644 --- a/act/gemv_coord.hpp +++ b/act/gemv_coord.hpp @@ -24,66 +24,84 @@ template < /// Columns of the matrix (number of elements in the input vector) uint32_t N_ = 1> struct GemvShape { - static constexpr uint32_t M = M_; - static constexpr uint32_t N = N_; + static constexpr uint32_t M = M_; + static constexpr uint32_t N = N_; - static constexpr int64_t MN = M * N; + static constexpr int64_t MN = M * N; - static constexpr int64_t COUNT = MN; + static constexpr int64_t COUNT = MN; - /// Returns a Coord object - ACT_HOST_DEVICE - static Coord<2> ToCoord() { return MakeCoord(M, N); } + /// Returns a Coord object + ACT_HOST_DEVICE + static Coord<2> ToCoord() + { + return MakeCoord(M, N); + } }; /// GemvCoord is a structure derived from Coord<2> that specifies a location /// within the coordinate space of a GEMV problem. struct GemvCoord : public Coord<2, uint32_t> { - /// Integer-valued index - using Index = uint32_t; - - /// Base type is a Coord of rank=2 - using Base = Coord<2, Index>; - - /// GEMV M dimension - rows of the output vector (y) - static constexpr int M_INDEX = 0; - - /// GEMV N dimension - columns of the matrix (length of the input vector x) - static constexpr int N_INDEX = 1; - - /// Default ctor - ACT_HOST_DEVICE - GemvCoord() {} - - /// Constructs from Coord<2> and a batch - ACT_HOST_DEVICE - GemvCoord(Coord<2, Index> const &coord) : Base(coord) {} - - /// Helper to construct from M, N coordinates - ACT_HOST_DEVICE - GemvCoord(Index m, Index n) : Base(MakeCoord(m, n)) {} - - /// Returns the GEMV M coordinate (row of the result y) - ACT_HOST_DEVICE - Index const &m() const { return this->At(M_INDEX); } - - /// Returns reference to the GEMV M coordinate - ACT_HOST_DEVICE - Index &m() { return this->At(M_INDEX); } - - /// Returns the GEMV N coordinate (column of the matrix A or the input vector - /// x) - ACT_HOST_DEVICE - Index const &n() const { return this->At(N_INDEX); } - - /// Returns reference to the GEMV N coordinate - ACT_HOST_DEVICE - Index &n() { return this->At(N_INDEX); } - - ACT_HOST_DEVICE - auto GetCoordMN() const { return this->GetCoordByAxis(); } + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=2 + using Base = Coord<2, Index>; + + /// GEMV M dimension - rows of the output vector (y) + static constexpr int M_INDEX = 0; + + /// GEMV N dimension - columns of the matrix (length of the input vector x) + static constexpr int N_INDEX = 1; + + /// Default ctor + ACT_HOST_DEVICE + GemvCoord() {} + + /// Constructs from Coord<2> and a batch + ACT_HOST_DEVICE + GemvCoord(Coord<2, Index> const &coord) : Base(coord) {} + + /// Helper to construct from M, N coordinates + ACT_HOST_DEVICE + GemvCoord(Index m, Index n) : Base(MakeCoord(m, n)) {} + + /// Returns the GEMV M coordinate (row of the result y) + ACT_HOST_DEVICE + Index const &m() const + { + return this->At(M_INDEX); + } + + /// Returns reference to the GEMV M coordinate + ACT_HOST_DEVICE + Index &m() + { + return this->At(M_INDEX); + } + + /// Returns the GEMV N coordinate (column of the matrix A or the input vector + /// x) + ACT_HOST_DEVICE + Index const &n() const + { + return this->At(N_INDEX); + } + + /// Returns reference to the GEMV N coordinate + ACT_HOST_DEVICE + Index &n() + { + return this->At(N_INDEX); + } + + ACT_HOST_DEVICE + auto GetCoordMN() const + { + return this->GetCoordByAxis(); + } }; -} // namespace Act +} // namespace Act -#endif // ACT_GEMV_COORD_HPP +#endif // ACT_GEMV_COORD_HPP diff --git a/act/layout/layout.hpp b/act/layout/layout.hpp index 981f0d33..5282545e 100644 --- a/act/layout/layout.hpp +++ b/act/layout/layout.hpp @@ -17,4 +17,4 @@ #include "../../act/layout/matrix.hpp" #include "../../act/layout/vector.hpp" -#endif // ACT_LAYOUT_LAYOUT_HPP +#endif // ACT_LAYOUT_LAYOUT_HPP diff --git a/act/layout/matrix.hpp b/act/layout/matrix.hpp index 2035ca9a..be705ce0 100644 --- a/act/layout/matrix.hpp +++ b/act/layout/matrix.hpp @@ -23,599 +23,727 @@ namespace Act::layout { /// Mapping function for row-major matrices struct RowMajor { public: - /// Logical rank of tensor - static constexpr int RANK = 2; + /// Logical rank of tensor + static constexpr int RANK = 2; - /// Index type used for coordinates - using Index = uint32_t; + /// Index type used for coordinates + using Index = uint32_t; - /// Long index type used for offsets - using LongIndex = int64_t; + /// Long index type used for offsets + using LongIndex = int64_t; - /// Logical coordinate - using Shape = Coord; + /// Logical coordinate + using Shape = Coord; - /// Stride vector - using Stride = Coord; + /// Stride vector + using Stride = Coord; public: - /// Constructor - ACT_HOST_DEVICE - RowMajor(Index rows = 0, Index cols = 0) - : shape_(MakeCoord(rows, cols)), - stride_(MakeCoord(LongIndex(cols), LongIndex(1))) {} - - /// Constructor - ACT_HOST_DEVICE - RowMajor(Index rows, Index cols, LongIndex ldm) - : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(ldm, LongIndex(1))) {} - - /// Ctor - ACT_HOST_DEVICE - RowMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} - - template - ACT_HOST_DEVICE static RowMajor MakeLayoutInUb(MatrixCoord const &shape) { - return RowMajor(shape.row(), shape.column(), - RoundUp(shape.column())); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const { - return LongIndex(coord.row()) * stride_[0] + LongIndex(coord.column()); - } - - /// Returns the layout of a tile. - ACT_HOST_DEVICE - RowMajor GetTileLayout(MatrixCoord const &tileShape) const { - return RowMajor(tileShape, stride()); - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const { return shape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) { return shape_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const { return stride_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) { return stride_[idx]; } + /// Constructor + ACT_HOST_DEVICE + RowMajor(Index rows = 0, Index cols = 0) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(cols), LongIndex(1))) + {} + + /// Constructor + ACT_HOST_DEVICE + RowMajor(Index rows, Index cols, LongIndex ldm) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(ldm, LongIndex(1))) + {} + + /// Ctor + ACT_HOST_DEVICE + RowMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + template + ACT_HOST_DEVICE static RowMajor MakeLayoutInUb(MatrixCoord const &shape) + { + return RowMajor(shape.row(), shape.column(), RoundUp(shape.column())); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) * stride_[0] + LongIndex(coord.column()); + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + RowMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return RowMajor(tileShape, stride()); + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } private: - // - // Data members - // + // + // Data members + // - /// Shape data member - Shape shape_; + /// Shape data member + Shape shape_; - /// Stride data member - Stride stride_; + /// Stride data member + Stride stride_; }; /// Mapping function for col-major matrices struct ColumnMajor { public: - /// Logical rank of tensor - static constexpr int RANK = 2; + /// Logical rank of tensor + static constexpr int RANK = 2; - /// Index type used for coordinates - using Index = uint32_t; + /// Index type used for coordinates + using Index = uint32_t; - /// Long index type used for offsets - using LongIndex = int64_t; + /// Long index type used for offsets + using LongIndex = int64_t; - /// Logical coordinate - using Shape = Coord; + /// Logical coordinate + using Shape = Coord; - /// Stride vector - using Stride = Coord; + /// Stride vector + using Stride = Coord; public: - // Methods - - /// Constructor - ACT_HOST_DEVICE - ColumnMajor(Index rows = 0, Index cols = 0) - : shape_(MakeCoord(rows, cols)), - stride_(MakeCoord(LongIndex(1), LongIndex(rows))) {} - - /// Constructor - ACT_HOST_DEVICE - ColumnMajor(Index rows, Index cols, LongIndex ldm) - : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(1), ldm)) {} - - /// Ctor - ACT_HOST_DEVICE - ColumnMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const { - return LongIndex(coord.row()) + LongIndex(coord.column()) * stride_[1]; - } - - /// Returns the layout of a tile. - ACT_HOST_DEVICE - ColumnMajor GetTileLayout(MatrixCoord const &tileShape) const { - return ColumnMajor(tileShape, stride()); - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const { return shape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) { return shape_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const { return stride_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) { return stride_[idx]; } + // Methods + + /// Constructor + ACT_HOST_DEVICE + ColumnMajor(Index rows = 0, Index cols = 0) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(1), LongIndex(rows))) + {} + + /// Constructor + ACT_HOST_DEVICE + ColumnMajor(Index rows, Index cols, LongIndex ldm) + : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(1), ldm)) + {} + + /// Ctor + ACT_HOST_DEVICE + ColumnMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) + LongIndex(coord.column()) * stride_[1]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + ColumnMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return ColumnMajor(tileShape, stride()); + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } private: - // - // Data members - // + // + // Data members + // - /// Shape data member - Shape shape_; + /// Shape data member + Shape shape_; - /// Stride data member - Stride stride_; + /// Stride data member + Stride stride_; }; /// Mapping function for nZ matrices which is col-major inside fractal and /// row-major between fractal struct nZ { public: - /// Logical rank of tensor - static constexpr int RANK = 4; + /// Logical rank of tensor + static constexpr int RANK = 4; - /// Index type used for coordinates - using Index = uint32_t; + /// Index type used for coordinates + using Index = uint32_t; - /// Long index type used for offsets - using LongIndex = int64_t; + /// Long index type used for offsets + using LongIndex = int64_t; - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; - /// Logical coordinate - using OrgShape = Coord; + /// Logical coordinate + using OrgShape = Coord; - /// Logical coordinate - using Shape = Coord; + /// Logical coordinate + using Shape = Coord; - /// Stride vector - using Stride = Coord; + /// Stride vector + using Stride = Coord; public: - // Methods - - /// Constructor - ACT_HOST_DEVICE constexpr nZ( - Index orgRows = 0, /// Number of rows of origin matrices - Index orgCols = 0, /// Number of cols of origin matrices - Index rowsInFractal = 0, /// Number of rows inside the fractal - Index rowsByFractal = 0, /// number of rows by the fractal - Index colsInFractal = 0, /// number of cols inside the fractal - Index colsByFractal = 0, /// number of cols by the fractal - LongIndex strideRowsInFractal = - 0, /// number of elements between adjacent rows inside the fractal - LongIndex strideRowsByFractal = - 0, /// number of elements between adjacent fractal rows - LongIndex strideColsInFractal = - 0, /// number of elements between adjacent cols inside the fractal - LongIndex strideColsByFractal = - 0) /// number of elements between adjacent fractal cols - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, - colsByFractal)), - stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, - strideColsInFractal, strideColsByFractal)) {} - - /// Ctor - ACT_HOST_DEVICE constexpr nZ(OrgShape orgShape, Shape shape, Stride stride) - : orgShape_(orgShape), shape_(shape), stride_(stride) {} - - /// Make the layout of a coordinate (row, column) - template - ACT_HOST_DEVICE constexpr static nZ MakeLayout(Index orgRows, Index orgCols) { - constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - Index rowsRound = RoundUp(orgRows); - Index colsRound = RoundUp(orgCols); - return nZ(orgRows, orgCols, ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, - C0_NUM_PER_FRACTAL, colsRound / C0_NUM_PER_FRACTAL, 1, - colsRound * ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const { - return LongIndex(coord.row()) / shape_[0] * stride_[1] + - LongIndex(coord.column()) / shape_[2] * stride_[3] + - (LongIndex(coord.row()) % shape_[0]) * stride_[0] + - (LongIndex(coord.column()) % shape_[2]) * stride_[2]; - } - - /// Returns the layout of a tile. - ACT_HOST_DEVICE - nZ GetTileLayout(MatrixCoord const &tileOriShape) const { - auto tileShape = - MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), - CeilDiv(tileOriShape.column(), shape(2))); - return nZ(tileOriShape, tileShape, stride()); - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const { return shape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) { return shape_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const { return stride_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) { return stride_[idx]; } + // Methods + + /// Constructor + ACT_HOST_DEVICE constexpr nZ( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + ACT_HOST_DEVICE constexpr nZ(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) + {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE constexpr static nZ MakeLayout(Index orgRows, Index orgCols) + { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return nZ(orgRows, orgCols, ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, C0_NUM_PER_FRACTAL, + colsRound / C0_NUM_PER_FRACTAL, 1, colsRound * ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3] + + (LongIndex(coord.row()) % shape_[0]) * stride_[0] + (LongIndex(coord.column()) % shape_[2]) * stride_[2]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + nZ GetTileLayout(MatrixCoord const &tileOriShape) const + { + auto tileShape = MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), + CeilDiv(tileOriShape.column(), shape(2))); + return nZ(tileOriShape, tileShape, stride()); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } private: - /// Origin Shape data member - OrgShape orgShape_; + /// Origin Shape data member + OrgShape orgShape_; - /// Shape data member - Shape shape_; + /// Shape data member + Shape shape_; - /// Stride data member - Stride stride_; + /// Stride data member + Stride stride_; }; /// Mapping function for zN matrices which is row-major inside fractal and /// col-major between fractal struct zN { public: - /// Logical rank of tensor - static constexpr int RANK = 4; + /// Logical rank of tensor + static constexpr int RANK = 4; - /// Index type used for coordinates - using Index = uint32_t; + /// Index type used for coordinates + using Index = uint32_t; - /// Long index type used for offsets - using LongIndex = int64_t; + /// Long index type used for offsets + using LongIndex = int64_t; - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; - /// Logical coordinate - using OrgShape = Coord; + /// Logical coordinate + using OrgShape = Coord; - /// Logical coordinate - using Shape = Coord; + /// Logical coordinate + using Shape = Coord; - /// Stride vector - using Stride = Coord; + /// Stride vector + using Stride = Coord; public: - // Methods - - /// Constructor - ACT_HOST_DEVICE constexpr zN( - Index orgRows = 0, /// Number of rows of origin matrices - Index orgCols = 0, /// Number of cols of origin matrices - Index rowsInFractal = 0, /// Number of rows inside the fractal - Index rowsByFractal = 0, /// number of rows by the fractal - Index colsInFractal = 0, /// number of cols inside the fractal - Index colsByFractal = 0, /// number of cols by the fractal - LongIndex strideRowsInFractal = - 0, /// number of elements between adjacent rows inside the fractal - LongIndex strideRowsByFractal = - 0, /// number of elements between adjacent fractal rows - LongIndex strideColsInFractal = - 0, /// number of elements between adjacent cols inside the fractal - LongIndex strideColsByFractal = - 0) /// number of elements between adjacent fractal cols - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, - colsByFractal)), - stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, - strideColsInFractal, strideColsByFractal)) {} - - /// Ctor - ACT_HOST_DEVICE constexpr zN(OrgShape orgShape, Shape shape, Stride stride) - : orgShape_(orgShape), shape_(shape), stride_(stride) {} - - /// Make the layout of a coordinate (row, column) - template - ACT_HOST_DEVICE constexpr static zN MakeLayout(Index orgRows, Index orgCols) { - constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - Index rowsRound = RoundUp(orgRows); - Index colsRound = RoundUp(orgCols); - return zN(orgRows, orgCols, C0_NUM_PER_FRACTAL, - rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, - colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL, - 1, rowsRound * ELE_NUM_PER_C0); - } - - ACT_HOST_DEVICE - static zN MakeLayoutInL0C(MatrixCoord const &shape) { - return zN(shape.row(), shape.column(), C0_NUM_PER_FRACTAL, - CeilDiv(shape.row()), C0_NUM_PER_FRACTAL, - CeilDiv(shape.column()), C0_NUM_PER_FRACTAL, - C0_NUM_PER_FRACTAL * C0_NUM_PER_FRACTAL, 1, - RoundUp(shape.row()) * C0_NUM_PER_FRACTAL); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const { - return LongIndex(coord.row()) / shape_[0] * stride_[1] + - LongIndex(coord.column()) / shape_[2] * stride_[3] + - (LongIndex(coord.row()) % shape_[0]) * stride_[0] + - (LongIndex(coord.column()) % shape_[2]) * stride_[2]; - } - - /// Returns the layout of a tile. - ACT_HOST_DEVICE - zN GetTileLayout(MatrixCoord const &tileOriShape) const { - auto tileShape = - MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), - CeilDiv(tileOriShape.column(), shape(2))); - return zN(tileOriShape, tileShape, stride()); - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const { return shape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) { return shape_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const { return stride_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) { return stride_[idx]; } + // Methods + + /// Constructor + ACT_HOST_DEVICE constexpr zN( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + ACT_HOST_DEVICE constexpr zN(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) + {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE constexpr static zN MakeLayout(Index orgRows, Index orgCols) + { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return zN(orgRows, orgCols, C0_NUM_PER_FRACTAL, rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, + colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL, 1, rowsRound * ELE_NUM_PER_C0); + } + + ACT_HOST_DEVICE + static zN MakeLayoutInL0C(MatrixCoord const &shape) + { + return zN(shape.row(), shape.column(), C0_NUM_PER_FRACTAL, CeilDiv(shape.row()), + C0_NUM_PER_FRACTAL, CeilDiv(shape.column()), C0_NUM_PER_FRACTAL, + C0_NUM_PER_FRACTAL * C0_NUM_PER_FRACTAL, 1, + RoundUp(shape.row()) * C0_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3] + + (LongIndex(coord.row()) % shape_[0]) * stride_[0] + (LongIndex(coord.column()) % shape_[2]) * stride_[2]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + zN GetTileLayout(MatrixCoord const &tileOriShape) const + { + auto tileShape = MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), + CeilDiv(tileOriShape.column(), shape(2))); + return zN(tileOriShape, tileShape, stride()); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } private: - /// Origin Shape data member - OrgShape orgShape_; + /// Origin Shape data member + OrgShape orgShape_; - /// Shape data member - Shape shape_; + /// Shape data member + Shape shape_; - /// Stride data member - Stride stride_; + /// Stride data member + Stride stride_; }; /// Mapping function for zN matrices which is row-major inside fractal and /// row-major between fractal struct zZ { public: - /// Logical rank of tensor - static constexpr int RANK = 4; + /// Logical rank of tensor + static constexpr int RANK = 4; - /// Index type used for coordinates - using Index = uint32_t; + /// Index type used for coordinates + using Index = uint32_t; - /// Long index type used for offsets - using LongIndex = int64_t; + /// Long index type used for offsets + using LongIndex = int64_t; - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; - /// Logical coordinate - using OrgShape = Coord; + /// Logical coordinate + using OrgShape = Coord; - /// Logical coordinate - using Shape = Coord; + /// Logical coordinate + using Shape = Coord; - /// Stride vector - using Stride = Coord; + /// Stride vector + using Stride = Coord; public: - // Methods - - /// Constructor - ACT_HOST_DEVICE constexpr zZ( - Index orgRows = 0, /// Number of rows of origin matrices - Index orgCols = 0, /// Number of cols of origin matrices - Index rowsInFractal = 0, /// Number of rows inside the fractal - Index rowsByFractal = 0, /// number of rows by the fractal - Index colsInFractal = 0, /// number of cols inside the fractal - Index colsByFractal = 0, /// number of cols by the fractal - LongIndex strideRowsInFractal = - 0, /// number of elements between adjacent rows inside the fractal - LongIndex strideRowsByFractal = - 0, /// number of elements between adjacent fractal rows - LongIndex strideColsInFractal = - 0, /// number of elements between adjacent cols inside the fractal - LongIndex strideColsByFractal = - 0) /// number of elements between adjacent fractal cols - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, - colsByFractal)), - stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, - strideColsInFractal, strideColsByFractal)) {} - - /// Ctor - ACT_HOST_DEVICE constexpr zZ(OrgShape orgShape, Shape shape, Stride stride) - : orgShape_(orgShape), shape_(shape), stride_(stride) {} - - /// Make the layout of a coordinate (row, column) - template - ACT_HOST_DEVICE constexpr static zZ MakeLayout(Index orgRows, Index orgCols) { - constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - Index rowsRound = RoundUp(orgRows); - Index colsRound = RoundUp(orgCols); - return zZ(orgRows, orgCols, C0_NUM_PER_FRACTAL, - rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, - colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, - colsRound * C0_NUM_PER_FRACTAL, 1, ELE_NUM_PER_FRACTAL); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const { - return LongIndex(coord.row()) / shape_[0] * stride_[1] + - LongIndex(coord.column()) / shape_[2] * stride_[3]; - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const { return shape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) { return shape_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const { return stride_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) { return stride_[idx]; } + // Methods + + /// Constructor + ACT_HOST_DEVICE constexpr zZ( + Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + ACT_HOST_DEVICE constexpr zZ(OrgShape orgShape, Shape shape, Stride stride) + : orgShape_(orgShape), shape_(shape), stride_(stride) + {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE constexpr static zZ MakeLayout(Index orgRows, Index orgCols) + { + constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return zZ(orgRows, orgCols, C0_NUM_PER_FRACTAL, rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, + colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, colsRound * C0_NUM_PER_FRACTAL, 1, ELE_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } private: - /// Origin Shape data member - OrgShape orgShape_; + /// Origin Shape data member + OrgShape orgShape_; - /// Shape data member - Shape shape_; + /// Shape data member + Shape shape_; - /// Stride data member - Stride stride_; + /// Stride data member + Stride stride_; }; /// Mapping function for padding rowmajor matrices @@ -624,111 +752,137 @@ struct zZ { /// blocks and also row-major between blocks. struct PaddingRowMajor { public: - /// Logical rank of tensor - static constexpr int RANK = 4; + /// Logical rank of tensor + static constexpr int RANK = 4; - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; - /// Index type used for coordinates - using Index = uint32_t; + /// Index type used for coordinates + using Index = uint32_t; - /// Long index type used for offsets - using LongIndex = int64_t; + /// Long index type used for offsets + using LongIndex = int64_t; - /// Logical coordinate - using OrgShape = Coord; + /// Logical coordinate + using OrgShape = Coord; - /// Logical coordinate - using Shape = Coord; + /// Logical coordinate + using Shape = Coord; - /// Stride vector - using Stride = Coord; + /// Stride vector + using Stride = Coord; public: - /// Constructor - ACT_HOST_DEVICE - PaddingRowMajor(Index orgRows, Index orgCols, Index blockRows, - Index blockCols) - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, - CeilDiv(orgCols, blockCols))), - stride_(MakeCoord( - (LongIndex)blockCols, - (LongIndex)blockRows * (LongIndex)RoundUp(orgCols, blockCols), - (LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols)) {} - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const { - LongIndex blockRows = (LongIndex)shape_[0]; - LongIndex blockCols = (LongIndex)shape_[2]; - return (LongIndex)coord.row() / blockRows * stride_[1] + - (LongIndex)coord.column() / blockCols * stride_[3] + - (LongIndex)coord.row() % blockRows * stride_[0] + - (LongIndex)coord.column() % blockCols; - } - - ACT_HOST_DEVICE - PaddingRowMajor GetTileLayout(MatrixCoord const &tileShape) const { - return PaddingRowMajor(tileShape.row(), tileShape.column(), shape_[0], - shape_[2]); - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const { return shape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) { return shape_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const { return stride_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) { return stride_[idx]; } + /// Constructor + ACT_HOST_DEVICE + PaddingRowMajor(Index orgRows, Index orgCols, Index blockRows, Index blockCols) + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, CeilDiv(orgCols, blockCols))), + stride_(MakeCoord((LongIndex)blockCols, (LongIndex)blockRows * (LongIndex)RoundUp(orgCols, blockCols), + (LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols)) + {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + LongIndex blockRows = (LongIndex)shape_[0]; + LongIndex blockCols = (LongIndex)shape_[2]; + return (LongIndex)coord.row() / blockRows * stride_[1] + (LongIndex)coord.column() / blockCols * stride_[3] + + (LongIndex)coord.row() % blockRows * stride_[0] + (LongIndex)coord.column() % blockCols; + } + + ACT_HOST_DEVICE + PaddingRowMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return PaddingRowMajor(tileShape.row(), tileShape.column(), shape_[0], shape_[2]); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } private: - // - // Data members - // + // + // Data members + // - /// Origin Shape data member - OrgShape orgShape_; + /// Origin Shape data member + OrgShape orgShape_; - /// Shape data member - Shape shape_; + /// Shape data member + Shape shape_; - /// Stride data member - Stride stride_; + /// Stride data member + Stride stride_; }; /// Mapping function for padding columnmajor matrices @@ -737,111 +891,137 @@ struct PaddingRowMajor { /// blocks and also column-major between blocks. struct PaddingColumnMajor { public: - /// Logical rank of tensor - static constexpr int RANK = 4; + /// Logical rank of tensor + static constexpr int RANK = 4; - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; - /// Index type used for coordinates - using Index = uint32_t; + /// Index type used for coordinates + using Index = uint32_t; - /// Long index type used for offsets - using LongIndex = int64_t; + /// Long index type used for offsets + using LongIndex = int64_t; - /// Logical coordinate - using OrgShape = Coord; + /// Logical coordinate + using OrgShape = Coord; - /// Logical coordinate - using Shape = Coord; + /// Logical coordinate + using Shape = Coord; - /// Stride vector - using Stride = Coord; + /// Stride vector + using Stride = Coord; public: - /// Constructor - ACT_HOST_DEVICE - PaddingColumnMajor(Index orgRows, Index orgCols, Index blockRows, - Index blockCols) - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, - CeilDiv(orgCols, blockCols))), - stride_(MakeCoord( - (LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols, - (LongIndex)blockRows, - (LongIndex)RoundUp(orgRows, blockRows) * (LongIndex)blockCols)) {} - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const { - LongIndex blockRows = (LongIndex)shape_[0]; - LongIndex blockCols = (LongIndex)shape_[2]; - return (LongIndex)coord.row() / blockRows * stride_[1] + - (LongIndex)coord.column() / blockCols * stride_[3] + - (LongIndex)coord.row() % blockRows + - (LongIndex)coord.column() % blockCols * stride_[2]; - } - - ACT_HOST_DEVICE - PaddingColumnMajor GetTileLayout(MatrixCoord const &tileShape) const { - return PaddingColumnMajor(tileShape.row(), tileShape.column(), shape_[0], - shape_[2]); - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const { return shape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) { return shape_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const { return stride_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) { return stride_[idx]; } + /// Constructor + ACT_HOST_DEVICE + PaddingColumnMajor(Index orgRows, Index orgCols, Index blockRows, Index blockCols) + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, CeilDiv(orgCols, blockCols))), + stride_(MakeCoord((LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols, (LongIndex)blockRows, + (LongIndex)RoundUp(orgRows, blockRows) * (LongIndex)blockCols)) + {} + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + LongIndex blockRows = (LongIndex)shape_[0]; + LongIndex blockCols = (LongIndex)shape_[2]; + return (LongIndex)coord.row() / blockRows * stride_[1] + (LongIndex)coord.column() / blockCols * stride_[3] + + (LongIndex)coord.row() % blockRows + (LongIndex)coord.column() % blockCols * stride_[2]; + } + + ACT_HOST_DEVICE + PaddingColumnMajor GetTileLayout(MatrixCoord const &tileShape) const + { + return PaddingColumnMajor(tileShape.row(), tileShape.column(), shape_[0], shape_[2]); + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } private: - // - // Data members - // + // + // Data members + // - /// Origin Shape data member - OrgShape orgShape_; + /// Origin Shape data member + OrgShape orgShape_; - /// Shape data member - Shape shape_; + /// Shape data member + Shape shape_; - /// Stride data member - Stride stride_; + /// Stride data member + Stride stride_; }; /////////////////////// @@ -849,134 +1029,156 @@ struct PaddingColumnMajor { // nN layout struct nN { public: - /// Logical rank of tensor - static constexpr int RANK = 4; + /// Logical rank of tensor + static constexpr int RANK = 4; - /// Index type used for coordinates - using Index = uint32_t; + /// Index type used for coordinates + using Index = uint32_t; - /// Long index type used for offsets - using LongIndex = int64_t; + /// Long index type used for offsets + using LongIndex = int64_t; - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; + /// Logical rank of orgshape + static constexpr int ORG_SHAPE_RANK = 2; - /// Logical coordinate - using OrgShape = Coord; + /// Logical coordinate + using OrgShape = Coord; - /// Logical coordinate - using Shape = Coord; + /// Logical coordinate + using Shape = Coord; - /// Stride vector - using Stride = Coord; + /// Stride vector + using Stride = Coord; public: - // Methods - - /// Constructor - ACT_HOST_DEVICE - nN(Index orgRows = 0, /// Number of rows of origin matrices - Index orgCols = 0, /// Number of cols of origin matrices - - Index rowsInFractal = 0, /// Number of rows inside the fractal - Index rowsByFractal = 0, /// number of rows by the fractal - Index colsInFractal = 0, /// number of cols inside the fractal - Index colsByFractal = 0, /// number of cols by the fractal - - LongIndex strideRowsInFractal = - 0, /// number of elements between adjacent rows inside the fractal - LongIndex strideRowsByFractal = - 0, /// number of elements between adjacent fractal rows - LongIndex strideColsInFractal = - 0, /// number of elements between adjacent cols inside the fractal - LongIndex strideColsByFractal = - 0) /// number of elements between adjacent fractal cols - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, - colsByFractal)), - stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, - strideColsInFractal, strideColsByFractal)) {} - - /// Ctor - ACT_HOST_DEVICE - nN(OrgShape orgShape, Shape shape, Stride stride) - : orgShape_(orgShape), shape_(shape), stride_(stride) {} - - /// Make the layout of a coordinate (row, column) - template - ACT_HOST_DEVICE static nN MakeLayout(Index orgRows, Index orgCols) { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = - BYTE_PER_FRACTAL / sizeof(Element); - Index rowsRound = RoundUp(orgRows); - Index colsRound = RoundUp(orgCols); - return nN(orgRows, orgCols, - - ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, C0_NUM_PER_FRACTAL, - colsRound / C0_NUM_PER_FRACTAL, - - 1, ELE_NUM_PER_FRACTAL, ELE_NUM_PER_C0, - rowsRound * C0_NUM_PER_FRACTAL); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const { - return LongIndex(coord.row()) / shape_[0] * stride_[1] + - LongIndex(coord.column()) / shape_[2] * stride_[3]; - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const { return orgShape_[idx]; } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) { return orgShape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const { return shape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) { return shape_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const { return stride_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) { return stride_[idx]; } + // Methods + + /// Constructor + ACT_HOST_DEVICE + nN(Index orgRows = 0, /// Number of rows of origin matrices + Index orgCols = 0, /// Number of cols of origin matrices + + Index rowsInFractal = 0, /// Number of rows inside the fractal + Index rowsByFractal = 0, /// number of rows by the fractal + Index colsInFractal = 0, /// number of cols inside the fractal + Index colsByFractal = 0, /// number of cols by the fractal + + LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal + LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows + LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal + LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols + : orgShape_(MakeCoord(orgRows, orgCols)), + shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), + stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) + {} + + /// Ctor + ACT_HOST_DEVICE + nN(OrgShape orgShape, Shape shape, Stride stride) : orgShape_(orgShape), shape_(shape), stride_(stride) {} + + /// Make the layout of a coordinate (row, column) + template + ACT_HOST_DEVICE static nN MakeLayout(Index orgRows, Index orgCols) + { + static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); + static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); + Index rowsRound = RoundUp(orgRows); + Index colsRound = RoundUp(orgCols); + return nN(orgRows, orgCols, + + ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, C0_NUM_PER_FRACTAL, colsRound / C0_NUM_PER_FRACTAL, + + 1, ELE_NUM_PER_FRACTAL, ELE_NUM_PER_C0, rowsRound * C0_NUM_PER_FRACTAL); + } + + /// Returns the offset of a coordinate in linear memory. + /// Assumes coordinate has convention (row, column) + ACT_HOST_DEVICE + LongIndex GetOffset(MatrixCoord const &coord) const + { + return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index orgShape(int idx) const + { + return orgShape_[idx]; + } + + /// Returns the origin shape of the layout + ACT_HOST_DEVICE + typename OrgShape::Index &orgShape(int idx) + { + return orgShape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } private: - /// Origin Shape data member - OrgShape orgShape_; + /// Origin Shape data member + OrgShape orgShape_; - /// Shape data member - Shape shape_; + /// Shape data member + Shape shape_; - /// Stride data member - Stride stride_; + /// Stride data member + Stride stride_; }; -} // namespace Act::layout +} // namespace Act::layout -#endif // ACT_LAYOUT_MATRIX_HPP +#endif // ACT_LAYOUT_MATRIX_HPP diff --git a/act/layout/vector.hpp b/act/layout/vector.hpp index 286d0648..8b62f92a 100644 --- a/act/layout/vector.hpp +++ b/act/layout/vector.hpp @@ -20,89 +20,114 @@ namespace Act::layout { struct VectorLayout { public: - /// Logical rank of tensor - static constexpr int RANK = 1; + /// Logical rank of tensor + static constexpr int RANK = 1; - /// Index type used for coordinates - using Index = uint32_t; + /// Index type used for coordinates + using Index = uint32_t; - /// Long index type used for offsets - using LongIndex = int64_t; + /// Long index type used for offsets + using LongIndex = int64_t; - /// Shape vector - using Shape = Coord; + /// Shape vector + using Shape = Coord; - /// Stride vector - using Stride = Coord; + /// Stride vector + using Stride = Coord; - /// Logical coordinate - using TensorCoord = Coord; + /// Logical coordinate + using TensorCoord = Coord; public: - // Methods - - ACT_HOST_DEVICE - VectorLayout(Index size = 0) - : shape_(MakeCoord(size)), stride_(MakeCoord(LongIndex(1))) {} - - ACT_HOST_DEVICE - VectorLayout(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} - - template - ACT_HOST_DEVICE static VectorLayout - MakeLayoutInUb(TensorCoord const &tileShape) { - return VectorLayout{RoundUp(tileShape[0])}; - } - - ACT_HOST_DEVICE - LongIndex GetOffset(TensorCoord const &coord) const { - return stride_[0] * coord[0]; - } - - /// Returns the layout of a tile. - ACT_HOST_DEVICE - VectorLayout GetTileLayout(TensorCoord const &tileShape) const { - return VectorLayout(tileShape, stride()); - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() { return shape_; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const { return shape_[idx]; } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) { return shape_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() { return stride_; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const { return stride_[idx]; } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) { return stride_[idx]; } + // Methods + + ACT_HOST_DEVICE + VectorLayout(Index size = 0) : shape_(MakeCoord(size)), stride_(MakeCoord(LongIndex(1))) {} + + ACT_HOST_DEVICE + VectorLayout(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} + + template + ACT_HOST_DEVICE static VectorLayout MakeLayoutInUb(TensorCoord const &tileShape) + { + return VectorLayout{RoundUp(tileShape[0])}; + } + + ACT_HOST_DEVICE + LongIndex GetOffset(TensorCoord const &coord) const + { + return stride_[0] * coord[0]; + } + + /// Returns the layout of a tile. + ACT_HOST_DEVICE + VectorLayout GetTileLayout(TensorCoord const &tileShape) const + { + return VectorLayout(tileShape, stride()); + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape shape() const + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + Shape &shape() + { + return shape_; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index shape(int idx) const + { + return shape_[idx]; + } + + /// Returns the shape of the layout + ACT_HOST_DEVICE + typename Shape::Index &shape(int idx) + { + return shape_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride stride() const + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + Stride &stride() + { + return stride_; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index stride(int idx) const + { + return stride_[idx]; + } + + /// Returns the stride of the layout + ACT_HOST_DEVICE + typename Stride::Index &stride(int idx) + { + return stride_[idx]; + } private: - /// Stride data member - Shape shape_; - Stride stride_; + /// Stride data member + Shape shape_; + Stride stride_; }; -} // namespace Act::layout +} // namespace Act::layout -#endif // ACT_LAYOUT_VECTOR_HPP +#endif // ACT_LAYOUT_VECTOR_HPP diff --git a/act/matrix_coord.hpp b/act/matrix_coord.hpp index bad9545b..a9018db4 100644 --- a/act/matrix_coord.hpp +++ b/act/matrix_coord.hpp @@ -17,82 +17,99 @@ namespace Act { -template struct MatrixShape { - static constexpr uint32_t ROW = ROW_; - static constexpr uint32_t COLUMN = COLUMN_; - - static constexpr int64_t COUNT = ROW * COLUMN; - - ACT_HOST_DEVICE - static Coord<2> ToCoord() { return MakeCoord(ROW, COLUMN); } +template +struct MatrixShape { + static constexpr uint32_t ROW = ROW_; + static constexpr uint32_t COLUMN = COLUMN_; + + static constexpr int64_t COUNT = ROW * COLUMN; + + ACT_HOST_DEVICE + static Coord<2> ToCoord() + { + return MakeCoord(ROW, COLUMN); + } }; /// MatrixCoord wraps Coord<2, uint32_t> to provide a helper for accessing named /// dimensions. Classes expecting a coordinate in the rank=2 index space of a /// matrix should use MatrixCoord. struct MatrixCoord : public Coord<2, uint32_t> { - /// Integer-valued index - using Index = uint32_t; - - /// Base type is a Coord of rank=2 - using Base = Coord<2, Index>; - - /// LongIndex type - using LongIndex = typename Base::LongIndex; - - /// Rows dimension - static constexpr uint32_t ROW_INDEX = 0; - - /// Columns dimension - static constexpr uint32_t COLUMN_INDEX = 1; - - /// Default ctor - ACT_HOST_DEVICE - MatrixCoord() {} - - /// Constructs from Coord<2> - ACT_HOST_DEVICE - MatrixCoord(Coord<2, Index> const &coord) : Base(coord) {} - - /// Helper to construct from a row and column - ACT_HOST_DEVICE - MatrixCoord(Index row, Index column) : Base(MakeCoord(row, column)) {} - - /// Helper to construct from a row and column, which are LongIndex based - ACT_HOST_DEVICE - MatrixCoord(LongIndex row, LongIndex column) - : Base(MakeCoord(Index(row), Index(column))) {} - - /// Returns the row of the coordinate - ACT_HOST_DEVICE - Index const &row() const { return this->At(ROW_INDEX); } - - /// Returns the row of the coordinate - ACT_HOST_DEVICE - Index &row() { return this->At(ROW_INDEX); } - - /// Returns the column of the coordinate - ACT_HOST_DEVICE - Index const &column() const { return this->At(COLUMN_INDEX); } - - /// Returns the column of the coordinate - ACT_HOST_DEVICE - Index &column() { return this->At(COLUMN_INDEX); } - - /// Element-wise addition - ACT_HOST_DEVICE - MatrixCoord operator+(Base const &b) const { - return MatrixCoord(Base::operator+(b)); - } - - /// In-place addition - ACT_HOST_DEVICE - MatrixCoord &operator+=(Base const &b) { - Base::operator+=(b); - return *this; - } + /// Integer-valued index + using Index = uint32_t; + + /// Base type is a Coord of rank=2 + using Base = Coord<2, Index>; + + /// LongIndex type + using LongIndex = typename Base::LongIndex; + + /// Rows dimension + static constexpr uint32_t ROW_INDEX = 0; + + /// Columns dimension + static constexpr uint32_t COLUMN_INDEX = 1; + + /// Default ctor + ACT_HOST_DEVICE + MatrixCoord() {} + + /// Constructs from Coord<2> + ACT_HOST_DEVICE + MatrixCoord(Coord<2, Index> const &coord) : Base(coord) {} + + /// Helper to construct from a row and column + ACT_HOST_DEVICE + MatrixCoord(Index row, Index column) : Base(MakeCoord(row, column)) {} + + /// Helper to construct from a row and column, which are LongIndex based + ACT_HOST_DEVICE + MatrixCoord(LongIndex row, LongIndex column) : Base(MakeCoord(Index(row), Index(column))) {} + + /// Returns the row of the coordinate + ACT_HOST_DEVICE + Index const &row() const + { + return this->At(ROW_INDEX); + } + + /// Returns the row of the coordinate + ACT_HOST_DEVICE + Index &row() + { + return this->At(ROW_INDEX); + } + + /// Returns the column of the coordinate + ACT_HOST_DEVICE + Index const &column() const + { + return this->At(COLUMN_INDEX); + } + + /// Returns the column of the coordinate + ACT_HOST_DEVICE + Index &column() + { + return this->At(COLUMN_INDEX); + } + + /// Element-wise addition + ACT_HOST_DEVICE + MatrixCoord operator+(Base const &b) const + { + return MatrixCoord(Base::operator+(b)); + } + + /// In-place addition + ACT_HOST_DEVICE + MatrixCoord &operator+=(Base const &b) + { + Base::operator+=(b); + return *this; + } }; -} // namespace Act +} // namespace Act #endif From 3e4610f5c44024aa43328a94856547a7a9dce44b Mon Sep 17 00:00:00 2001 From: KanielZhou <36097092+kaniel-outis@users.noreply.github.com> Date: Fri, 26 Sep 2025 17:05:44 +0800 Subject: [PATCH 3/6] Delete act directory --- act/act.hpp | 37 - act/arch/arch.hpp | 54 - act/arch/cross_core_sync.hpp | 115 -- act/arch/local_tensor_buffer.hpp | 231 ---- act/arch/resource.hpp | 44 - act/coord.hpp | 311 ----- act/detail/alignment.hpp | 57 - act/detail/callback.hpp | 63 - act/detail/dependent_false.hpp | 22 - act/detail/macros.hpp | 20 - act/detail/tag_to_layout.hpp | 80 -- act/epilogue/block/block_epilogue.hpp | 29 - .../block_epilogue_per_token_dequant.hpp | 763 ----------- act/epilogue/dispatch_policy.hpp | 76 -- act/epilogue/tile/copy_gm_to_ub.hpp | 156 --- act/epilogue/tile/copy_ub_to_gm.hpp | 115 -- .../tile/tile_broadcast_inplace_by_column.hpp | 64 - .../tile/tile_broadcast_inplace_by_row.hpp | 57 - act/epilogue/tile/tile_broadcast_mul.hpp | 122 -- act/epilogue/tile/tile_broadcast_one_blk.hpp | 51 - act/epilogue/tile/tile_cast.hpp | 45 - act/epilogue/tile/tile_copy.hpp | 104 -- act/epilogue/tile/tile_elemwise_add.hpp | 48 - act/epilogue/tile/tile_elemwise_mul.hpp | 47 - act/epilogue/tile/tile_elemwise_muls.hpp | 38 - act/epilogue/tile/tile_swizzle.hpp | 92 -- act/gemm/block/block_mmad.hpp | 57 - ...block_mmad_preload_async_with_callback.hpp | 410 ------ act/gemm/block/block_swizzle.hpp | 243 ---- act/gemm/dispatch_policy.hpp | 88 -- act/gemm/gemm_type.hpp | 29 - act/gemm/helper.hpp | 280 ---- ...per_token_dequant_multistage_workspace.hpp | 362 ----- act/gemm/tile/copy_gm_to_l1.hpp | 798 ----------- act/gemm/tile/copy_gm_to_ub.hpp | 53 - act/gemm/tile/copy_l0c_to_gm.hpp | 219 --- act/gemm/tile/copy_l1_to_l0a.hpp | 392 ------ act/gemm/tile/copy_l1_to_l0b.hpp | 537 -------- act/gemm/tile/copy_ub_to_gm.hpp | 80 -- act/gemm/tile/tile_copy.hpp | 183 --- act/gemm/tile/tile_mmad.hpp | 110 -- act/gemm_coord.hpp | 159 --- act/gemv_coord.hpp | 107 -- act/layout/layout.hpp | 20 - act/layout/matrix.hpp | 1184 ----------------- act/layout/vector.hpp | 133 -- act/matrix_coord.hpp | 115 -- 47 files changed, 8400 deletions(-) delete mode 100644 act/act.hpp delete mode 100644 act/arch/arch.hpp delete mode 100644 act/arch/cross_core_sync.hpp delete mode 100644 act/arch/local_tensor_buffer.hpp delete mode 100644 act/arch/resource.hpp delete mode 100644 act/coord.hpp delete mode 100644 act/detail/alignment.hpp delete mode 100644 act/detail/callback.hpp delete mode 100644 act/detail/dependent_false.hpp delete mode 100644 act/detail/macros.hpp delete mode 100644 act/detail/tag_to_layout.hpp delete mode 100644 act/epilogue/block/block_epilogue.hpp delete mode 100644 act/epilogue/block/block_epilogue_per_token_dequant.hpp delete mode 100644 act/epilogue/dispatch_policy.hpp delete mode 100644 act/epilogue/tile/copy_gm_to_ub.hpp delete mode 100644 act/epilogue/tile/copy_ub_to_gm.hpp delete mode 100644 act/epilogue/tile/tile_broadcast_inplace_by_column.hpp delete mode 100644 act/epilogue/tile/tile_broadcast_inplace_by_row.hpp delete mode 100644 act/epilogue/tile/tile_broadcast_mul.hpp delete mode 100644 act/epilogue/tile/tile_broadcast_one_blk.hpp delete mode 100644 act/epilogue/tile/tile_cast.hpp delete mode 100644 act/epilogue/tile/tile_copy.hpp delete mode 100644 act/epilogue/tile/tile_elemwise_add.hpp delete mode 100644 act/epilogue/tile/tile_elemwise_mul.hpp delete mode 100644 act/epilogue/tile/tile_elemwise_muls.hpp delete mode 100644 act/epilogue/tile/tile_swizzle.hpp delete mode 100644 act/gemm/block/block_mmad.hpp delete mode 100644 act/gemm/block/block_mmad_preload_async_with_callback.hpp delete mode 100644 act/gemm/block/block_swizzle.hpp delete mode 100644 act/gemm/dispatch_policy.hpp delete mode 100644 act/gemm/gemm_type.hpp delete mode 100644 act/gemm/helper.hpp delete mode 100644 act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp delete mode 100644 act/gemm/tile/copy_gm_to_l1.hpp delete mode 100644 act/gemm/tile/copy_gm_to_ub.hpp delete mode 100644 act/gemm/tile/copy_l0c_to_gm.hpp delete mode 100644 act/gemm/tile/copy_l1_to_l0a.hpp delete mode 100644 act/gemm/tile/copy_l1_to_l0b.hpp delete mode 100644 act/gemm/tile/copy_ub_to_gm.hpp delete mode 100644 act/gemm/tile/tile_copy.hpp delete mode 100644 act/gemm/tile/tile_mmad.hpp delete mode 100644 act/gemm_coord.hpp delete mode 100644 act/gemv_coord.hpp delete mode 100644 act/layout/layout.hpp delete mode 100644 act/layout/matrix.hpp delete mode 100644 act/layout/vector.hpp delete mode 100644 act/matrix_coord.hpp diff --git a/act/act.hpp b/act/act.hpp deleted file mode 100644 index 2e5fab8b..00000000 --- a/act/act.hpp +++ /dev/null @@ -1,37 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_ACT_HPP -#define ACT_ACT_HPP - -#include - -#include "../act/detail/alignment.hpp" -#include "../act/detail/dependent_false.hpp" -#include "../act/detail/macros.hpp" - -namespace Act { - -constexpr uint32_t BYTE_PER_C0 = 32; -constexpr uint32_t C0_NUM_PER_FRACTAL = 16; -constexpr uint32_t BYTE_PER_FRACTAL = BYTE_PER_C0 * C0_NUM_PER_FRACTAL; - -constexpr uint32_t BYTE_PER_BLK = 32; -constexpr uint32_t BLK_NUM_PER_VECTOR_FRACTAL = 8; -constexpr uint32_t BYTE_PER_VECTOR_FRACTAL = BYTE_PER_BLK * BLK_NUM_PER_VECTOR_FRACTAL; - -constexpr uint64_t L2_OFFSET = 0; -constexpr uint32_t STRIDE_LIMIT = 65536; - -} // namespace Act - -#endif // ACT_ACT_HPP diff --git a/act/arch/arch.hpp b/act/arch/arch.hpp deleted file mode 100644 index f1bb8727..00000000 --- a/act/arch/arch.hpp +++ /dev/null @@ -1,54 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_ARCH_ARCH_HPP -#define ACT_ARCH_ARCH_HPP - -namespace Act::Arch { - -struct AtlasA2 { - static constexpr uint32_t BIAS_SIZE = 1024; - static constexpr uint32_t FIXBUF_SIZE = 7 * 1024; - static constexpr uint32_t UB_SIZE = 192 * 1024; - static constexpr uint32_t L1_SIZE = 512 * 1024; - static constexpr uint32_t L0A_SIZE = 64 * 1024; - static constexpr uint32_t L0B_SIZE = 64 * 1024; - static constexpr uint32_t L0C_SIZE = 128 * 1024; -}; - -struct PositionGM { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::GM; -}; - -struct PositionL1 { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::A1; -}; - -struct PositionL0A { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::A2; -}; - -struct PositionL0B { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::B2; -}; - -struct PositionL0C { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::CO1; -}; - -struct PositionUB { - static constexpr AscendC::TPosition POSITION = AscendC::TPosition::VECCALC; -}; - -} // namespace Act::Arch - -#endif // ACT_ARCH_ARCH_HPP diff --git a/act/arch/cross_core_sync.hpp b/act/arch/cross_core_sync.hpp deleted file mode 100644 index 72099c4e..00000000 --- a/act/arch/cross_core_sync.hpp +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_ARCH_CROSS_CORE_SYNC_HPP -#define ACT_ARCH_CROSS_CORE_SYNC_HPP - -#include "../../act/act.hpp" - -namespace Act::Arch { - -constexpr uint32_t MAX_REVERSE_DEPTH = 16; - -using FlagID = uint16_t; -constexpr FlagID AIV_INTER_BLOCK_BARRIER = 8; -constexpr FlagID AIC_INTER_BLOCK_BARRIER = 9; -constexpr FlagID AIV_INTER_SUBBLOCK_BARRIER = 10; -constexpr FlagID FFTS_MAX_FLAG = 7; - -struct CrossCoreFlag { - ACT_DEVICE - CrossCoreFlag() : id(0) {} - - ACT_DEVICE - CrossCoreFlag(FlagID id) : id(id) {} - - FlagID id; -}; - -template -struct CrossCoreFlagWithReverse { - ACT_DEVICE - CrossCoreFlagWithReverse() : id(0), reverseId(0) {} - - ACT_DEVICE - CrossCoreFlagWithReverse(FlagID id, FlagID reverseId) : id(id), reverseId(reverseId) {} - - FlagID id; - FlagID reverseId; - uint32_t count{0}; -}; - -template -struct BarrierFlag { - static_assert(MODE != MODE, - "Unsupporteded cross core barrier flag, can not " - "find the specialization."); -}; - -template <> -struct BarrierFlag<0x0, AscendC::AIV> { - static constexpr FlagID ID = AIV_INTER_BLOCK_BARRIER; -}; - -template <> -struct BarrierFlag<0x0, AscendC::AIC> { - static constexpr FlagID ID = AIC_INTER_BLOCK_BARRIER; -}; - -template <> -struct BarrierFlag<0x1, AscendC::AIV> { - static constexpr FlagID ID = AIV_INTER_SUBBLOCK_BARRIER; -}; - -template -ACT_DEVICE void CrossCoreBarrier() -{ - constexpr FlagID flagId = BarrierFlag::ID; - AscendC::CrossCoreSetFlag(flagId); - AscendC::CrossCoreWaitFlag(flagId); -} - -template -ACT_DEVICE void CrossCoreSetFlag(CrossCoreFlag &flag) -{ - AscendC::CrossCoreSetFlag(flag.id); -} - -ACT_DEVICE -void CrossCoreWaitFlag(CrossCoreFlag &flag) -{ - AscendC::CrossCoreWaitFlag(flag.id); -} - -template -ACT_DEVICE void CrossCoreSetFlagWithReverse(CrossCoreFlagWithReverse &flag) -{ - AscendC::CrossCoreSetFlag(flag.id); - if (++flag.count >= REVERSE_DEPTH) { - AscendC::CrossCoreWaitFlag(flag.reverseId); - flag.count = 0; - } -} - -template -ACT_DEVICE void CrossCoreWaitFlagWithReverse(CrossCoreFlagWithReverse &flag) -{ - AscendC::CrossCoreWaitFlag(flag.id); - if (++flag.count >= REVERSE_DEPTH) { - AscendC::CrossCoreSetFlag(flag.reverseId); - flag.count = 0; - } -} - -} // namespace Act::Arch - -#endif // ACT_ARCH_CROSS_CORE_SYNC_HPP diff --git a/act/arch/local_tensor_buffer.hpp b/act/arch/local_tensor_buffer.hpp deleted file mode 100644 index 5208153f..00000000 --- a/act/arch/local_tensor_buffer.hpp +++ /dev/null @@ -1,231 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef INCLUDE_ACT_ARCH_MEMORY_H -#define INCLUDE_ACT_ARCH_MEMORY_H - -#include "../../act/act.hpp" -#include "../../act/arch/arch.hpp" - -namespace Act::Arch { - -struct LocalTensorBufferBase { -public: - template - ACT_DEVICE AscendC::LocalTensor GetBufferByByte(const uint32_t offset) const - { - return tensor[offset].template ReinterpretCast(); - } - -protected: - ACT_DEVICE - LocalTensorBufferBase() = default; - - AscendC::LocalTensor tensor; -}; - -template -struct LocalTensorBuffer { - static_assert(DEPENDENT_FALSE, "Unsupporteded local tensor buffer, can not find the specialization."); -}; - -/// Partial specialization for TPosition::A1 -template -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::A1; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufA1; - GetTPipePtr()->InitBuffer(tbufA1, ArchTag::L1_SIZE); - tensor = tbufA1.Get(); - } -}; - -/////////////////////////////////////////////////////////// - -/// Partial specialization for TPosition::A2 -template -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::A2; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufA2; - GetTPipePtr()->InitBuffer(tbufA2, ArchTag::L0A_SIZE); - tensor = tbufA2.Get(); - } -}; - -/////////////////////////////////////////////////////////// - -/// Partial specialization for TPosition::B1 -template -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::B1; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufB1; - GetTPipePtr()->InitBuffer(tbufB1, ArchTag::L1_SIZE); - tensor = tbufB1.Get(); - } -}; - -/////////////////////////////////////////////////////////// - -/// Partial specialization for AtlasA2, TPosition::B2 -template -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::B2; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufB2; - GetTPipePtr()->InitBuffer(tbufB2, ArchTag::L0B_SIZE); - tensor = tbufB2.Get(); - } -}; - -/////////////////////////////////////////////////////////// - -/// Partial specialization for AtlasA2, TPosition::C1 -template <> -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - using ArchTag = Arch::AtlasA2; - static constexpr AscendC::TPosition Position = AscendC::TPosition::C1; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufC1; - GetTPipePtr()->InitBuffer(tbufC1, ArchTag::L1_SIZE); - tensor = tbufC1.Get(); - } -}; - -/////////////////////////////////////////////////////////// - -/// Partial specialization for AtlasA2, TPosition::C2 -template <> -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - using ArchTag = Arch::AtlasA2; - static constexpr AscendC::TPosition Position = AscendC::TPosition::C2; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufC2; - GetTPipePtr()->InitBuffer(tbufC2, ArchTag::BIAS_SIZE); - tensor = tbufC2.Get(); - } -}; - -/////////////////////////////////////////////////////////// - -/// Partial specialization for TPosition::CO1 -template -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::CO1; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufCO1; - GetTPipePtr()->InitBuffer(tbufCO1, ArchTag::L0C_SIZE); - tensor = tbufCO1.Get(); - } -}; - -/////////////////////////////////////////////////////////// - -/// Partial specialization for AtlasA2, TPosition::C2PIPE2GM -template <> -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - using ArchTag = Arch::AtlasA2; - static constexpr AscendC::TPosition Position = AscendC::TPosition::C2PIPE2GM; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufC2PIPE2GM; - GetTPipePtr()->InitBuffer(tbufC2PIPE2GM, ArchTag::FIXBUF_SIZE); - tensor = tbufC2PIPE2GM.Get(); - } -}; - -/////////////////////////////////////////////////////////// - -/// Partial specialization for TPosition::VECIN -template -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::VECIN; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufVECIN; - GetTPipePtr()->InitBuffer(tbufVECIN, ArchTag::UB_SIZE); - tensor = tbufVECIN.Get(); - } -}; - -/////////////////////////////////////////////////////////// - -/// Partial specialization for TPosition::VECOUT -template -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::VECOUT; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufVECOUT; - GetTPipePtr()->InitBuffer(tbufVECOUT, ArchTag::UB_SIZE); - tensor = tbufVECOUT.Get(); - } -}; - -/////////////////////////////////////////////////////////// - -/// Partial specialization for TPosition::VECCALC -template -struct LocalTensorBuffer : LocalTensorBufferBase { -public: - static constexpr AscendC::TPosition Position = AscendC::TPosition::VECCALC; - - ACT_DEVICE - LocalTensorBuffer() - { - AscendC::TBuf tbufVECCALC; - GetTPipePtr()->InitBuffer(tbufVECCALC, ArchTag::UB_SIZE); - tensor = tbufVECCALC.Get(); - } -}; - -} // namespace Act::Arch - -#endif // INCLUDE_ACT_ARCH_MEMORY_H diff --git a/act/arch/resource.hpp b/act/arch/resource.hpp deleted file mode 100644 index 71367981..00000000 --- a/act/arch/resource.hpp +++ /dev/null @@ -1,44 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef INCLUDE_ACT_ARCH_RESOURCE_HPP -#define INCLUDE_ACT_ARCH_RESOURCE_HPP - -#include "../../act/act.hpp" -#include "../../act/arch/local_tensor_buffer.hpp" - -namespace Act::Arch { - -template -struct Resource { -public: - AscendC::TPipe pipe; - - LocalTensorBuffer l1Buf; - LocalTensorBuffer l0ABuf; - LocalTensorBuffer l0BBuf; - LocalTensorBuffer l0CBuf; - LocalTensorBuffer ubBuf; - - ACT_DEVICE - Resource() - { - // The initialization of AscendC::Tpipe will insert some synchronization - // interfaces, which may conflict with the usage by users. Therefore, the - // "destroy" interface is used for releasing. - pipe.Destroy(); - } -}; - -} // namespace Act::Arch - -#endif // INCLUDE_ACT_ARCH_RESOURCE_HPP diff --git a/act/coord.hpp b/act/coord.hpp deleted file mode 100644 index 5faf5be6..00000000 --- a/act/coord.hpp +++ /dev/null @@ -1,311 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_COORD_HPP -#define ACT_COORD_HPP - -#include "../act/act.hpp" - -namespace Act { - -/// Statically-sized array specifying Coords within a tensor -template -struct Coord { -public: - // Number of elements in Coord - static const int RANK = RANK_; - - // Index typen used to store elements - using Index = Index_; - - // Type used to represent linear offsets - using LongIndex = LongIndex_; - - // Default ctor initializes uniformly - ACT_HOST_DEVICE constexpr explicit Coord(Index value = Index(0)) - { - for (int i = 0; i < RANK; ++i) { - idx[i] = value; - } - } - - // Constructs from an array of integers - ACT_HOST_DEVICE constexpr Coord(Index const (&idx_)[RANK]) - { - for (int i = 0; i < RANK; ++i) { - idx[i] = idx_[i]; - } - } - - // Constructs from an array of integers - ACT_HOST_DEVICE - int Argmin() const - { - int i = 0; - for (int j = 1; j < RANK; ++j) { - if (idx[j] < idx[i]) { - i = j; - } - } - return i; - } - - // Returns the index of the dimension with greatest value - ACT_HOST_DEVICE - int Argmax() const - { - int i = 0; - for (int j = 1; j < RANK; ++j) { - if (idx[j] > idx[i]) { - i = j; - } - } - return i; - } - - // Returns true if Coord is non-zero - ACT_HOST_DEVICE - explicit operator bool() const - { - for (int i = 0; i < RANK; ++i) { - if (idx[i]) { - return true; - } - } - return false; - } - - // Return true if Coord is uniformly zero. - ACT_HOST_DEVICE - bool operator!() const - { - for (int i = 0; i < RANK; ++i) { - if (idx[i]) { - return false; - } - } - return true; - } - - // Element-wise addition - ACT_HOST_DEVICE - Coord operator+(Coord const &b) const - { - Coord c; - for (int i = 0; i < RANK; ++i) { - c.idx[i] = idx[i] + b.idx[i]; - } - return c; - } - - // Add a scalar to each element - ACT_HOST_DEVICE - Coord operator+(const Index val) const - { - Coord c; - for (int i = 0; i < RANK; ++i) { - c.idx[i] = idx[i] + val; - } - return c; - } - - // Element-wise subtraction - ACT_HOST_DEVICE - Coord operator-(Coord const &b) const - { - Coord c; - for (int i = 0; i < RANK; i++) { - c.idx[i] = idx[i] - b.idx[i]; - } - return c; - } - - // Subtract a scalar from each element - ACT_HOST_DEVICE - Coord operator-(Index const val) const - { - Coord c; - for (int i = 0; i < RANK; ++i) { - c.idx[i] = idx[i] - val; - } - return c; - } - - // Element-wise multiply - ACT_HOST_DEVICE - Coord operator*(Coord const &b) const - { - Coord c; - for (int i = 0; i < RANK; i++) { - c.idx[i] = idx[i] * b.idx[i]; - } - return c; - } - - // Element-wise division - ACT_HOST_DEVICE - Coord operator/(Coord const &b) const - { - Coord c; - for (int i = 0; i < RANK; i++) { - c.idx[i] = idx[i] / b.idx[i]; - } - return c; - } - - // Element-wise mod - ACT_HOST_DEVICE - Coord operator%(Coord const &b) const - { - Coord c; - for (int i = 0; i < RANK; i++) { - c.idx[i] = idx[i] % b.idx[i]; - } - return c; - } - - // In-place addition - ACT_HOST_DEVICE - Coord &operator+=(Coord const &b) - { - for (int i = 0; i < RANK; ++i) { - idx[i] += b.idx[i]; - } - return *this; - } - - // In-place equal - ACT_HOST_DEVICE - bool operator==(Coord const &b) const - { - for (int i = 0; i < RANK; ++i) { - if (idx[i] != b.idx[i]) { - return false; - } - } - return true; - } - - // In-place equal - ACT_HOST_DEVICE - bool operator==(Index const val) const - { - for (int i = 0; i < RANK; ++i) { - if (idx[i] != val) { - return false; - } - } - return true; - } - - // Member access operator - ACT_HOST_DEVICE - Index &operator[](int dim) - { - return idx[dim]; - } - - // Member access operator - ACT_HOST_DEVICE - Index const &operator[](int dim) const - { - return idx[dim]; - } - - // Gets the index of a given Coord element - template - ACT_HOST_DEVICE Index &At() - { - return idx[DIM]; - } - - // Access via index; may limit unrolling potential - ACT_HOST_DEVICE - Index &At(int dim) - { - return idx[dim]; - } - - // Gets the index of a given Coord element - template - ACT_HOST_DEVICE Index const &At() const - { - return idx[DIM]; - } - - // Access via index; may limit unrolling potential - ACT_HOST_DEVICE - Index const &At(int dim) const - { - return idx[dim]; - } - - template - ACT_HOST_DEVICE auto GetCoordByAxis() const - { - Index idx_[sizeof...(Is)]{idx[Is]...}; - return Coord{idx_}; - } - - ACT_HOST_DEVICE - static Coord Min(Coord const &a, Coord const &b) - { - Coord res; - for (int i = 0; i < RANK; ++i) { - res[i] = a[i] < b[i] ? a[i] : b[i]; - } - return res; - } - -private: - // Indices - Index idx[RANK]; -}; - -// Helper to make a 1-element coordinate -template -ACT_HOST_DEVICE constexpr Coord<1, T> MakeCoord(T dim0) -{ - T values[1] = {dim0}; - return Coord<1, T>(values); -} - -/// Helper to make a 2-element coordinate -template -ACT_HOST_DEVICE constexpr Coord<2, T> MakeCoord(T dim0, T dim1) -{ - T values[2] = {dim0, dim1}; - return Coord<2, T>(values); -} - -/// Helper to make a 3-element coordinate -template -ACT_HOST_DEVICE constexpr Coord<3, T> MakeCoord(T dim0, T dim1, T dim2) -{ - T values[3] = {dim0, dim1, dim2}; - return Coord<3, T>(values); -} - -/// Helper to make a 4-element coordinate -template -ACT_HOST_DEVICE constexpr Coord<4, T> MakeCoord(T dim0, T dim1, T dim2, T dim3) -{ - T values[4] = {dim0, dim1, dim2, dim3}; - return Coord<4, T>(values); -} - -} // namespace Act - -#endif // ACT_COORD_HPP diff --git a/act/detail/alignment.hpp b/act/detail/alignment.hpp deleted file mode 100644 index db40e7ba..00000000 --- a/act/detail/alignment.hpp +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_ALIGNMENT_HPP -#define ACT_ALIGNMENT_HPP - -#include "../../act/detail/macros.hpp" - -template -ACT_HOST_DEVICE constexpr T RoundUp(const T &val) -{ - static_assert(ALIGN != 0, "ALIGN must not be 0"); - return (val + ALIGN - 1) / ALIGN * ALIGN; -} - -template -ACT_HOST_DEVICE constexpr T RoundUp(const T &val, const T align) -{ - return (val + align - 1) / align * align; -} - -template -ACT_HOST_DEVICE constexpr T RoundDown(const T val) -{ - static_assert(ALIGN != 0, "ALIGN must not be 0"); - return val / ALIGN * ALIGN; -} - -template -ACT_HOST_DEVICE constexpr T RoundDown(const T val, const T align) -{ - return val / align * align; -} - -template -ACT_HOST_DEVICE constexpr T CeilDiv(const T dividend) -{ - static_assert(DIVISOP != 0, "DIVISOP must not be 0"); - return (dividend + DIVISOP - 1) / DIVISOP; -} - -template -ACT_HOST_DEVICE constexpr T CeilDiv(const T dividend, const T divisor) -{ - return (dividend + divisor - 1) / divisor; -} - -#endif // ACT_ALIGNMENT_HPP diff --git a/act/detail/callback.hpp b/act/detail/callback.hpp deleted file mode 100644 index 7475213c..00000000 --- a/act/detail/callback.hpp +++ /dev/null @@ -1,63 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_DETAIL_CALLBACK_HPP -#define ACT_DETAIL_CALLBACK_HPP - -#include "../../act/detail/macros.hpp" - -/// @brief Callback is an alternative to std::function, providing a -/// general carrier of callable structure with no parameters and no return -/// value. Compared with function pointers of type void (*)(), Callback can -/// carry lambda expressions with captures, and does not need to pay attention -/// to the captured content. It should be noted that Callback itself does not -/// store the callable structure it carries like std::function, so -/// it is necessary to ensure that it is used within the life cycle of the -/// callable structure. -struct Callback { - void const *func{nullptr}; - void (*caller)(void const *){nullptr}; - - Callback() = default; - - ACT_DEVICE - void operator()() const - { - if (func) { - caller(func); - } - } - - ACT_DEVICE - operator bool() const - { - return func != nullptr; - } -}; - -template -ACT_DEVICE static void FuncWrapper(void const *func) -{ - (*static_cast(func))(); -} - -// Use this to make a callback -template -ACT_DEVICE Callback MakeCallback(Func *func) -{ - Callback callback; - callback.func = func; - callback.caller = &FuncWrapper; - return callback; -} - -#endif // ACT_DETAIL_CALLBACK_HPP diff --git a/act/detail/dependent_false.hpp b/act/detail/dependent_false.hpp deleted file mode 100644 index c9985a05..00000000 --- a/act/detail/dependent_false.hpp +++ /dev/null @@ -1,22 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_DETAIL_DEPENDENT_FALSE_HPP -#define ACT_DETAIL_DEPENDENT_FALSE_HPP - -template -constexpr bool DEPENDENT_BOOL_VALUE = VALUE; - -template -constexpr bool DEPENDENT_FALSE = DEPENDENT_BOOL_VALUE; - -#endif // ACT_DETAIL_DEPENDENT_FALSE_HPP diff --git a/act/detail/macros.hpp b/act/detail/macros.hpp deleted file mode 100644 index a2825344..00000000 --- a/act/detail/macros.hpp +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_DETAIL_MACROS_HPP -#define ACT_DETAIL_MACROS_HPP - -#define ACT_DEVICE __forceinline__[aicore] -#define ACT_HOST_DEVICE __forceinline__[host, aicore] -#define ACT_GLOBAL __global__[aicore] - -#endif // ACT_DETAIL_MACROS_HPP diff --git a/act/detail/tag_to_layout.hpp b/act/detail/tag_to_layout.hpp deleted file mode 100644 index 033a4ee4..00000000 --- a/act/detail/tag_to_layout.hpp +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_DETAIL_TAG_TO_LAYOUT_HPP -#define ACT_DETAIL_TAG_TO_LAYOUT_HPP - -#include "../../act/layout/layout.hpp" -#include "../../tla/layout.hpp" - -using namespace tla; -//////////////////////////////////////////////////////////////////////////////////////////////////// - -namespace Act::detail { -//////////////////////////////////////////////////////////////////////////////////////////////////// -// For each Act::layout, provides its corresponding tla layout types -template -struct TagToLayout { - using type = LayoutTag; -}; - -template -struct TagToLayout { - using type = Layout, Stride>, Shape>; -}; - -template -struct TagToLayout { - using type = Layout, Stride, int64_t>, Shape>; -}; - -template -struct TagToLayout { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - using type = Layout, uint32_t>, Shape, uint32_t>>, - Stride, Int>, Stride, int64_t>>, - Shape>; -}; - -template -struct TagToLayout { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - using type = Layout, uint32_t>, Shape, uint32_t>>, - Stride, int64_t>, Stride, Int>>, - Shape>; -}; - -template -struct TagToLayout { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - using type = Layout, uint32_t>, Shape, uint32_t>>, - Stride, int64_t>, Stride, Int>>, - Shape>; -}; - -// Convenience aliases -template -using TagToLayout_t = typename TagToLayout::type; - -constexpr uint32_t ELE_NUM_PER_FRACTAL_L0C = 256; -using LayoutL0C = Layout, uint32_t>, Shape, uint32_t>>, - Stride, Int>, Stride, int64_t>>, - Shape>; - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace Act::detail - -#endif // ACT_DETAIL_TAG_TO_LAYOUT_HPP diff --git a/act/epilogue/block/block_epilogue.hpp b/act/epilogue/block/block_epilogue.hpp deleted file mode 100644 index bb7a6ac6..00000000 --- a/act/epilogue/block/block_epilogue.hpp +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP -#define ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP - -#include "../../../act/act.hpp" - -namespace Act::Epilogue::Block { - -template -class BlockEpilogue -{ - static_assert(DEPENDENT_FALSE, "Could not find an epilogue specialization"); -}; - -} // namespace Act::Epilogue::Block - -#include "../../../act/epilogue/block/block_epilogue_per_token_dequant.hpp" -#endif // ACT_EPILOGUE_BLOCK_BLOCK_EPILOGUE_HPP diff --git a/act/epilogue/block/block_epilogue_per_token_dequant.hpp b/act/epilogue/block/block_epilogue_per_token_dequant.hpp deleted file mode 100644 index dee41a8c..00000000 --- a/act/epilogue/block/block_epilogue_per_token_dequant.hpp +++ /dev/null @@ -1,763 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP -#define ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP - -#include "../../../../cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h" -#include "../../../act/act.hpp" -#include "../../../act/arch/resource.hpp" -#include "../../../act/detail/callback.hpp" -#include "../../../act/epilogue/dispatch_policy.hpp" -#include "../../../act/gemm_coord.hpp" -#include "../../../act/layout/layout.hpp" -#include "../../../act/matrix_coord.hpp" - -#define ENABLE_EP_SEND_COUNT_HASH 0 - -namespace Act::Epilogue::Block { - -template -class BlockEpilogue, CType_, ScaleType_, PerTokenScaleType_, - DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, TileOneBlkColumnBroadcastMul_, TileCopy_, - EpilogueTileSwizzle_> -{ -public: - using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; - using ArchTag = typename DispatchPolicy::ArchTag; - static constexpr uint32_t UB_STAGES = UB_STAGES_; - - // Data infos - using ElementC = typename CType_::Element; - using LayoutC = typename CType_::Layout; - using ElementScale = typename ScaleType_::Element; - using LayoutScale = typename ScaleType_::Layout; - using ElementPerTokenScale = typename PerTokenScaleType_::Element; - using LayoutPerTokenScale = typename PerTokenScaleType_::Layout; - using ElementD = typename DType_::Element; - using LayoutD = typename DType_::Layout; - - // Check data infos - static_assert(std::is_same_v && - (std::is_same_v || std::is_same_v) && - std::is_same_v && std::is_same_v, - "The element type template parameters of BlockEpilogue are wrong"); - static_assert(std::is_same_v && std::is_same_v && - std::is_same_v && - std::is_same_v, - "The layout template parameters of BlockEpilogue are wrong"); - - // Tile compute ops - using TileRowBroadcastMul = TileRowBroadcastMul_; - using TileBroadcastOneBlk = TileBroadcastOneBlk_; - using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; - - // Tile copy - using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; - using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; - using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; - using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; - - using EpilogueTileSwizzle = EpilogueTileSwizzle_; - - using TileShape = typename TileRowBroadcastMul::TileShape; - - static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && - std::is_same_v, - "TileShape must be consistent for all tile compute ops"); - - static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + - TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + - (TileShape::COUNT + TileShape::COLUMN + TileShape::COUNT + TileShape::ROW) * sizeof(float) + - TileShape::ROW * BYTE_PER_BLK) <= ArchTag::UB_SIZE, - "TileShape is too large to fit in UB"); - - struct Params { - __gm__ ElementScale *ptrScale{nullptr}; - LayoutScale layoutScale{}; - __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; - LayoutPerTokenScale layoutPerTokenScale{}; - __gm__ ElementD *ptrD{nullptr}; - LayoutD layoutD{}; - - ACT_DEVICE - Params() {}; - - ACT_DEVICE - Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, - __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, - __gm__ ElementD *ptrD_, LayoutD const &layoutD_) - : ptrScale(ptrScale_), - layoutScale(layoutScale_), - ptrPerTokenScale(ptrPerTokenScale_), - layoutPerTokenScale(layoutPerTokenScale_), - ptrD(ptrD_), - layoutD(layoutD_) - {} - }; - - ACT_DEVICE - BlockEpilogue(Arch::Resource const &resource, Params const ¶ms = Params{}) : params(params) - { - size_t ubOffset = 0; - int32_t eventVMTE2 = 0; - int32_t eventMTE2V = 0; - int32_t eventMTE3V = 0; - int32_t eventVMTE3 = 0; - for (uint32_t i = 0; i < UB_STAGES; ++i) { - ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementC); - ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(ElementScale); - ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); - ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementD); - - eventUbCVMTE2List[i] = eventVMTE2++; - eventUbCMTE2VList[i] = eventMTE2V++; - eventUbScaleVMTE2List[i] = eventVMTE2++; - eventUbScaleMTE2VList[i] = eventMTE2V++; - eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; - eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; - eventUbDMTE3VList[i] = eventMTE3V++; - eventUbDVMTE3List[i] = eventVMTE3++; - - AscendC::SetFlag(eventUbCVMTE2List[i]); - AscendC::SetFlag(eventUbScaleVMTE2List[i]); - AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); - AscendC::SetFlag(eventUbDMTE3VList[i]); - } - ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(float); - ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubPerTokenScaleFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * sizeof(float); - ubPerTokenScaleFp32Brcb = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * BYTE_PER_BLK; - ubPerTokenMul = ubMul; - } - - ACT_DEVICE - ~BlockEpilogue() - { - for (uint32_t i = 0; i < UB_STAGES; ++i) { - AscendC::WaitFlag(eventUbCVMTE2List[i]); - AscendC::WaitFlag(eventUbScaleVMTE2List[i]); - AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); - AscendC::WaitFlag(eventUbDMTE3VList[i]); - } - } - - ACT_DEVICE - void UpdateParams(Params const ¶ms_) - { - params = params_; - } - - ACT_DEVICE - void operator()(GemmCoord const &blockShapeMNK, GemmCoord const &blockCoordMNK, - GemmCoord const &actualBlockShapeMNK, AscendC::GlobalTensor const &gmBlockC, - LayoutC const &layoutBlockC, Callback &&callback = Callback{}) - { - if (actualBlockShapeMNK.k() == 0) { - return; - } - callback(); - - // Calculate the offset of the current block - MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); - MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); - MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); - MatrixCoord blockOffset = blockCoord * blockShape; - - AscendC::GlobalTensor gmScale; - gmScale.SetGlobalBuffer(params.ptrScale); - AscendC::GlobalTensor gmPerTokenScale; - gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); - AscendC::GlobalTensor gmD; - gmD.SetGlobalBuffer(params.ptrD); - - auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); - auto tileShape = TileShape::ToCoord(); - EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); - uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); - uint32_t subblockIdx = AscendC::GetSubBlockIdx(); - uint32_t subblockNum = AscendC::GetSubBlockNum(); - for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { - auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); - auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); - auto tileOffsetInBlock = tileCoord * tileShape; - auto tileOffset = blockOffset + tileOffsetInBlock; - - auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; - auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); - - auto &ubC = ubCList[ubListId]; - LayoutC layoutUbC{actualTileShape, ubTileStride}; - - AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); - copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); - AscendC::SetFlag(eventUbCMTE2VList[ubListId]); - - auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); - auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); - - auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; - auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - - auto &ubScale = ubScaleList[ubListId]; - auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); - - AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); - copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); - AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); - - auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); - auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); - - auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; - auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); - - auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; - auto layoutUbPerTokenScale = - LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); - - AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); - copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, - layoutGmTilePerTokenScale); - AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); - - AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); - AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); - AscendC::SetFlag(eventUbCVMTE2List[ubListId]); - - AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); - AscendC::Cast(ubScaleFp32, ubScale, AscendC::RoundMode::CAST_NONE, TileShape::COLUMN); - AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); - - AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); - AscendC::Cast(ubPerTokenScaleFp32, ubPerTokenScale, AscendC::RoundMode::CAST_NONE, TileShape::ROW); - AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); - - tileRowBroadcastMul(ubMul, ubCFp32, ubScaleFp32); - tileBroadcastOneBlk(ubPerTokenScaleFp32Brcb, ubPerTokenScaleFp32); - AscendC::PipeBarrier(); - tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleFp32Brcb); - AscendC::PipeBarrier(); - - auto &ubD = ubDList[ubListId]; - LayoutD layoutUbD{actualTileShape, ubTileStride}; - - AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); - AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); - AscendC::SetFlag(eventUbDVMTE3List[ubListId]); - - auto gmTileD = gmD[params.layoutD.GetOffset(tileOffset)]; - auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); - - AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); - copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); - AscendC::SetFlag(eventUbDMTE3VList[ubListId]); - - ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; - } - } - -private: - Params params; - - AscendC::LocalTensor ubCList[UB_STAGES]; - AscendC::LocalTensor ubScaleList[UB_STAGES]; - AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; - AscendC::LocalTensor ubDList[UB_STAGES]; - - int32_t eventUbCVMTE2List[UB_STAGES]; - int32_t eventUbCMTE2VList[UB_STAGES]; - int32_t eventUbScaleVMTE2List[UB_STAGES]; - int32_t eventUbScaleMTE2VList[UB_STAGES]; - int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; - int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; - int32_t eventUbDMTE3VList[UB_STAGES]; - int32_t eventUbDVMTE3List[UB_STAGES]; - - uint32_t ubListId{0}; - - AscendC::LocalTensor ubCFp32; - AscendC::LocalTensor ubScaleFp32; - AscendC::LocalTensor ubMul; - AscendC::LocalTensor ubPerTokenScaleFp32; - AscendC::LocalTensor ubPerTokenScaleFp32Brcb; - AscendC::LocalTensor ubPerTokenMul; - - TileRowBroadcastMul tileRowBroadcastMul; - TileBroadcastOneBlk tileBroadcastOneBlk; - TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; - - CopyGmToUbC copyGmToUbC; - CopyGmToUbScale copyGmToUbScale; - CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; - CopyUbToGmD copyUbToGmD; -}; - -template -class BlockEpilogue, CType_, Gemm::GemmType, - Gemm::GemmType, DType_, TileRowBroadcastMul_, TileBroadcastOneBlk_, - TileOneBlkColumnBroadcastMul_, TileCopy_, EpilogueTileSwizzle_> -{ -public: - using DispatchPolicy = EpilogueAtlasA2PerTokenDequant; - using ArchTag = typename DispatchPolicy::ArchTag; - static constexpr uint32_t UB_STAGES = UB_STAGES_; - static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; - - // Data infos - using ElementC = typename CType_::Element; - using LayoutC = typename CType_::Layout; - using ElementScale = float; - using LayoutScale = LayoutScale_; - using ElementPerTokenScale = float; - using LayoutPerTokenScale = LayoutPerTokenScale_; - using ElementD = typename DType_::Element; - using LayoutD = typename DType_::Layout; - - // Check data infos - static_assert(std::is_same_v && - (std::is_same_v || std::is_same_v), - "The element type template parameters of BlockEpilogue are wrong"); - static_assert(std::is_same_v && std::is_same_v && - std::is_same_v && - std::is_same_v, - "The layout template parameters of BlockEpilogue are wrong"); - - // Tile compute ops - using TileRowBroadcastMul = TileRowBroadcastMul_; - using TileBroadcastOneBlk = TileBroadcastOneBlk_; - using TileOneBlkColumnBroadcastMul = TileOneBlkColumnBroadcastMul_; - - // Tile copy - using CopyGmToUbC = typename TileCopy_::CopyGmToUbC; - using CopyGmToUbScale = typename TileCopy_::CopyGmToUbX; - using CopyGmToUbPerTokenScale = typename TileCopy_::CopyGmToUbY; - using CopyUbToGmD = typename TileCopy_::CopyUbToGmD; - - using EpilogueTileSwizzle = EpilogueTileSwizzle_; - - using TileShape = typename TileRowBroadcastMul::TileShape; - - static_assert(TileShape::ROW == TileBroadcastOneBlk::COMPUTE_LENGTH && - std::is_same_v, - "TileShape must be consistent for all tile compute ops"); - - static_assert((UB_STAGES * (TileShape::COUNT * sizeof(ElementC) + TileShape::COLUMN * sizeof(ElementScale) + - TileShape::ROW * sizeof(ElementPerTokenScale) + TileShape::COUNT * sizeof(ElementD)) + - (TileShape::COUNT + TileShape::COUNT) * sizeof(float) + TileShape::ROW * BYTE_PER_BLK) <= - ArchTag::UB_SIZE, - "TileShape is too large to fit in UB"); - - struct Params { - __gm__ ElementScale *ptrScale{nullptr}; - LayoutScale layoutScale{}; - __gm__ ElementPerTokenScale *ptrPerTokenScale{nullptr}; - LayoutPerTokenScale layoutPerTokenScale{}; - __gm__ ElementD *ptrD{nullptr}; - LayoutD layoutD{}; - - ACT_DEVICE - Params() {}; - - ACT_DEVICE - Params(__gm__ ElementScale *ptrScale_, LayoutScale const &layoutScale_, - __gm__ ElementPerTokenScale *ptrPerTokenScale_, LayoutPerTokenScale const &layoutPerTokenScale_, - __gm__ ElementD *ptrD_, LayoutD const &layoutD_) - : ptrScale(ptrScale_), - layoutScale(layoutScale_), - ptrPerTokenScale(ptrPerTokenScale_), - layoutPerTokenScale(layoutPerTokenScale_), - ptrD(ptrD_), - layoutD(layoutD_) - {} - }; - - ACT_DEVICE void AlignUbOffset() - { - size_t ubMask = ubOffset & (MoeDistributeCombineImpl::UB_ALIGN - 1); - if (ubMask != 0) { - ubOffset += MoeDistributeCombineImpl::UB_ALIGN - ubMask; - } - } - - ACT_DEVICE - BlockEpilogue(Arch::Resource &resource, MoeDistributeCombineImpl::CombineCalcInfo &calcInfo, - Params const ¶ms = Params{}) - : resource(resource), calcInfo(calcInfo), params(params) - { - for (uint32_t i = 0; i < UB_STAGES; ++i) { - ubCList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementC); - ubScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COLUMN * sizeof(ElementScale); - ubPerTokenScaleList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * sizeof(ElementPerTokenScale); - ubDList[i] = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(ElementD); - - eventUbCVMTE2List[i] = eventVMTE2++; - eventUbCMTE2VList[i] = eventMTE2V++; - eventUbScaleVMTE2List[i] = eventVMTE2++; - eventUbScaleMTE2VList[i] = eventMTE2V++; - eventUbPerTokenScaleVMTE2List[i] = eventVMTE2++; - eventUbPerTokenScaleMTE2VList[i] = eventMTE2V++; - eventUbDMTE3VList[i] = eventMTE3V++; - eventUbDVMTE3List[i] = eventVMTE3++; - - AscendC::SetFlag(eventUbCVMTE2List[i]); - AscendC::SetFlag(eventUbScaleVMTE2List[i]); - AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[i]); - AscendC::SetFlag(eventUbDMTE3VList[i]); - } - ubCFp32 = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubMul = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::COUNT * sizeof(float); - ubPerTokenScaleBrcb = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += TileShape::ROW * BYTE_PER_BLK; - ubPerTokenMul = ubCFp32; - - if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - AlignUbOffset(); - epSendCountLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); - ubOffset += calcInfo.moeSendNum_ * sizeof(int32_t); - AlignUbOffset(); - AscendC::GlobalTensor epSendCountGM; - epSendCountGM.SetGlobalBuffer((__gm__ int32_t *)calcInfo.epSendCount_); - uint32_t epSendCountSize = calcInfo.isSharedExpert_ ? calcInfo.epWorldSize_ : calcInfo.moeSendNum_; - AscendC::DataCopyExtParams epSendCntParams = {1U, static_cast(epSendCountSize * sizeof(uint32_t)), - 0U, 0U, 0U}; - AscendC::DataCopyPadExtParams copyPadParams{false, 0U, 0U, 0U}; - AscendC::DataCopyPad(epSendCountLocal_, epSendCountGM, epSendCntParams, copyPadParams); - AscendC::SetFlag(eventMTE2S); - AscendC::WaitFlag(eventMTE2S); -#if ENABLE_EP_SEND_COUNT_HASH - tokenToEpRankHashLocal_ = resource.ubBuf.template GetBufferByByte(ubOffset); - uint32_t maxGroupSendCount = 0; - uint32_t groupSendCount = 0; - for (uint32_t expertIdx = 0; expertIdx < calcInfo.moeExpertPerRankNum_; ++expertIdx) { - uint32_t prevGroupSendCount = groupSendCount; - groupSendCount = epSendCountLocal_.GetValue((expertIdx + 1) * calcInfo.epWorldSize_ - 1); - if (maxGroupSendCount < groupSendCount - prevGroupSendCount) { - maxGroupSendCount = groupSendCount - prevGroupSendCount; - } - } - ubOffset += maxGroupSendCount * sizeof(int32_t); - AlignUbOffset(); - // assert: ubOffset <= AscendC::TOTAL_UB_SIZE or - // AscendC::TOTAL_VEC_LOCAL_SIZE -#endif - } - } - - ACT_DEVICE - ~BlockEpilogue() - { - for (uint32_t i = 0; i < UB_STAGES; ++i) { - AscendC::WaitFlag(eventUbCVMTE2List[i]); - AscendC::WaitFlag(eventUbScaleVMTE2List[i]); - AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[i]); - AscendC::WaitFlag(eventUbDMTE3VList[i]); - } - } - - ACT_DEVICE - void UpdateParams(Params const ¶ms_) - { - params = params_; - } - - ACT_DEVICE GM_ADDR GetWinAddrByRankId(const int32_t rankId, const uint8_t expertLocalId = 0U) - { - return (GM_ADDR)((calcInfo.epRankId_ == rankId) - ? calcInfo.epWinContext_->localWindowsIn - : ((HcclRankRelationResV2 *)(calcInfo.epWinContext_->remoteRes[rankId].nextDevicePtr)) - ->windowsIn) + - calcInfo.winDataSizeOffset_ + expertLocalId * calcInfo.expertPerSizeOnWin_ + rankId * OPT_RANK_OFFSET; - } -#if ENABLE_EP_SEND_COUNT_HASH - ACT_DEVICE void InitTokenToEpRankHashLocalForEpRank(uint32_t &hashOffset, uint32_t epRank, uint32_t copyLen) - { - constexpr uint32_t DUPLICATE_MASK_COUNT = 8; - uint32_t hashOffsetMask = (((uint32_t)hashOffset) & (DUPLICATE_MASK_COUNT - 1)); - if (hashOffsetMask != 0) { - uint32_t remainMaskCount = DUPLICATE_MASK_COUNT - hashOffsetMask; - if (copyLen < remainMaskCount) { - remainMaskCount = copyLen; - } - uint64_t copyMask = ((1UL << remainMaskCount) - 1) << hashOffsetMask; - AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset - hashOffsetMask], epRank, ©Mask, 1, 1, - DUPLICATE_MASK_COUNT); - hashOffset += remainMaskCount; - copyLen -= remainMaskCount; - } - if (copyLen > 0) { - AscendC::Duplicate(tokenToEpRankHashLocal_[hashOffset], epRank, copyLen); - hashOffset += copyLen; - } - } -#endif - - ACT_DEVICE void SetCombineSendEpRank(uint32_t epRank, uint32_t &remoteEpRank, uint32_t &localEpRank) - { - if ((calcInfo.isSharedExpert_) && (epRank < calcInfo.sharedExpertRankNum_)) { - remoteEpRank = calcInfo.epRankId_; - localEpRank = epRank; - } else { - remoteEpRank = epRank; - localEpRank = calcInfo.epRankId_; - } - } - - ACT_DEVICE void DoCombineSend(AscendC::LocalTensor &ubD, layout::RowMajor &layoutGmTileD, - LayoutD &layoutUbD, int64_t groupOffsetD, uint32_t expertIdx, uint32_t tileOffsetD) - { - const uint32_t copyTokenLen = layoutGmTileD.shape(1) * sizeof(ElementD); - const uint32_t copyTokenSrcStride = - (layoutUbD.stride(0) - layoutUbD.shape(1)) / (BYTE_PER_C0 / sizeof(ElementD)); - const uint32_t copyTokenDstStride = (layoutGmTileD.stride(0) - layoutGmTileD.shape(1)) * sizeof(ElementD); - - int64_t offsetD = groupOffsetD + tileOffsetD; - uint32_t startToken = offsetD / calcInfo.axisH_; - uint32_t tokenOffset = offsetD - startToken * calcInfo.axisH_; - uint32_t itToken = startToken; - uint32_t endToken = startToken + layoutGmTileD.shape(0); -#if ENABLE_EP_SEND_COUNT_HASH - uint32_t epRankStart = tokenToEpRankHashLocal_(itToken - startToken); -#else - constexpr uint32_t epRankStart = 0; -#endif - uint32_t sendCount = - expertIdx == 0 && epRankStart == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset + epRankStart - 1); - for (uint32_t epRank = epRankStart; epRank < calcInfo.epWorldSize_ && itToken < endToken; ++epRank) { - uint32_t prevSendCount = sendCount; - sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); - if (prevSendCount <= itToken && itToken < sendCount) { - uint32_t copyTokenCount = (sendCount < endToken ? sendCount : endToken) - itToken; - AscendC::DataCopyExtParams dataCopyParams(copyTokenCount, copyTokenLen, copyTokenSrcStride, - copyTokenDstStride, 0); - uint32_t remoteEpRank; - uint32_t localEpRank; - SetCombineSendEpRank(epRank, remoteEpRank, localEpRank); - GM_ADDR rankGM = GetWinAddrByRankId(remoteEpRank, expertIdx) + - localEpRank * calcInfo.moeExpertPerRankNum_ * calcInfo.expertPerSizeOnWin_; - AscendC::GlobalTensor rankWindow; - rankWindow.SetGlobalBuffer((__gm__ ElementD *)rankGM); - AscendC::DataCopyPad(rankWindow[(itToken - prevSendCount) * calcInfo.axisH_ + tokenOffset], - ubD[(itToken - startToken) * layoutUbD.stride(0)], dataCopyParams); - itToken += copyTokenCount; - } - } - } - - ACT_DEVICE - void operator()(int64_t groupOffsetD, uint32_t expertIdx, GemmCoord const &blockShapeMNK, - GemmCoord const &blockCoordMNK, GemmCoord const &actualBlockShapeMNK, - AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutBlockC, - Callback &&callback = Callback{}) - { - if (actualBlockShapeMNK.k() == 0) { - return; - } - - if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - expertOffset = expertIdx * calcInfo.epWorldSize_; -#if ENABLE_EP_SEND_COUNT_HASH - if (currentExpertIdx_ != expertIdx) { - uint32_t hashOffset = 0; - uint32_t sendCount = expertIdx == 0 ? 0 : epSendCountLocal_.GetValue(expertOffset - 1); - for (uint32_t epRank = 0; epRank < calcInfo.epWorldSize_; ++epRank) { - uint32_t prevSendCount = sendCount; - sendCount = epSendCountLocal_.GetValue(expertOffset + epRank); - InitTokenToEpRankHashLocalForEpRank(hashOffset, epRank, sendCount - prevSendCount); - } - AscendC::SetFlag(eventVS); - AscendC::WaitFlag(eventVS); - currentExpertIdx_ = expertIdx; - } -#endif - } - - callback(); - // Calculate the offset of the current block - MatrixCoord blockShape = blockShapeMNK.GetCoordMN(); - MatrixCoord blockCoord = blockCoordMNK.GetCoordMN(); - MatrixCoord actualBlockShape = actualBlockShapeMNK.GetCoordMN(); - MatrixCoord blockOffset = blockCoord * blockShape; - - AscendC::GlobalTensor gmScale; - gmScale.SetGlobalBuffer(params.ptrScale); - AscendC::GlobalTensor gmPerTokenScale; - gmPerTokenScale.SetGlobalBuffer(params.ptrPerTokenScale); - AscendC::GlobalTensor gmD; - gmD.SetGlobalBuffer(params.ptrD); - - auto ubTileStride = MakeCoord(static_cast(TileShape::COLUMN), 1L); - auto tileShape = TileShape::ToCoord(); - EpilogueTileSwizzle epilogueTileSwizzle(actualBlockShape, tileShape); - uint32_t tileLoops = epilogueTileSwizzle.GetLoops(); - uint32_t subblockIdx = AscendC::GetSubBlockIdx(); - uint32_t subblockNum = AscendC::GetSubBlockNum(); - for (uint32_t loopIdx = subblockIdx; loopIdx < tileLoops; loopIdx += subblockNum) { - auto tileCoord = epilogueTileSwizzle.GetTileCoord(loopIdx); - auto actualTileShape = epilogueTileSwizzle.GetActualTileShape(tileCoord); - auto tileOffsetInBlock = tileCoord * tileShape; - auto tileOffset = blockOffset + tileOffsetInBlock; - - auto gmTileC = gmBlockC[layoutBlockC.GetOffset(tileOffsetInBlock)]; - auto layoutGmTileC = layoutBlockC.GetTileLayout(actualTileShape); - - auto &ubC = ubCList[ubListId]; - LayoutC layoutUbC{actualTileShape, ubTileStride}; - - AscendC::WaitFlag(eventUbCVMTE2List[ubListId]); - copyGmToUbC(ubC, gmTileC, layoutUbC, layoutGmTileC); - AscendC::SetFlag(eventUbCMTE2VList[ubListId]); - - auto scaleTileOffset = tileOffset.template GetCoordByAxis<1>(); - auto scaleTileShape = actualTileShape.template GetCoordByAxis<1>(); - - auto gmTileScale = gmScale[params.layoutScale.GetOffset(scaleTileOffset)]; - auto layoutGmTileScale = params.layoutScale.GetTileLayout(scaleTileShape); - - auto &ubScale = ubScaleList[ubListId]; - auto layoutUbScale = LayoutScale::template MakeLayoutInUb(scaleTileShape); - - AscendC::WaitFlag(eventUbScaleVMTE2List[ubListId]); - copyGmToUbScale(ubScale, gmTileScale, layoutUbScale, layoutGmTileScale); - AscendC::SetFlag(eventUbScaleMTE2VList[ubListId]); - - auto perTokenScaleTileOffset = tileOffset.template GetCoordByAxis<0>(); - auto perTokenScaleTileShape = actualTileShape.template GetCoordByAxis<0>(); - - auto gmTilePerTokenScale = gmPerTokenScale[params.layoutPerTokenScale.GetOffset(perTokenScaleTileOffset)]; - auto layoutGmTilePerTokenScale = params.layoutPerTokenScale.GetTileLayout(perTokenScaleTileShape); - - auto &ubPerTokenScale = ubPerTokenScaleList[ubListId]; - auto layoutUbPerTokenScale = - LayoutScale::template MakeLayoutInUb(perTokenScaleTileShape); - - AscendC::WaitFlag(eventUbPerTokenScaleVMTE2List[ubListId]); - copyGmToUbPerTokenScale(ubPerTokenScale, gmTilePerTokenScale, layoutUbPerTokenScale, - layoutGmTilePerTokenScale); - AscendC::SetFlag(eventUbPerTokenScaleMTE2VList[ubListId]); - - AscendC::WaitFlag(eventUbCMTE2VList[ubListId]); - AscendC::Cast(ubCFp32, ubC, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); - AscendC::SetFlag(eventUbCVMTE2List[ubListId]); - - AscendC::WaitFlag(eventUbScaleMTE2VList[ubListId]); - tileRowBroadcastMul(ubMul, ubCFp32, ubScale); - AscendC::SetFlag(eventUbScaleVMTE2List[ubListId]); - - AscendC::WaitFlag(eventUbPerTokenScaleMTE2VList[ubListId]); - tileBroadcastOneBlk(ubPerTokenScaleBrcb, ubPerTokenScale); - AscendC::SetFlag(eventUbPerTokenScaleVMTE2List[ubListId]); - - AscendC::PipeBarrier(); - tileOneBlkColumnBroadcastMul(ubPerTokenMul, ubMul, ubPerTokenScaleBrcb); - AscendC::PipeBarrier(); - - auto &ubD = ubDList[ubListId]; - LayoutD layoutUbD{actualTileShape, ubTileStride}; - - AscendC::WaitFlag(eventUbDMTE3VList[ubListId]); - AscendC::Cast(ubD, ubPerTokenMul, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); - AscendC::SetFlag(eventUbDVMTE3List[ubListId]); - - auto tileOffsetD = params.layoutD.GetOffset(tileOffset); - auto layoutGmTileD = params.layoutD.GetTileLayout(actualTileShape); - - AscendC::WaitFlag(eventUbDVMTE3List[ubListId]); - - if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - DoCombineSend(ubD, layoutGmTileD, layoutUbD, groupOffsetD, expertIdx, tileOffsetD); - } else { - auto gmTileD = gmD[tileOffsetD]; - copyUbToGmD(gmTileD, ubD, layoutGmTileD, layoutUbD); - } - - AscendC::SetFlag(eventUbDMTE3VList[ubListId]); - - ubListId = (ubListId + 1 < UB_STAGES) ? (ubListId + 1) : 0; - } - } - -private: - Params params; - Arch::Resource &resource; - MoeDistributeCombineImpl::CombineCalcInfo calcInfo; - - AscendC::LocalTensor ubCList[UB_STAGES]; - AscendC::LocalTensor ubScaleList[UB_STAGES]; - AscendC::LocalTensor ubPerTokenScaleList[UB_STAGES]; - AscendC::LocalTensor ubDList[UB_STAGES]; - - int32_t eventUbCVMTE2List[UB_STAGES]; - int32_t eventUbCMTE2VList[UB_STAGES]; - int32_t eventUbScaleVMTE2List[UB_STAGES]; - int32_t eventUbScaleMTE2VList[UB_STAGES]; - int32_t eventUbPerTokenScaleVMTE2List[UB_STAGES]; - int32_t eventUbPerTokenScaleMTE2VList[UB_STAGES]; - int32_t eventUbDMTE3VList[UB_STAGES]; - int32_t eventUbDVMTE3List[UB_STAGES]; - - AscendC::LocalTensor epSendCountLocal_; -#if ENABLE_EP_SEND_COUNT_HASH - AscendC::LocalTensor tokenToEpRankHashLocal_; - uint32_t currentExpertIdx_{static_cast(-1)}; -#endif - - size_t ubOffset{0}; - int32_t eventVMTE2{0}; - int32_t eventMTE2V{0}; - int32_t eventMTE3V{0}; - int32_t eventVMTE3{0}; - int32_t eventVS{0}; - int32_t eventMTE2S{0}; - - uint32_t expertOffset; - - uint32_t ubListId{0}; - - AscendC::LocalTensor ubCFp32; - AscendC::LocalTensor ubMul; - AscendC::LocalTensor ubPerTokenScaleBrcb; - AscendC::LocalTensor ubPerTokenMul; - - TileRowBroadcastMul tileRowBroadcastMul; - TileBroadcastOneBlk tileBroadcastOneBlk; - TileOneBlkColumnBroadcastMul tileOneBlkColumnBroadcastMul; - - CopyGmToUbC copyGmToUbC; - CopyGmToUbScale copyGmToUbScale; - CopyGmToUbPerTokenScale copyGmToUbPerTokenScale; - CopyUbToGmD copyUbToGmD; -}; - -} // namespace Act::Epilogue::Block - -#endif // ACT_EPILOGUE_BLOCK_EPILOGUE_PER_TOKEN_DEQUANT_HPP diff --git a/act/epilogue/dispatch_policy.hpp b/act/epilogue/dispatch_policy.hpp deleted file mode 100644 index 8d93192d..00000000 --- a/act/epilogue/dispatch_policy.hpp +++ /dev/null @@ -1,76 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_DISPATCH_POLICY_HPP -#define ACT_EPILOGUE_DISPATCH_POLICY_HPP - -#include "../../act/arch/arch.hpp" - -namespace Act::Epilogue { - -// For AtlasA2, an element wise epilogue of the form D = C + X, where X is an -// additional source -struct EpilogueAtlasA2ElemWiseOneSource { - using ArchTag = Arch::AtlasA2; - // Number of operands. Including C, X, and D 3 operands - static constexpr uint32_t OPERANDS_NUM = 3; -}; - -// For AtlasA2, FA Softmax -struct EpilogueAtlasA2FASoftmax { - using ArchTag = Arch::AtlasA2; -}; - -// For AtlasA2, FA RescaleO -struct EpilogueAtlasA2FARescaleO { - using ArchTag = Arch::AtlasA2; -}; - -// For AtlasA2, MLA Softmax -struct EpilogueAtlasA2MLASoftmax { - using ArchTag = Arch::AtlasA2; -}; - -// For AtlasA2, MLA RescaleO -struct EpilogueAtlasA2MLARescaleO { - using ArchTag = Arch::AtlasA2; -}; - -// For AtlasA2, MLA FD RescaleO -template -struct EpilogueAtlasA2MLAFDRescaleO { - using ArchTag = Arch::AtlasA2; - static constexpr uint32_t KV_SPLIT_MAX = 64; - static constexpr uint32_t HEADS_PROCESS_MAX = 16; - static constexpr uint32_t COMPUTE_ELE_NUM = COMPUTE_ELE_NUM_; -}; - -// For AtlasA2, MLA TP1 Softmax -struct EpilogueAtlasA2MLATP1Softmax { - using ArchTag = Arch::AtlasA2; -}; - -// For AtlasA2, MLA TP1 RescaleO -struct EpilogueAtlasA2MLATP1RescaleO { - using ArchTag = Arch::AtlasA2; -}; - -// For AtlasA2, per token dequant -template -struct EpilogueAtlasA2PerTokenDequant { - using ArchTag = Arch::AtlasA2; - static constexpr uint32_t UB_STAGES = UB_STAGES_; - static constexpr uint32_t EXEC_FLAG = EXEC_FLAG_; -}; -} // namespace Act::Epilogue - -#endif // ACT_EPILOGUE_DISPATCH_POLICY_HPP diff --git a/act/epilogue/tile/copy_gm_to_ub.hpp b/act/epilogue/tile/copy_gm_to_ub.hpp deleted file mode 100644 index 1a9d3b40..00000000 --- a/act/epilogue/tile/copy_gm_to_ub.hpp +++ /dev/null @@ -1,156 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_COPY_GM_TO_UB_HPP -#define ACT_EPILOGUE_TILE_TILE_COPY_GM_TO_UB_HPP - -#include "../../../act/act.hpp" -#include "../../../act/gemm/gemm_type.hpp" -#include "../../../act/layout/layout.hpp" - -namespace Act::Epilogue::Tile { - -template -struct CopyGm2Ub { - static_assert(DEPENDENT_FALSE, "Unsupporteded copy gm to ub, can not find the specialization."); -}; - -template -struct CopyGm2Ub> { - using LayoutSrc = layout::RowMajor; - using LayoutDst = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - - ACT_DEVICE - CopyGm2Ub() = default; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) - { - AscendC::DataCopyExtParams dataCopyParams(layoutSrc.shape(0), layoutSrc.shape(1) * sizeof(Element), - (layoutSrc.stride(0) - layoutSrc.shape(1)) * sizeof(Element), - (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK, 0); - AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); - AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); - }; -}; - -template -struct CopyGm2Ub> { - using LayoutSrc = layout::VectorLayout; - using LayoutDst = layout::VectorLayout; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - - ACT_DEVICE - CopyGm2Ub() = default; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - layout::VectorLayout const &layoutDst, layout::VectorLayout const &layoutSrc) - { - AscendC::DataCopyExtParams dataCopyParams(1, layoutSrc.shape(0) * sizeof(Element), 0, 0, 0); - AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); - AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); - }; -}; - -/// @brief This copy instruction used to copy per token scale from GM to UB. -/// Copy the scale of shape (m,1) on GM to the first column of shape (m,n) on -/// UB, and pad the first block of each row (i.e. pad to shape (m,8) when -/// element type is float). -/// @tparam ArchTag: Architecture tag. -/// @tparam GmType: Type of data on GM. -template -struct CopyPerTokenScale2Ub { - static_assert(std::is_same_v, - "Unsupporteded layout for CopyPerTokenScale2Ub."); - - using Element = typename GmType::Element; - using LayoutSrc = typename GmType::Layout; - using LayoutDst = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - - ACT_DEVICE - CopyPerTokenScale2Ub() = default; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::DataCopyExtParams dataCopyParams; - AscendC::DataCopyPadExtParams padParams; - - dataCopyParams.blockCount = layoutSrc.shape(0); - dataCopyParams.blockLen = layoutSrc.shape(1) * sizeof(Element); // per token scale has only one column - dataCopyParams.srcStride = 0; - dataCopyParams.dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; - // Pad the data to the complete block - padParams.isPad = true; - padParams.leftPadding = 0; - padParams.rightPadding = 0; - - AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams, padParams); - } -}; - -template -struct CopyGm2UbAligned { - static_assert(DEPENDENT_FALSE, "Unsupporteded copy gm to ub aligned, can not find the specialization."); -}; - -template -struct CopyGm2UbAligned> { - using LayoutSrc = layout::RowMajor; - using LayoutDst = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; - static constexpr uint32_t MAX_REPEAT = 4095; - static constexpr uint32_t STRIDE_LIMIT = 65536; - - ACT_DEVICE - CopyGm2UbAligned() = default; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) - { - uint32_t rows = layoutSrc.shape(0); - uint32_t cols = layoutSrc.shape(1); - uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; - uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; - - if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { - DataCopy(dstTensor, srcTensor, rows * cols); - } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { - uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); - for (uint32_t i = 0; i < rLoops; ++i) { - uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; - AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); - DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], - srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); - } - } else { - for (uint32_t i = 0; i < rows; ++i) { - DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); - } - } - }; -}; - -} // namespace Act::Epilogue::Tile - -#endif diff --git a/act/epilogue/tile/copy_ub_to_gm.hpp b/act/epilogue/tile/copy_ub_to_gm.hpp deleted file mode 100644 index 651f4342..00000000 --- a/act/epilogue/tile/copy_ub_to_gm.hpp +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_COPY_UB_TO_GM_HPP -#define ACT_EPILOGUE_TILE_TILE_COPY_UB_TO_GM_HPP - -#include "../../../act/act.hpp" -#include "../../../act/gemm/gemm_type.hpp" -#include "../../../act/layout/layout.hpp" - -namespace Act::Epilogue::Tile { - -template -struct CopyUb2Gm { - static_assert(DEPENDENT_FALSE, "Unsupporteded copy ub to gm, can not find the specialization."); -}; - -template -struct CopyUb2Gm> { - using LayoutDst = layout::RowMajor; - using LayoutSrc = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyUb2Gm() = default; - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) - { - AscendC::DataCopyExtParams dataCopyParams(layoutDst.shape(0), layoutDst.shape(1) * sizeof(Element), - (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_C0, - (layoutDst.stride(0) - layoutDst.shape(1)) * sizeof(Element), 0); - AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); - } -}; - -// new add vectorlayout version -template -struct CopyUb2Gm> { - using LayoutSrc = layout::VectorLayout; - using LayoutDst = layout::VectorLayout; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - - ACT_DEVICE - CopyUb2Gm() = default; - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - layout::VectorLayout const &layoutDst, layout::VectorLayout const &layoutSrc) - { - AscendC::DataCopyExtParams dataCopyParams(1, layoutDst.shape(0) * sizeof(Element), 0, 0, 0); - AscendC::DataCopyPad(dstTensor, srcTensor, dataCopyParams); - }; -}; - -template -struct CopyUb2GmAligned { - static_assert(DEPENDENT_FALSE, "Unsupporteded copy ub to gm aligned, can not find the specialization."); -}; - -template -struct CopyUb2GmAligned> { - using LayoutSrc = layout::RowMajor; - using LayoutDst = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; - static constexpr uint32_t MAX_REPEAT = 4095; - static constexpr uint32_t STRIDE_LIMIT = 65536; - - ACT_DEVICE - CopyUb2GmAligned() = default; - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - layout::RowMajor const &layoutDst, layout::RowMajor const &layoutSrc) - { - uint32_t rows = layoutDst.shape(0); - uint32_t cols = layoutDst.shape(1); - uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; - uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; - - if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { - DataCopy(dstTensor, srcTensor, rows * cols); - } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { - uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); - for (uint32_t i = 0; i < rLoops; ++i) { - uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; - AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); - DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], - srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); - } - } else { - for (uint32_t i = 0; i < rows; ++i) { - DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); - } - } - }; -}; - -} // namespace Act::Epilogue::Tile - -#endif diff --git a/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp b/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp deleted file mode 100644 index a4a9d8d6..00000000 --- a/act/epilogue/tile/tile_broadcast_inplace_by_column.hpp +++ /dev/null @@ -1,64 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_COLUMN_HPP -#define ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_COLUMN_HPP - -#include "../../../act/act.hpp" - -namespace Act::Epilogue::Tile { - -template < - /// Tag indicating architecture - class ArchTag_, - /// Compute data type - class ComputeType_, - /// Length of the compute buffer - class TileShape_> -struct TileBroadcastInplaceByColumn { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileBroadcastInplaceByColumn() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubInOut) - { - constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); - constexpr uint32_t blkNumPerRow = TileShape::COLUMN / eleNumPerBlk; - - constexpr uint64_t defaultMask = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); - constexpr uint64_t tailMask = (TileShape::ROW % BLK_NUM_PER_VECTOR_FRACTAL) * eleNumPerBlk; - - constexpr uint8_t repeatTimes = 1; - - AscendC::CopyRepeatParams repeatParams; - repeatParams.dstStride = blkNumPerRow; - repeatParams.srcStride = blkNumPerRow; - repeatParams.dstRepeatSize = 1; - repeatParams.srcRepeatSize = 1; - - for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += BLK_NUM_PER_VECTOR_FRACTAL) { - uint64_t mask = ((TileShape::ROW - rowOffset) >= BLK_NUM_PER_VECTOR_FRACTAL) ? defaultMask : tailMask; - for (uint32_t colOffset = eleNumPerBlk; colOffset < TileShape::COLUMN; colOffset += eleNumPerBlk) { - AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN + colOffset], - ubInOut[rowOffset * TileShape::COLUMN], mask, 1, repeatParams); - } - } - } -}; - -} // namespace Act::Epilogue::Tile - -#endif diff --git a/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp b/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp deleted file mode 100644 index 7ea15659..00000000 --- a/act/epilogue/tile/tile_broadcast_inplace_by_row.hpp +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_ROW_HPP -#define ACT_EPILOGUE_TILE_TILE_BROADCAST_INPLACE_BY_ROW_HPP - -#include "../../../act/act.hpp" - -namespace Act::Epilogue::Tile { - -template < - /// Tag indicating architecture - class ArchTag_, - /// Compute data type - class ComputeType_, - /// Length of the compute buffer - class TileShape_> -struct TileBroadcastInplaceByRow { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileBroadcastInplaceByRow() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubInOut) - { - constexpr uint32_t eleNumPerVectorFractal = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); - - constexpr uint64_t mask = eleNumPerVectorFractal; - constexpr uint8_t repeatTimes = TileShape::COLUMN / eleNumPerVectorFractal; - - AscendC::CopyRepeatParams repeatParams; - repeatParams.dstStride = 1; - repeatParams.srcStride = 1; - repeatParams.dstRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; - repeatParams.srcRepeatSize = BLK_NUM_PER_VECTOR_FRACTAL; - - for (uint32_t rowOffset = 1; rowOffset < TileShape::ROW; ++rowOffset) { - AscendC::Copy(ubInOut[rowOffset * TileShape::COLUMN], ubInOut, mask, repeatTimes, repeatParams); - } - } -}; - -} // namespace Act::Epilogue::Tile - -#endif diff --git a/act/epilogue/tile/tile_broadcast_mul.hpp b/act/epilogue/tile/tile_broadcast_mul.hpp deleted file mode 100644 index 93b6125f..00000000 --- a/act/epilogue/tile/tile_broadcast_mul.hpp +++ /dev/null @@ -1,122 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_MUL_HPP -#define ACT_EPILOGUE_TILE_TILE_BROADCAST_MUL_HPP - -#include "../../../act/act.hpp" - -namespace Act::Epilogue::Tile { - -/// BroadcastMul computes the elementwise multiplication of a tensor of shape -/// (m, n) and a tensor of shape (m, n) after broadcasting. There are two -/// broadcast modes: row-broadcast and column-broadcast. - -/// @brief Computes the elementwise multiplication of a tensor with shape (m, n) -/// and a tensor with original shape (1, n) broadcast to (m, n). -/// @tparam ArchTag_ is the architecture tag. -/// @tparam ComputeType_ includes the element type and layout information. -/// @tparam TileShape_ is the shape (m, n). -template -struct TileRowBroadcastMul { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileRowBroadcastMul() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, - AscendC::LocalTensor const &ubIn0, - AscendC::LocalTensor const &ubIn1) - { - constexpr uint32_t maxRepeatTimes = 255; - constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); - - constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; - AscendC::BinaryRepeatParams repeatParams; - repeatParams.dstBlkStride = 1; - repeatParams.src0BlkStride = 1; - repeatParams.src1BlkStride = 1; - repeatParams.dstRepStride = blkNumPerColumn; - repeatParams.src0RepStride = blkNumPerColumn; - repeatParams.src1RepStride = 0; - - constexpr uint32_t rowNumPerCompute = maxRepeatTimes; - constexpr uint32_t colNumPerCompute = BYTE_PER_VECTOR_FRACTAL / sizeof(ElementCompute); - for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { - uint32_t residueM = TileShape::ROW - rowOffset; - uint8_t repeatTimes = static_cast((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM); - for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { - uint32_t residueN = TileShape::COLUMN - colOffset; - uint64_t mask = (residueN > colNumPerCompute) ? colNumPerCompute : residueN; - AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], - ubIn0[rowOffset * TileShape::COLUMN + colOffset], ubIn1[colOffset], mask, repeatTimes, - repeatParams); - } - } - } -}; - -/// @brief Compute the elementwise multiplication of a tensor of shape (m, n) -/// and a tensor of shape (m, eleNumPerBlk), which is broadcast from a tensor of -/// shape (m, 1), broadcast to (m, n). -/// @tparam ArchTag_ is the architecture tag. -/// @tparam ComputeType_ includes the element type and layout information. -/// @tparam TileShape_ is the shape (m, n). -template -struct TileOneBlkColumnBroadcastMul { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileOneBlkColumnBroadcastMul() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, - AscendC::LocalTensor const &ubIn0, - AscendC::LocalTensor const &ubIn1) - { - constexpr uint32_t maxRepeatNum = 255; - constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); - - constexpr uint32_t blkNumPerColumn = TileShape::COLUMN / eleNumPerBlk; - AscendC::BinaryRepeatParams repeatParams; - repeatParams.dstBlkStride = blkNumPerColumn; - repeatParams.src0BlkStride = blkNumPerColumn; - repeatParams.src1BlkStride = 1; - repeatParams.dstRepStride = 1; - repeatParams.src0RepStride = 1; - repeatParams.src1RepStride = 0; - - constexpr uint32_t rowNumPerCompute = BLK_NUM_PER_VECTOR_FRACTAL; - constexpr uint32_t colNumPerCompute = eleNumPerBlk * maxRepeatNum; - for (uint32_t rowOffset = 0; rowOffset < TileShape::ROW; rowOffset += rowNumPerCompute) { - uint32_t residueM = TileShape::ROW - rowOffset; - uint64_t mask = ((residueM > rowNumPerCompute) ? rowNumPerCompute : residueM) * eleNumPerBlk; - for (uint32_t colOffset = 0; colOffset < TileShape::COLUMN; colOffset += colNumPerCompute) { - uint32_t residueN = TileShape::COLUMN - colOffset; - uint8_t repeatTimes = - static_cast(((residueN > colNumPerCompute) ? colNumPerCompute : residueN) / eleNumPerBlk); - AscendC::Mul(ubOut[rowOffset * TileShape::COLUMN + colOffset], - ubIn0[rowOffset * TileShape::COLUMN + colOffset], ubIn1[rowOffset * eleNumPerBlk], mask, - repeatTimes, repeatParams); - } - } - } -}; - -} // namespace Act::Epilogue::Tile - -#endif diff --git a/act/epilogue/tile/tile_broadcast_one_blk.hpp b/act/epilogue/tile/tile_broadcast_one_blk.hpp deleted file mode 100644 index d8f7d79d..00000000 --- a/act/epilogue/tile/tile_broadcast_one_blk.hpp +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_BROADCAST_ONE_BLK_HPP -#define ACT_EPILOGUE_TILE_TILE_BROADCAST_ONE_BLK_HPP - -#include "../../../act/act.hpp" - -namespace Act::Epilogue::Tile { - -template -struct TileBroadcastOneBlk { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; - - ACT_DEVICE - TileBroadcastOneBlk() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, AscendC::LocalTensor const &ubIn) - { - constexpr uint32_t maxRepeatNum = 255; - constexpr uint32_t eleNumPerBlk = BYTE_PER_BLK / sizeof(ElementCompute); - - AscendC::BrcbRepeatParams repeatParams; - repeatParams.dstBlkStride = 1; - repeatParams.dstRepStride = BLK_NUM_PER_VECTOR_FRACTAL; - - constexpr uint32_t eleNumPerCompute = RoundDown(maxRepeatNum * BLK_NUM_PER_VECTOR_FRACTAL); - for (uint32_t offset = 0; offset < COMPUTE_LENGTH; offset += eleNumPerCompute) { - uint32_t residueM = COMPUTE_LENGTH - offset; - uint32_t computeM = (residueM > eleNumPerCompute) ? eleNumPerCompute : residueM; - uint8_t repeatTimes = static_cast(CeilDiv(computeM)); - AscendC::Brcb(ubOut[offset * eleNumPerBlk], ubIn[offset], repeatTimes, repeatParams); - } - } -}; - -} // namespace Act::Epilogue::Tile - -#endif diff --git a/act/epilogue/tile/tile_cast.hpp b/act/epilogue/tile/tile_cast.hpp deleted file mode 100644 index 50162516..00000000 --- a/act/epilogue/tile/tile_cast.hpp +++ /dev/null @@ -1,45 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_CAST_HPP -#define ACT_EPILOGUE_TILE_TILE_CAST_HPP - -#include "../../../act/act.hpp" - -namespace Act::Epilogue::Tile { - -template < - /// Tag indicating architecture - class ArchTag_, - /// Compute data type - class DstType_, class SrcType_, - /// Length of the compute buffer - class TileShape_> -struct TileCast { - using ArchTag = ArchTag_; - using ElementDst = typename DstType_::Element; - using ElementSrc = typename SrcType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileCast() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, AscendC::LocalTensor const &ubIn) - { - AscendC::Cast(ubOut, ubIn, AscendC::RoundMode::CAST_RINT, TileShape::COUNT); - } -}; - -} // namespace Act::Epilogue::Tile - -#endif diff --git a/act/epilogue/tile/tile_copy.hpp b/act/epilogue/tile/tile_copy.hpp deleted file mode 100644 index 2ed7c9c7..00000000 --- a/act/epilogue/tile/tile_copy.hpp +++ /dev/null @@ -1,104 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_COPY_HPP -#define ACT_EPILOGUE_TILE_TILE_COPY_HPP - -#include "../../../act/epilogue/tile/copy_gm_to_ub.hpp" -#include "../../../act/epilogue/tile/copy_ub_to_gm.hpp" - -namespace Act::Epilogue::Tile { - -template < - /// Tag indicating architecture - class ArchTag, class... Args> -struct TileCopy { - static_assert(DEPENDENT_FALSE, "Unsupporteded tile copy, can not find the specialization."); -}; - -template -struct TileCopy { - using ElementC = typename CType::Element; - using ElementX = typename XType::Element; - using ElementD = typename DType::Element; - - using CopyGmToUbC = CopyGm2Ub; - using CopyGmToUbX = CopyGm2Ub; - using CopyUbToGmD = CopyUb2Gm; - using CopyGmToUbY = CopyGm2Ub; - using CopyGmToUbTemp = CopyGm2Ub; - using CopyUbToGmZ = CopyUb2Gm; -}; - -template -struct TileCopy { - using ElementC = typename CType::Element; - using ElementX = typename XType::Element; - using ElementY = typename YType::Element; - using ElementD = typename DType::Element; - - using CopyGmToUbC = CopyGm2Ub; - using CopyGmToUbX = CopyGm2Ub; - using CopyGmToUbY = CopyGm2Ub; - using CopyUbToGmD = CopyUb2Gm; -}; - -template -struct TileCopyBf16 { - using ElementC = typename CType::Element; - using ElementX = bfloat16_t; - using ElementY = bfloat16_t; - using ElementD = bfloat16_t; - - using CopyGmToUbC = CopyGm2Ub; - using CopyGmToUbX = CopyGm2Ub>; - using CopyGmToUbY = CopyGm2Ub>; - using CopyUbToGmD = CopyUb2Gm>; -}; - -template -struct TileCopyPerTokenDequant { - using ElementC = typename CType::Element; - using ElementScale = typename ScaleType::Element; - using ElementPerTokenScale = typename PerTokenScaleType::Element; - using ElementD = typename DType::Element; - - using CopyGmToUbC = CopyGm2Ub; - using CopyGmToUbScale = CopyGm2Ub; - using CopyGmToUbPerTokenScale = CopyPerTokenScale2Ub; - using CopyUbToGmD = CopyUb2Gm; -}; - -template -struct TileCopyPerTokenDequantGemm { - using ElementX = typename XType::Element; - using ElementScale = typename ScaleType::Element; - using ElementPerTokenScale = typename PerTokenScaleType::Element; - using ElementBias = typename BiasType::Element; - using ElementC = typename CType::Element; - - using CopyGmToUbX = CopyGm2Ub; - using CopyGmToUbScale = CopyGm2Ub; - using CopyGmToUbPerTokenScale = CopyGm2Ub; - using CopyGmToUbBias = CopyGm2Ub; - using CopyUbToGmC = CopyUb2Gm; -}; - -} // namespace Act::Epilogue::Tile - -#endif // ACT_EPILOGUE_TILE_TILE_COPY_HPP diff --git a/act/epilogue/tile/tile_elemwise_add.hpp b/act/epilogue/tile/tile_elemwise_add.hpp deleted file mode 100644 index 8edcc1f9..00000000 --- a/act/epilogue/tile/tile_elemwise_add.hpp +++ /dev/null @@ -1,48 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_ELEMWISE_ADD_HPP -#define ACT_EPILOGUE_TILE_TILE_ELEMWISE_ADD_HPP - -#include "../../../act/act.hpp" - -namespace Act::Epilogue::Tile { - -template < - /// Tag indicating architecture - class ArchTag_, - /// Compute data type - class ComputeType_, - /// Length of the compute buffer - uint32_t COMPUTE_LENGTH_> -struct TileElemWiseAdd { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - - static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; - - ACT_DEVICE - TileElemWiseAdd() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, - AscendC::LocalTensor const &ubIn0, - AscendC::LocalTensor const &ubIn1) - { - // Do the calculation - AscendC::Add(ubOut, ubIn0, ubIn1, COMPUTE_LENGTH); - } -}; - -} // namespace Act::Epilogue::Tile - -#endif diff --git a/act/epilogue/tile/tile_elemwise_mul.hpp b/act/epilogue/tile/tile_elemwise_mul.hpp deleted file mode 100644 index cfc45739..00000000 --- a/act/epilogue/tile/tile_elemwise_mul.hpp +++ /dev/null @@ -1,47 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_ELEMWISE_MUL_HPP -#define ACT_EPILOGUE_TILE_TILE_ELEMWISE_MUL_HPP - -#include "../../../act/act.hpp" - -namespace Act::Epilogue::Tile { - -template < - /// Tag indicating architecture - class ArchTag_, - /// Compute data type - class ComputeType_, - /// Length of the compute buffer - class TileShape_> -struct TileElemwiseMul { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - using TileShape = TileShape_; - - ACT_DEVICE - TileElemwiseMul() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &ubOut, - AscendC::LocalTensor const &ubIn0, - AscendC::LocalTensor const &ubIn1) - { - // Do the calculation - AscendC::Mul(ubOut, ubIn0, ubIn1, TileShape::COUNT); - } -}; - -} // namespace Act::Epilogue::Tile - -#endif diff --git a/act/epilogue/tile/tile_elemwise_muls.hpp b/act/epilogue/tile/tile_elemwise_muls.hpp deleted file mode 100644 index 9bf10fa9..00000000 --- a/act/epilogue/tile/tile_elemwise_muls.hpp +++ /dev/null @@ -1,38 +0,0 @@ -/* - * Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP -#define ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP - -#include "../../../act/gemm/helper.hpp" - -namespace Act::Epilogue::Tile { -template -struct TileElemWiseMuls { - using ArchTag = ArchTag_; - using ElementCompute = typename ComputeType_::Element; - - static constexpr uint32_t COMPUTE_LENGTH = COMPUTE_LENGTH_; - - ACT_DEVICE - TileElemWiseMuls() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstLocal, AscendC::LocalTensor srcTensor, - ElementCompute scalar) - { - AscendC::Muls(dstLocal, srcTensor, scalar, COMPUTE_LENGTH); - } -}; -} // namespace Act::Epilogue::Tile - -#endif // ACT_EPILOGUE_TILE_TILE_ELEMWISE_MULS_HPP diff --git a/act/epilogue/tile/tile_swizzle.hpp b/act/epilogue/tile/tile_swizzle.hpp deleted file mode 100644 index 490a2a5a..00000000 --- a/act/epilogue/tile/tile_swizzle.hpp +++ /dev/null @@ -1,92 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP -#define ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP - -#include "../../../act/act.hpp" -#include "../../../act/detail/alignment.hpp" -#include "../../../act/matrix_coord.hpp" - -namespace Act::Epilogue::Tile { - -struct EpilogueIdentityTileSwizzle { - MatrixCoord blockShape; - MatrixCoord tileShape; - MatrixCoord loopsMN; - - ACT_DEVICE - EpilogueIdentityTileSwizzle() = default; - - ACT_DEVICE - EpilogueIdentityTileSwizzle(MatrixCoord const &blockShape, MatrixCoord const &tileShape) - : blockShape(blockShape), tileShape(tileShape) - { - loopsMN = CeilDiv(blockShape, tileShape); - } - - ACT_DEVICE - uint32_t GetLoops() const - { - return loopsMN.row() * loopsMN.column(); - } - - ACT_DEVICE - MatrixCoord GetTileCoord(uint32_t loopIdx) const - { - return MatrixCoord{loopIdx / loopsMN.column(), loopIdx % loopsMN.column()}; - } - - ACT_DEVICE - MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const - { - return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); - } -}; - -struct EpilogueHorizontalTileSwizzle { - MatrixCoord blockShape; - MatrixCoord tileShape; - MatrixCoord loopsMN; - - ACT_DEVICE - EpilogueHorizontalTileSwizzle() = default; - - ACT_DEVICE - EpilogueHorizontalTileSwizzle(MatrixCoord const &blockShape, MatrixCoord const &tileShape) - : blockShape(blockShape), tileShape(tileShape) - { - loopsMN = CeilDiv(blockShape, tileShape); - } - - ACT_DEVICE - uint32_t GetLoops() const - { - return loopsMN.row() * loopsMN.column(); - } - - ACT_DEVICE - MatrixCoord GetTileCoord(uint32_t loopIdx) const - { - return MatrixCoord{loopIdx % loopsMN.row(), loopIdx / loopsMN.row()}; - } - - ACT_DEVICE - MatrixCoord GetActualTileShape(MatrixCoord const &tileCoord) const - { - return MatrixCoord::Min(tileShape, blockShape - tileCoord * tileShape); - } -}; - -} // namespace Act::Epilogue::Tile - -#endif // ACT_EPILOGUE_TILE_TILE_SWIZZLE_HPP diff --git a/act/gemm/block/block_mmad.hpp b/act/gemm/block/block_mmad.hpp deleted file mode 100644 index 8da81c80..00000000 --- a/act/gemm/block/block_mmad.hpp +++ /dev/null @@ -1,57 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_BLOCK_BLOCK_MMAD_HPP -#define ACT_GEMM_BLOCK_BLOCK_MMAD_HPP - -#include "../../../act/act.hpp" -#include "../../../act/gemm/tile/tile_copy.hpp" -#include "../../../act/gemm/tile/tile_mmad.hpp" - -namespace Act::Gemm::Block { - -template , - class TileMmad = Gemm::Tile::TileMmad> -struct BlockMmad { - static_assert(DEPENDENT_FALSE, "BlockMmad is not implemented for this DispatchPolicy"); -}; - -template , - class TileMmad = Gemm::Tile::TileMmadTla> -struct BlockMmadTla { - static_assert(DEPENDENT_FALSE, "BlockMmadTla is not implemented for this DispatchPolicy"); -}; - -/// new add for the reason that i am using the dispatchpolicy which is same as -/// the policy of the optimized_matmul -// so i add a new one class to avoid the conflict -template , // change the name - class TileMmad = Gemm::Tile::TileMmad> -struct BlockGemm { - static_assert(DEPENDENT_FALSE, "BlockMmad is not implemented for this DispatchPolicy"); -}; - -} // namespace Act::Gemm::Block - -#include "../../../act/gemm/block/block_mmad_preload_async_with_callback.hpp" - -#endif // ACT_GEMM_BLOCK_BLOCK_MMAD_HPP diff --git a/act/gemm/block/block_mmad_preload_async_with_callback.hpp b/act/gemm/block/block_mmad_preload_async_with_callback.hpp deleted file mode 100644 index 324f9799..00000000 --- a/act/gemm/block/block_mmad_preload_async_with_callback.hpp +++ /dev/null @@ -1,410 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP -#define ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP - -#include "../../../act/act.hpp" -#include "../../../act/arch/resource.hpp" -#include "../../../act/coord.hpp" -#include "../../../act/detail/callback.hpp" -#include "../../../act/gemm/dispatch_policy.hpp" -#include "../../../act/gemm/helper.hpp" -#include "../../../act/gemm_coord.hpp" - -namespace Act::Gemm::Block { - -template -struct BlockMmad, - L1TileShape_, L0TileShape_, AType_, BType_, CType_, BiasType_, TileCopy_, TileMmad_> { -public: - // Type Aliases - using DispatchPolicy = MmadAtlasA2PreloadAsyncWithCallback; - using ArchTag = typename DispatchPolicy::ArchTag; - using L1TileShape = L1TileShape_; - using L0TileShape = L0TileShape_; - using ElementA = typename AType_::Element; - using LayoutA = typename AType_::Layout; - using ElementB = typename BType_::Element; - using LayoutB = typename BType_::Layout; - using ElementC = typename CType_::Element; - using LayoutC = typename CType_::Layout; - using TileMmad = TileMmad_; - using CopyGmToL1A = typename TileCopy_::CopyGmToL1A; - using CopyGmToL1B = typename TileCopy_::CopyGmToL1B; - using CopyL1ToL0A = typename TileCopy_::CopyL1ToL0A; - using CopyL1ToL0B = typename TileCopy_::CopyL1ToL0B; - using CopyL0CToGm = typename TileCopy_::CopyL0CToGm; - using ElementAccumulator = - typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; - using LayoutAInL1 = typename CopyL1ToL0A::LayoutSrc; - using LayoutBInL1 = typename CopyL1ToL0B::LayoutSrc; - using LayoutAInL0 = typename CopyL1ToL0A::LayoutDst; - using LayoutBInL0 = typename CopyL1ToL0B::LayoutDst; - using LayoutCInL0 = layout::zN; - - using L1AAlignHelper = Gemm::helper::L1AlignHelper; - using L1BAlignHelper = Gemm::helper::L1AlignHelper; - - static constexpr uint32_t PRELOAD_STAGES = DispatchPolicy::PRELOAD_STAGES; - static constexpr uint32_t L1_STAGES = DispatchPolicy::L1_STAGES; - static constexpr uint32_t L0A_STAGES = DispatchPolicy::L0A_STAGES; - static constexpr uint32_t L0B_STAGES = DispatchPolicy::L0B_STAGES; - static constexpr uint32_t L0C_STAGES = DispatchPolicy::L0C_STAGES; - - static constexpr bool ENABLE_UNIT_FLAG = DispatchPolicy::ENABLE_UNIT_FLAG; - static constexpr bool ENABLE_SHUFFLE_K = DispatchPolicy::ENABLE_SHUFFLE_K; - - // L1 tile size - static constexpr uint32_t L1A_TILE_SIZE = L1TileShape::M * L1TileShape::K * sizeof(ElementA); - static constexpr uint32_t L1B_TILE_SIZE = L1TileShape::N * L1TileShape::K * sizeof(ElementB); - // L0 tile size - static constexpr uint32_t L0A_TILE_SIZE = L0TileShape::M * L0TileShape::K * sizeof(ElementA); - static constexpr uint32_t L0B_TILE_SIZE = L0TileShape::K * L0TileShape::N * sizeof(ElementB); - static constexpr uint32_t L0C_TILE_SIZE = L1TileShape::M * L1TileShape::N * sizeof(ElementAccumulator); - - // Check LayoutC - static_assert(std::is_same_v, "LayoutC only support RowMajor yet!"); - - // Check L1TileShape - static_assert((L1A_TILE_SIZE + L1B_TILE_SIZE) * L1_STAGES <= ArchTag::L1_SIZE, - "L1TileShape exceeding the L1 space!"); - - // Check L0TileShape - static_assert(L0A_TILE_SIZE * L0A_STAGES <= ArchTag::L0A_SIZE, "L0TileShape exceeding the L0A space!"); - static_assert(L0B_TILE_SIZE * L0B_STAGES <= ArchTag::L0B_SIZE, "L0TileShape exceeding the L0B space!"); - static_assert(L0C_TILE_SIZE * L0C_STAGES <= ArchTag::L0C_SIZE, "L0TileShape exceeding the L0C space!"); - - static_assert(L1TileShape::M == L0TileShape::M && L1TileShape::N == L0TileShape::N, - "The situation where the basic blocks of L1 and L0 differ on " - "the m and n axes is not supported yet"); - - static constexpr auto L1A_LAYOUT = LayoutAInL1::template MakeLayout(L1TileShape::M, L1TileShape::K); - static constexpr auto L1B_LAYOUT = LayoutBInL1::template MakeLayout(L1TileShape::K, L1TileShape::N); - - ACT_DEVICE - BlockMmad(Arch::Resource &resource, uint32_t l1BufAddrStart = 0) - { - InitL1(resource, l1BufAddrStart); - InitL0A(resource); - InitL0B(resource); - InitL0C(resource); - } - - ACT_DEVICE - ~BlockMmad() - { - SynchronizeBlock(); - for (uint32_t i = 0; i < L1_STAGES; ++i) { - AscendC::WaitFlag(l1AEventList[i]); - AscendC::WaitFlag(l1BEventList[i]); - } - for (uint32_t i = 0; i < L0A_STAGES; ++i) { - AscendC::WaitFlag(l0AEventList[i]); - } - for (uint32_t i = 0; i < L0B_STAGES; ++i) { - AscendC::WaitFlag(l0BEventList[i]); - } - for (uint32_t i = 0; i < L0C_STAGES; ++i) { - AscendC::WaitFlag(l0CEventList[i]); - } - } - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &gmBlockA, LayoutA const &layoutA, - AscendC::GlobalTensor const &gmBlockB, LayoutB const &layoutB, - AscendC::GlobalTensor const &gmBlockC, LayoutC const &layoutC, - GemmCoord const &actualShape, Callback const &callbackBeforeFixpipe, - Callback const &callbackAfterFixpipe) - { - uint32_t kTileCount = CeilDiv(actualShape.k()); - - uint32_t mRound = RoundUp(actualShape.m()); - uint32_t nRound = RoundUp(actualShape.n()); - - uint32_t startTileIdx = 0; - if constexpr (ENABLE_SHUFFLE_K) { - startTileIdx = AscendC::GetBlockIdx() % kTileCount; - } - - for (uint32_t kLoopIdx = 0; kLoopIdx < kTileCount; ++kLoopIdx) { - uint32_t kTileIdx = (startTileIdx + kLoopIdx < kTileCount) ? (startTileIdx + kLoopIdx) - : (startTileIdx + kLoopIdx - kTileCount); - - uint32_t kActual = - (kTileIdx < kTileCount - 1) ? L1TileShape::K : (actualShape.k() - kTileIdx * L1TileShape::K); - - // Emission load instruction from GM to L1 - MatrixCoord gmTileAOffset{0, kTileIdx * L1TileShape::K}; - MatrixCoord gmTileBOffset{kTileIdx * L1TileShape::K, 0}; - auto gmTileA = gmBlockA[layoutA.GetOffset(gmTileAOffset)]; - auto gmTileB = gmBlockB[layoutB.GetOffset(gmTileBOffset)]; - // Load first matrix A tile from GM to L1 - AscendC::WaitFlag(l1AEventList[l1ListId]); - auto layoutTileA = layoutA.GetTileLayout(MakeCoord(actualShape.m(), kActual)); - copyGmToL1A(l1ATensorList[l1ListId], gmTileA, L1A_LAYOUT, layoutTileA); - AscendC::SetFlag(l1AEventList[l1ListId]); - // Load first matrix B tile from GM to L1 - AscendC::WaitFlag(l1BEventList[l1ListId]); - auto layoutTileB = layoutB.GetTileLayout(MakeCoord(kActual, actualShape.n())); - copyGmToL1B(l1BTensorList[l1ListId], gmTileB, L1B_LAYOUT, layoutTileB); - AscendC::SetFlag(l1BEventList[l1ListId]); - - // If the number of preload instructions reaches the upper limit, perform - // an mmad calculation on L1 tile - if (preloadCount == PRELOAD_STAGES) { - L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); - } - - // Store the current load status - uint32_t preloadL1TileMmadParamsId = (l1TileMmadParamsId + preloadCount < PRELOAD_STAGES) - ? (l1TileMmadParamsId + preloadCount) - : (l1TileMmadParamsId + preloadCount - PRELOAD_STAGES); - auto &l1TileMmadParams = l1TileMmadParamsList[preloadL1TileMmadParamsId]; - l1TileMmadParams.l1ListId = l1ListId; - l1TileMmadParams.mRound = mRound; - l1TileMmadParams.nRound = nRound; - l1TileMmadParams.kActual = kActual; - l1TileMmadParams.isKLoopFirst = (kLoopIdx == 0); - l1TileMmadParams.isKLoopLast = (kLoopIdx == kTileCount - 1); - if (kLoopIdx == kTileCount - 1) { - l1TileMmadParams.gmBlockC = gmBlockC; - l1TileMmadParams.layoutCInGm = layoutC.GetTileLayout(actualShape.GetCoordMN()); - l1TileMmadParams.callbackBeforeFixpipe = callbackBeforeFixpipe; - l1TileMmadParams.callbackAfterFixpipe = callbackAfterFixpipe; - } - - if (preloadCount < PRELOAD_STAGES) { - ++preloadCount; - } else { - l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; - } - l1ListId = (l1ListId + 1 < L1_STAGES) ? (l1ListId + 1) : 0; - } - } - - ACT_DEVICE - void SynchronizeBlock() - { - while (preloadCount > 0) { - L1TileMmad(l1TileMmadParamsList[l1TileMmadParamsId]); - l1TileMmadParamsId = (l1TileMmadParamsId + 1 < PRELOAD_STAGES) ? (l1TileMmadParamsId + 1) : 0; - --preloadCount; - } - } - -private: - struct L1TileMmadParams { - uint32_t l1ListId; - uint32_t mRound; - uint32_t nRound; - uint32_t kActual; - bool isKLoopFirst; - bool isKLoopLast; - AscendC::GlobalTensor gmBlockC; - LayoutC layoutCInGm; - Callback callbackBeforeFixpipe; - Callback callbackAfterFixpipe; - - ACT_DEVICE - L1TileMmadParams() = default; - }; - - ACT_DEVICE - void InitL1(Arch::Resource &resource, uint32_t l1BufAddrStart) - { - uint32_t l1AOffset = l1BufAddrStart; - uint32_t l1BOffset = l1BufAddrStart + L1A_TILE_SIZE * L1_STAGES; - for (uint32_t i = 0; i < L1_STAGES; ++i) { - l1ATensorList[i] = resource.l1Buf.template GetBufferByByte(l1AOffset + L1A_TILE_SIZE * i); - l1BTensorList[i] = resource.l1Buf.template GetBufferByByte(l1BOffset + L1B_TILE_SIZE * i); - l1AEventList[i] = i; - l1BEventList[i] = i + L1_STAGES; - AscendC::SetFlag(l1AEventList[i]); - AscendC::SetFlag(l1BEventList[i]); - } - } - - ACT_DEVICE - void InitL0A(Arch::Resource &resource) - { - for (uint32_t i = 0; i < L0A_STAGES; ++i) { - l0ATensorList[i] = resource.l0ABuf.template GetBufferByByte(L0A_TILE_SIZE * i); - l0AEventList[i] = i; - AscendC::SetFlag(l0AEventList[i]); - } - } - - ACT_DEVICE - void InitL0B(Arch::Resource &resource) - { - for (uint32_t i = 0; i < L0B_STAGES; ++i) { - l0BTensorList[i] = resource.l0BBuf.template GetBufferByByte(L0B_TILE_SIZE * i); - l0BEventList[i] = i + L0A_STAGES; - AscendC::SetFlag(l0BEventList[i]); - } - } - - ACT_DEVICE - void InitL0C(Arch::Resource &resource) - { - for (uint32_t i = 0; i < L0C_STAGES; ++i) { - l0CTensorList[i] = resource.l0CBuf.template GetBufferByByte(L0C_TILE_SIZE * i); - l0CEventList[i] = i; - AscendC::SetFlag(l0CEventList[i]); - } - } - - ACT_DEVICE - void L1TileMmad(L1TileMmadParams const ¶ms) - { - uint32_t mPartLoop = CeilDiv(params.mRound); - uint32_t nPartLoop = CeilDiv(params.nRound); - uint32_t kPartLoop = CeilDiv(params.kActual); - auto &l1ATensor = l1ATensorList[params.l1ListId]; - auto &l1BTensor = l1BTensorList[params.l1ListId]; - - auto &l0CTensor = l0CTensorList[l0CListId]; - LayoutCInL0 layoutCInL0 = LayoutCInL0::MakeLayoutInL0C(MakeCoord(params.mRound, params.nRound)); - - if constexpr (!ENABLE_UNIT_FLAG) { - if (params.isKLoopFirst) { - AscendC::WaitFlag(l0CEventList[l0CListId]); - } - } - - for (uint32_t mPartIdx = 0; mPartIdx < mPartLoop; ++mPartIdx) { - uint32_t mPartActual = - (mPartIdx < mPartLoop - 1) ? L0TileShape::M : (params.mRound - mPartIdx * L0TileShape::M); - - for (uint32_t kPartIdx = 0; kPartIdx < kPartLoop; ++kPartIdx) { - uint32_t kPartActual = - (kPartIdx < kPartLoop - 1) ? L0TileShape::K : (params.kActual - kPartIdx * L0TileShape::K); - - auto &l0ATile = l0ATensorList[l0AListId]; - auto layoutAInL0 = LayoutAInL0::template MakeLayout(mPartActual, kPartActual); - auto l1AOffset = MakeCoord(mPartIdx, kPartIdx) * L0TileShape::ToCoordMK(); - auto l1ATile = l1ATensor[L1A_LAYOUT.GetOffset(l1AOffset)]; - - AscendC::WaitFlag(l0AEventList[l0AListId]); - if ((mPartIdx == 0) && (kPartIdx == 0)) { - AscendC::WaitFlag(l1AEventList[params.l1ListId]); - } - copyL1ToL0A(l0ATile, l1ATile, layoutAInL0, L1A_LAYOUT); - if ((mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1)) { - AscendC::SetFlag(l1AEventList[params.l1ListId]); - } - - for (uint32_t nPartIdx = 0; nPartIdx < nPartLoop; ++nPartIdx) { - uint32_t nPartActual = - (nPartIdx < nPartLoop - 1) ? L0TileShape::N : (params.nRound - nPartIdx * L0TileShape::N); - - auto &l0BTile = l0BTensorList[l0BListId]; - auto layoutBInL0 = LayoutBInL0::template MakeLayout(kPartActual, nPartActual); - auto l1BOffset = MakeCoord(kPartIdx, nPartIdx) * L0TileShape::ToCoordKN(); - auto l1BTile = l1BTensor[L1B_LAYOUT.GetOffset(l1BOffset)]; - - AscendC::WaitFlag(l0BEventList[l0BListId]); - if ((kPartIdx == 0) && (nPartIdx == 0)) { - AscendC::WaitFlag(l1BEventList[params.l1ListId]); - } - copyL1ToL0B(l0BTile, l1BTile, layoutBInL0, L1B_LAYOUT); - if ((kPartIdx == kPartLoop - 1) && (nPartIdx == nPartLoop - 1)) { - AscendC::SetFlag(l1BEventList[params.l1ListId]); - } - - AscendC::SetFlag(EVENT_ID0); - - auto l0COffset = MakeCoord(mPartIdx, nPartIdx) * L0TileShape::ToCoordMN(); - auto l0CTile = l0CTensor[layoutCInL0.GetOffset(l0COffset)]; - - AscendC::WaitFlag(EVENT_ID0); - // If the current tile is the first tile on the k axis, the - // accumulator needs to be reset to 0 - bool initC = (params.isKLoopFirst && (kPartIdx == 0)); - // If the unit flag is enabled, the unit flag is set according to the - // calculation progress - uint8_t unitFlag = 0b00; - if constexpr (ENABLE_UNIT_FLAG) { - if (params.isKLoopLast && (mPartIdx == mPartLoop - 1) && (kPartIdx == kPartLoop - 1) && - (nPartIdx == nPartLoop - 1)) { - unitFlag = 0b11; - } else { - unitFlag = 0b10; - } - } - tileMmad(l0CTile, l0ATile, l0BTile, mPartActual, nPartActual, kPartActual, initC, unitFlag); - - AscendC::SetFlag(l0BEventList[l0BListId]); - l0BListId = (l0BListId + 1 < L0B_STAGES) ? (l0BListId + 1) : 0; - } - AscendC::SetFlag(l0AEventList[l0AListId]); - l0AListId = (l0AListId + 1 < L0A_STAGES) ? (l0AListId + 1) : 0; - } - } - - if (params.isKLoopLast) { - auto layoutCInGm = params.layoutCInGm; - - params.callbackBeforeFixpipe(); - - if constexpr (!ENABLE_UNIT_FLAG) { - AscendC::SetFlag(l0CEventList[l0CListId]); - AscendC::WaitFlag(l0CEventList[l0CListId]); - copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0); - AscendC::SetFlag(l0CEventList[l0CListId]); - } else { - copyL0CToGm(params.gmBlockC, l0CTensor, layoutCInGm, layoutCInL0, 0b11); - } - l0CListId = (l0CListId + 1 < L0C_STAGES) ? (l0CListId + 1) : 0; - - params.callbackAfterFixpipe(); - } - } - - AscendC::LocalTensor l1ATensorList[L1_STAGES]; - AscendC::LocalTensor l1BTensorList[L1_STAGES]; - int32_t l1AEventList[L1_STAGES]; - int32_t l1BEventList[L1_STAGES]; - uint32_t l1ListId{0}; - - AscendC::LocalTensor l0ATensorList[L0A_STAGES]; - int32_t l0AEventList[L0A_STAGES]; - uint32_t l0AListId{0}; - - AscendC::LocalTensor l0BTensorList[L0B_STAGES]; - int32_t l0BEventList[L0B_STAGES]; - uint32_t l0BListId{0}; - - AscendC::LocalTensor l0CTensorList[L0C_STAGES_]; - int32_t l0CEventList[L0C_STAGES_]; - uint32_t l0CListId{0}; - - L1TileMmadParams l1TileMmadParamsList[PRELOAD_STAGES]; - uint32_t l1TileMmadParamsId{0}; - uint32_t preloadCount{0}; - - TileMmad tileMmad; - CopyGmToL1A copyGmToL1A; - CopyGmToL1B copyGmToL1B; - CopyL1ToL0A copyL1ToL0A; - CopyL1ToL0B copyL1ToL0B; - CopyL0CToGm copyL0CToGm; -}; - -} // namespace Act::Gemm::Block - -#endif // ACT_GEMM_BLOCK_BLOCK_MMAD_PRELOAD_ASYNC_WITH_CALLBACK_HPP diff --git a/act/gemm/block/block_swizzle.hpp b/act/gemm/block/block_swizzle.hpp deleted file mode 100644 index 36662d2a..00000000 --- a/act/gemm/block/block_swizzle.hpp +++ /dev/null @@ -1,243 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP -#define ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP - -#include "../../../act/act.hpp" -#include "../../../act/detail/alignment.hpp" -#include "../../../act/gemm_coord.hpp" -#include "../../../act/matrix_coord.hpp" - -namespace Act::Gemm::Block { - -///////////////////////////////////////////////////////////////////////////////////////////////// - -/// Block swizzling function for Gemms -template -struct GemmIdentityBlockSwizzle { - /// Data members - - GemmCoord problemShape; - MatrixCoord tileMN; - MatrixCoord loopsMN; - - /// Methods - - ACT_DEVICE - GemmIdentityBlockSwizzle() {} - - ACT_DEVICE - GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) - : problemShape(problemShape_), tileMN(tileMN_) - { - loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); - } - - ACT_DEVICE - GemmIdentityBlockSwizzle(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, MatrixCoord const &loopsMN_) - : problemShape(problemShape_), tileMN(tileMN_), loopsMN(loopsMN_) - {} - - ACT_DEVICE - void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_) - { - problemShape = problemShape_; - tileMN = tileMN_; - - loopsMN = CeilDiv(MatrixCoord(problemShape.GetCoordMN()), tileMN); - } - - ACT_DEVICE - void Update(GemmCoord const &problemShape_, MatrixCoord const &tileMN_, MatrixCoord const &loopsMN_) - { - problemShape = problemShape_; - tileMN = tileMN_; - loopsMN = loopsMN_; - } - - ACT_DEVICE - uint32_t GetCoreLoops() const - { - return loopsMN.row() * loopsMN.column(); - } - - ACT_DEVICE - uint32_t GetBatchIdx(uint32_t taskIdx) - { - return taskIdx / (GetCoreLoops()); - } - - ACT_DEVICE - GemmCoord GetBlockCoord(uint32_t taskIdx) - { - uint32_t innerIdx = taskIdx % GetCoreLoops(); - if constexpr (SwizzleDirection == 0) { // Zn - uint32_t tileBlockLoop = CeilDiv(loopsMN.row(), SwizzleOffset); - uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.column()); - uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.column()); - - uint32_t nRow = SwizzleOffset; - if (tileBlockIdx == tileBlockLoop - 1) { - nRow = loopsMN.row() - SwizzleOffset * tileBlockIdx; - } - uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; - uint32_t nIdx = inTileBlockIdx / nRow; - if (tileBlockIdx % 2 == 1) { - nIdx = loopsMN.column() - nIdx - 1; - } - return GemmCoord{mIdx, nIdx, 0}; - } else if constexpr (SwizzleDirection == 1) { // Nz - uint32_t tileBlockLoop = CeilDiv(loopsMN.column(), SwizzleOffset); - uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMN.row()); - uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMN.row()); - - uint32_t nCol = SwizzleOffset; - if (tileBlockIdx == tileBlockLoop - 1) { - nCol = loopsMN.column() - SwizzleOffset * tileBlockIdx; - } - uint32_t mIdx = inTileBlockIdx / nCol; - uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; - if (tileBlockIdx % 2 == 1) { - mIdx = loopsMN.row() - mIdx - 1; - } - return GemmCoord{mIdx, nIdx, 0}; - } - } - - ACT_DEVICE - GemmCoord GetActualBlockShape(GemmCoord blockCoord) - { - uint32_t mActual = - (blockCoord.m() == (loopsMN.row() - 1)) ? (problemShape.m() - blockCoord.m() * tileMN.row()) : tileMN.row(); - uint32_t nActual = (blockCoord.n() == (loopsMN.column() - 1)) - ? (problemShape.n() - blockCoord.n() * tileMN.column()) - : tileMN.column(); - uint32_t kActual = problemShape.k(); - return GemmCoord{mActual, nActual, kActual}; - } -}; - -/// Block swizzling function for Splitk Gemms -template -struct SplitkGemmIdentityBlockSwizzle { - /// Data members - - GemmCoord problemShape; - GemmCoord tileShape; - GemmCoord loopsMNK; - uint32_t splitkFactor = 1; // split k dim into virtual cores - - /// Methods - - ACT_DEVICE - SplitkGemmIdentityBlockSwizzle() {} - - ACT_DEVICE - SplitkGemmIdentityBlockSwizzle(GemmCoord const &problemShape_, GemmCoord const &tileShape_, - uint32_t splitkFactor_ = 1) - : problemShape(problemShape_), tileShape(tileShape_), splitkFactor(splitkFactor_) - { - loopsMNK = CeilDiv(problemShape, tileShape); - } - - ACT_DEVICE - uint32_t GetKIdxBySplitkSliceIdx(uint32_t splitkSliceIdx) const - { - if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { - return (loopsMNK.k() / splitkFactor + 1) * splitkSliceIdx; - } else { - return splitkSliceIdx * (loopsMNK.k() / splitkFactor) + loopsMNK.k() % splitkFactor; - } - } - - ACT_DEVICE - uint32_t GetSplitkSliceIdx(uint32_t taskIdx) const - { - uint32_t mnLoops = loopsMNK.m() * loopsMNK.n(); - return taskIdx % GetCoreLoops() / mnLoops; - } - - ACT_DEVICE - uint32_t GetCoreLoops() const - { - return loopsMNK.m() * loopsMNK.n() * splitkFactor; - } - - ACT_DEVICE - uint32_t GetBatchIdx(uint32_t taskIdx) - { - return taskIdx / GetCoreLoops(); - } - - ACT_DEVICE - GemmCoord GetBlockCoord(uint32_t taskIdx) - { - uint32_t splitkSliceIdx = GetSplitkSliceIdx(taskIdx); - uint32_t kIdx = GetKIdxBySplitkSliceIdx(splitkSliceIdx); - - uint32_t innerIdx = taskIdx % (loopsMNK.m() * loopsMNK.n()); - if constexpr (SwizzleDirection == 0) { // Zn - uint32_t tileBlockLoop = CeilDiv(loopsMNK.m(), SwizzleOffset); - uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.n()); - uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.n()); - - uint32_t nRow = SwizzleOffset; - if (tileBlockIdx == tileBlockLoop - 1) { - nRow = loopsMNK.m() - SwizzleOffset * tileBlockIdx; - } - uint32_t mIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nRow; - uint32_t nIdx = inTileBlockIdx / nRow; - if (tileBlockIdx % 2 == 1) { - nIdx = loopsMNK.n() - nIdx - 1; - } - return GemmCoord{mIdx, nIdx, kIdx}; - } else if constexpr (SwizzleDirection == 1) { // Nz - uint32_t tileBlockLoop = CeilDiv(loopsMNK.n(), SwizzleOffset); - uint32_t tileBlockIdx = innerIdx / (SwizzleOffset * loopsMNK.m()); - uint32_t inTileBlockIdx = innerIdx % (SwizzleOffset * loopsMNK.m()); - - uint32_t nCol = SwizzleOffset; - if (tileBlockIdx == tileBlockLoop - 1) { - nCol = loopsMNK.n() - SwizzleOffset * tileBlockIdx; - } - uint32_t mIdx = inTileBlockIdx / nCol; - uint32_t nIdx = tileBlockIdx * SwizzleOffset + inTileBlockIdx % nCol; - if (tileBlockIdx % 2 == 1) { - mIdx = loopsMNK.m() - mIdx - 1; - } - return GemmCoord{mIdx, nIdx, kIdx}; - } - } - - ACT_DEVICE - GemmCoord GetActualBlockShape(GemmCoord blockCoord, uint32_t splitkSliceIdx) - { - uint32_t splitkSliceLen; - if (splitkSliceIdx < loopsMNK.k() % splitkFactor) { - splitkSliceLen = (loopsMNK.k() / splitkFactor + 1) * tileShape.k(); - } else { - splitkSliceLen = (loopsMNK.k() / splitkFactor) * tileShape.k(); - } - uint32_t mActual = (blockCoord.m() == (loopsMNK.m() - 1)) ? (problemShape.m() - blockCoord.m() * tileShape.m()) - : tileShape.m(); - uint32_t nActual = (blockCoord.n() == (loopsMNK.n() - 1)) ? (problemShape.n() - blockCoord.n() * tileShape.n()) - : tileShape.n(); - uint32_t kActual = (splitkSliceIdx == (splitkFactor - 1)) ? (problemShape.k() - blockCoord.k() * tileShape.k()) - : splitkSliceLen; - return GemmCoord{mActual, nActual, kActual}; - } -}; - -} // namespace Act::Gemm::Block - -#endif // ACT_GEMM_BLOCK_BLOCK_SWIZZLE_HPP diff --git a/act/gemm/dispatch_policy.hpp b/act/gemm/dispatch_policy.hpp deleted file mode 100644 index 4ec7433f..00000000 --- a/act/gemm/dispatch_policy.hpp +++ /dev/null @@ -1,88 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_DISPATCH_POLICY_HPP -#define ACT_GEMM_DISPATCH_POLICY_HPP - -#include "../../act/act.hpp" - -namespace Act::Gemm { - -// Block Mmad Policies - -template -struct MmadAtlasA2Base { - using ArchTag = Arch::AtlasA2; - static constexpr uint32_t ASYNC = ASYNC_; -}; - -using MmadAtlasA2 = MmadAtlasA2Base; -using MmadAtlasA2Async = MmadAtlasA2Base; - -// Now ENABLE_UNIT_FLAG_ must be false when input element is int8 -template -struct MmadAtlasA2Pingpong : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; - static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; -}; - -template -struct MmadAtlasA2Preload : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; - static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; - static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; -}; - -struct MmadAtlasA2FAQK : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; -}; - -struct MmadAtlasA2FAPV : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; -}; - -struct MmadAtlasA2MLAQK : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; -}; - -struct MmadAtlasA2MLAPV : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; -}; - -struct MmadAtlasA2MLAQKTp1Spec : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; -}; - -struct MmadAtlasA2MLAPVTp1Spec : public MmadAtlasA2 { - static constexpr uint32_t STAGES = 2; -}; - -template -struct MmadAtlasA2PreloadAsync : public MmadAtlasA2Async { - static constexpr uint32_t PRELOAD_STAGES = PRELOAD_STAGES_; // Stages of emitting load instruction in advance - static constexpr uint32_t L1_STAGES = L1_STAGES_; - static constexpr uint32_t L0A_STAGES = L0A_STAGES_; - static constexpr uint32_t L0B_STAGES = L0B_STAGES_; - static constexpr uint32_t L0C_STAGES = L0C_STAGES_; - static constexpr bool ENABLE_UNIT_FLAG = ENABLE_UNIT_FLAG_; - static constexpr bool ENABLE_SHUFFLE_K = ENABLE_SHUFFLE_K_; -}; - -template -struct MmadAtlasA2PreloadAsyncWithCallback - : public MmadAtlasA2PreloadAsync {}; -} // namespace Act::Gemm - -#endif // ACT_GEMM_DISPATCH_POLICY_HPP diff --git a/act/gemm/gemm_type.hpp b/act/gemm/gemm_type.hpp deleted file mode 100644 index 145c3964..00000000 --- a/act/gemm/gemm_type.hpp +++ /dev/null @@ -1,29 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_GEMM_TYPE_HPP -#define ACT_GEMM_GEMM_TYPE_HPP - -namespace Act::Gemm { - -//////////////////////////////////////////////////////////////////// - -template -struct GemmType { - using Element = Element_; - using Layout = Layout_; - static constexpr AscendC::TPosition POSITION = POSITION_; -}; - -} // namespace Act::Gemm - -#endif // ACT_GEMM_GEMM_TYPE_HPP diff --git a/act/gemm/helper.hpp b/act/gemm/helper.hpp deleted file mode 100644 index bb634f9b..00000000 --- a/act/gemm/helper.hpp +++ /dev/null @@ -1,280 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_HELPER_HPP -#define ACT_GEMM_HELPER_HPP - -#include "../../act/act.hpp" -#include "../../act/layout/layout.hpp" -#include "../../tla/layout.hpp" - -namespace Act::Gemm::helper { - -template -struct L1AlignHelper { - static_assert(DEPENDENT_FALSE, "Unsupporteded align helper, can not find the specialization."); -}; - -template -struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; -}; - -template -struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; -}; - -template -struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; -}; - -template -struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; -}; - -template -struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; -}; - -template -struct L1AlignHelper { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; -}; - -template -struct ElementAccumulatorSelector { - static_assert(DEPENDENT_FALSE, - "Unsupporteded element accumulator selector, can not find the " - "specialization."); -}; - -template <> -struct ElementAccumulatorSelector { - using ElementAccumulator = float; -}; - -template <> -struct ElementAccumulatorSelector { - using ElementAccumulator = float; -}; - -template <> -struct ElementAccumulatorSelector { - using ElementAccumulator = int32_t; -}; - -template <> -struct ElementAccumulatorSelector { - using ElementAccumulator = float; -}; - -template -struct L1ATypeSelector { - static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); -}; - -template -struct L1ATypeSelector> { - using L1AType = Gemm::GemmType; -}; - -template -struct L1ATypeSelector> { - using L1AType = Gemm::GemmType; -}; - -template -struct L1ATypeSelector> { - using L1AType = Gemm::GemmType; -}; - -template -struct L1ATypeSelector> { - using L1AType = Gemm::GemmType; -}; - -template -struct L1BTypeSelector { - static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); -}; - -template -struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; -}; - -template -struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; -}; - -template -struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; -}; - -template -struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; -}; - -template -struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; -}; - -template -struct L1BTypeSelector> { - using L1BType = Gemm::GemmType; -}; - -template -struct L1AlignHelperTla { - static_assert(DEPENDENT_FALSE, "Unsupporteded align helper tla, can not find the specialization."); -}; - -template -struct L1AlignHelperTla::value>> { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = C0_NUM_PER_FRACTAL; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = ELE_NUM_PER_C0; -}; - -template -struct L1AlignHelperTla::value>> { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t M_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t K_ALIGNED = ELE_NUM_PER_C0; - static constexpr uint32_t N_ALIGNED = C0_NUM_PER_FRACTAL; -}; - -/////////////////////////////////////// -// new add -template -struct L1ATypeSelectorGemm { - static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); -}; - -template -struct L1ATypeSelectorGemm> { - using L1AType = Gemm::GemmType; -}; - -template <> -struct L1ATypeSelectorGemm> { - using L1AType = Gemm::GemmType; -}; - -template -struct L1ATypeSelectorGemm> { - using L1AType = Gemm::GemmType; -}; - -template -struct L1BTypeSelectorGemm { - static_assert(DEPENDENT_FALSE, "Unsupporteded layout selector, can not find the specialization."); -}; - -template -struct L1BTypeSelectorGemm> { - using L1BType = Gemm::GemmType; -}; - -template <> -struct L1BTypeSelectorGemm> { - using L1BType = Gemm::GemmType; -}; - -template -struct L1BTypeSelectorGemm> { - using L1BType = Gemm::GemmType; -}; - -template -struct L0ATypeSelector {}; - -template -struct L0ATypeSelector> { - using L0AType = Gemm::GemmType; -}; - -template -struct L0ATypeSelector> { - using L0AType = Gemm::GemmType; -}; - -template <> -struct L0ATypeSelector> { - using L0AType = Gemm::GemmType; -}; - -template -struct L0BTypeSelectorGemm {}; - -template -struct L0BTypeSelectorGemm> { - using L0BType = Gemm::GemmType; -}; - -template <> -struct L0BTypeSelectorGemm> { - using L0BType = Gemm::GemmType; -}; - -template -struct L0BTypeSelectorGemm> { - using L0BType = Gemm::GemmType; -}; - -template -struct L0BTypeSelectorGemv {}; - -template -struct L0BTypeSelectorGemv> { - using L0BType = Gemm::GemmType; -}; - -template -struct L0BTypeSelectorGemv> { - using L0BType = Gemm::GemmType; -}; - -template <> -struct L0BTypeSelectorGemv> { - using L0BType = Gemm::GemmType; -}; -} // namespace Act::Gemm::helper - -#endif // ACT_GEMM_HELPER_HPP diff --git a/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp b/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp deleted file mode 100644 index 4a59ac9b..00000000 --- a/act/gemm/kernel/grouped_matmul_slice_m_per_token_dequant_multistage_workspace.hpp +++ /dev/null @@ -1,362 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP -#define ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP - -#include "../../../../cam_moe_distribute_combine/op_kernel/a3/cam_moe_distribute_combine.h" -#include "../../../act/act.hpp" -#include "../../../act/arch/cross_core_sync.hpp" -#include "../../../act/arch/resource.hpp" -#include "../../../act/coord.hpp" -#include "../../../act/detail/callback.hpp" -#include "../../../act/gemm_coord.hpp" -#include "../../../act/matrix_coord.hpp" - -namespace Act::Gemm::Kernel { - -template -class GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace -{ -public: - using BlockMmad = BlockMmad_; - using ArchTag = typename BlockMmad::ArchTag; - using L1TileShape = typename BlockMmad::L1TileShape; - using ElementA = typename BlockMmad::ElementA; - using LayoutA = typename BlockMmad::LayoutA; - using ElementB = typename BlockMmad::ElementB; - using LayoutB = typename BlockMmad::LayoutB; - using ElementC = typename BlockMmad::ElementC; - using LayoutC = typename BlockMmad::LayoutC; - using ElementAccumulator = typename BlockMmad::ElementAccumulator; - - using BlockEpilogue = BlockEpilogue_; - using ElementScale = typename BlockEpilogue::ElementScale; - using LayoutScale = typename BlockEpilogue::LayoutScale; - using ElementPerTokenScale = typename BlockEpilogue::ElementPerTokenScale; - using LayoutPerTokenScale = typename BlockEpilogue::LayoutPerTokenScale; - using ElementD = typename BlockEpilogue::ElementD; - using LayoutD = typename BlockEpilogue::LayoutD; - using EpilogueParams = typename BlockEpilogue::Params; - - using BlockScheduler = BlockScheduler_; - static constexpr uint32_t WORKSPACE_STAGES = WORKSPACE_STAGES_; - using ElementGroupList = ElementGroupList_; - - /// Parameters structure - struct Params { - // Data members - GemmCoord problemShape; - uint32_t problemCount; - __gm__ ElementGroupList_ *ptrGroupList; - __gm__ ElementA *ptrA; - LayoutA layoutA; - __gm__ ElementB *ptrB; - LayoutB layoutB; - __gm__ ElementScale *ptrScale; - LayoutScale layoutScale; - __gm__ ElementPerTokenScale *ptrPerTokenScale; - LayoutPerTokenScale layoutPerTokenScale; - __gm__ ElementD *ptrD; - LayoutD layoutD; - GM_ADDR ptrWorkspace; - void *combiner; - - // Methods - ACT_DEVICE - Params() {} - - ACT_DEVICE - Params(GemmCoord problemShape_, uint32_t problemCount_, GM_ADDR ptrGroupList_, GM_ADDR ptrA_, LayoutA layoutA_, - GM_ADDR ptrB_, LayoutB layoutB_, GM_ADDR ptrScale_, LayoutScale layoutScale_, GM_ADDR ptrPerTokenScale_, - LayoutPerTokenScale layoutPerTokenScale_, GM_ADDR ptrD_, LayoutD layoutD_, GM_ADDR ptrWorkspace_, - void *combiner_) - : problemShape(problemShape_), - problemCount(problemCount_), - ptrGroupList(reinterpret_cast<__gm__ ElementGroupList *>(ptrGroupList_)), - ptrA(reinterpret_cast<__gm__ ElementA *>(ptrA_)), - layoutA(layoutA_), - ptrB(reinterpret_cast<__gm__ ElementB *>(ptrB_)), - layoutB(layoutB_), - ptrScale(reinterpret_cast<__gm__ ElementScale *>(ptrScale_)), - layoutScale(layoutScale_), - ptrPerTokenScale(reinterpret_cast<__gm__ ElementPerTokenScale *>(ptrPerTokenScale_)), - layoutPerTokenScale(layoutPerTokenScale_), - ptrD(reinterpret_cast<__gm__ ElementD *>(ptrD_)), - layoutD(layoutD_), - ptrWorkspace(ptrWorkspace_), - combiner(combiner_) - {} - }; - - // Methods - ACT_DEVICE - GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace() - { - Arch::FlagID flagId = 0; - for (uint32_t stageId = 0; stageId < WORKSPACE_STAGES; ++stageId) { - flagAicFinishStoreList[stageId] = Arch::CrossCoreFlag(flagId++); - flagAivFinishComputeList[stageId] = Arch::CrossCoreFlag(flagId++); - aicWaitFuncList[stageId] = {this, stageId}; - aicSetFuncList[stageId] = {this, stageId}; - } - } - - template - ACT_DEVICE void operator()(Params const ¶ms); - - template <> - ACT_DEVICE void operator()(Params const ¶ms) - { - BlockScheduler blockScheduler; - BlockMmad blockMmad(resource); - - // Represent the full gm - AscendC::GlobalTensor gmA; - gmA.SetGlobalBuffer(params.ptrA); - AscendC::GlobalTensor gmB; - gmB.SetGlobalBuffer(params.ptrB); - AscendC::GlobalTensor groupList; - groupList.SetGlobalBuffer(params.ptrGroupList); - - uint32_t coreIdx = AscendC::GetBlockIdx(); - uint32_t coreNum = AscendC::GetBlockNum(); - int64_t gmGroupOffsetA = 0; - int64_t gmGroupOffsetB = 0; - - AscendC::GlobalTensor gmC; - gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); - auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; - - uint32_t stageId = 0; - uint32_t stageUsed = 0; - uint32_t startCoreIdx = 0; - for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { - uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) - : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); - GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; - - LayoutA layoutA = params.layoutA.GetTileLayout(inGroupProblemShape.GetCoordMK()); - LayoutB layoutB = params.layoutB; - - blockScheduler.Update(inGroupProblemShape, MakeCoord(L1TileShape::M, L1TileShape::N)); - uint32_t coreLoops = blockScheduler.GetCoreLoops(); - - // Determine the starting loopIdx of the current core under the current - // groupIdx - uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; - // Loop through the matmul of each groupIdx - for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { - // Compute block location - GemmCoord blockCoord = blockScheduler.GetBlockCoord(loopIdx); - GemmCoord actualBlockShape = blockScheduler.GetActualBlockShape(blockCoord); - - Callback callbackBeforeFixpipe{}; - if (stageUsed == WORKSPACE_STAGES) { - callbackBeforeFixpipe = MakeCallback(&aicWaitFuncList[stageId]); - } else { - ++stageUsed; - } - Callback callbackAfterFixpipe = MakeCallback(&aicSetFuncList[stageId]); - - // Compute initial location in logical coordinates - MatrixCoord offsetA{blockCoord.m() * L1TileShape::M, blockCoord.k() * L1TileShape::K}; - MatrixCoord offsetB{blockCoord.k() * L1TileShape::K, blockCoord.n() * L1TileShape::N}; - MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; - int64_t gmOffsetA = layoutA.GetOffset(offsetA); - int64_t gmOffsetB = layoutB.GetOffset(offsetB); - int64_t gmOffsetC = layoutC.GetOffset(offsetC); - - // Compute block-scoped matrix multiply-add - if constexpr (BlockMmad::DispatchPolicy::ASYNC) { - blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, - gmC[gmOffsetC], layoutC, actualBlockShape, callbackBeforeFixpipe, callbackAfterFixpipe); - } else { - callbackBeforeFixpipe(); - blockMmad(gmA[gmGroupOffsetA + gmOffsetA], layoutA, gmB[gmGroupOffsetB + gmOffsetB], layoutB, - gmC[gmOffsetC], layoutC, actualBlockShape); - callbackAfterFixpipe(); - } - - stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; - } - - gmGroupOffsetA += inGroupProblemShape.m() * inGroupProblemShape.k(); - gmGroupOffsetB += inGroupProblemShape.k() * inGroupProblemShape.n(); - - startCoreIdx = (startCoreIdx + coreLoops) % coreNum; - } - - if constexpr (BlockMmad::DispatchPolicy::ASYNC) { - blockMmad.SynchronizeBlock(); - } - - while (stageUsed > 0) { - uint32_t aivComputeStageId = - (stageId >= stageUsed) ? (stageId - stageUsed) : (stageId + WORKSPACE_STAGES - stageUsed); - Arch::CrossCoreWaitFlag(flagAivFinishComputeList[aivComputeStageId]); - --stageUsed; - } - } - - template <> - ACT_DEVICE void operator()(Params const ¶ms) - { - auto *combiner = (MoeDistributeCombineImpl::CamMoeDistributeCombine *)params.combiner; - { - if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - if (get_subblockid() == 0) { - AscendC::CrossCoreSetFlag<0x0, PIPE_MTE3>(MoeDistributeCombineImpl::RECV_SYNC_EVENT_ID); - } - } - BlockScheduler blockScheduler; - BlockEpilogue blockEpilogue(resource, combiner->GetCalcInfo()); - - uint32_t coreIdx = AscendC::GetBlockIdx() / AscendC::GetSubBlockNum(); - uint32_t coreNum = AscendC::GetBlockNum(); - int64_t gmGroupOffsetScale = 0; - int64_t gmGroupOffsetPerTokenScale = 0; - int64_t gmGroupOffsetD = 0; - - AscendC::GlobalTensor groupList; - groupList.SetGlobalBuffer(params.ptrGroupList); - - AscendC::GlobalTensor gmC; - gmC.SetGlobalBuffer(reinterpret_cast<__gm__ ElementC *>(params.ptrWorkspace)); - auto layoutC = layout::RowMajor{L1TileShape::M * coreNum * WORKSPACE_STAGES, L1TileShape::N}; - - uint32_t stageId = 0; - uint32_t startCoreIdx = 0; - for (uint32_t groupIdx = 0; groupIdx < params.problemCount; ++groupIdx) { - uint32_t currentM = (groupIdx == 0) ? groupList.GetValue(groupIdx) - : (groupList.GetValue(groupIdx) - groupList.GetValue(groupIdx - 1)); - GemmCoord inGroupProblemShape{currentM, params.problemShape.n(), params.problemShape.k()}; - - LayoutScale layoutScale = params.layoutScale; - LayoutPerTokenScale layoutPerTokenScale = - params.layoutPerTokenScale.GetTileLayout(inGroupProblemShape.template GetCoordByAxis<0>()); - LayoutD layoutD = params.layoutD.GetTileLayout(inGroupProblemShape.GetCoordMN()); - - EpilogueParams epilogueParams{params.ptrScale + gmGroupOffsetScale, - layoutScale, - params.ptrPerTokenScale + gmGroupOffsetPerTokenScale, - layoutPerTokenScale, - params.ptrD + gmGroupOffsetD, - layoutD}; - - blockScheduler.Update(inGroupProblemShape, L1TileShape::ToCoordMN()); - blockEpilogue.UpdateParams(epilogueParams); - uint32_t coreLoops = blockScheduler.GetCoreLoops(); - - GemmCoord blockShapeMNK = L1TileShape::ToCoord(); - uint32_t startLoopIdx = ((coreIdx < startCoreIdx) ? (coreIdx + coreNum) : coreIdx) - startCoreIdx; - for (uint32_t loopIdx = startLoopIdx; loopIdx < coreLoops; loopIdx += coreNum) { - GemmCoord blockCoordMNK = blockScheduler.GetBlockCoord(loopIdx); - GemmCoord actualBlockShapeMNK = blockScheduler.GetActualBlockShape(blockCoordMNK); - - MatrixCoord offsetC{(stageId * coreNum + coreIdx) * L1TileShape::M, 0}; - int64_t gmOffsetC = layoutC.GetOffset(offsetC); - auto gmBlockC = gmC[gmOffsetC]; - auto layoutBlockC = layoutC.GetTileLayout(actualBlockShapeMNK.GetCoordMN()); - - Arch::CrossCoreWaitFlag(flagAicFinishStoreList[stageId]); - blockEpilogue(gmGroupOffsetD, groupIdx, blockShapeMNK, blockCoordMNK, actualBlockShapeMNK, gmBlockC, - layoutBlockC); - Arch::CrossCoreSetFlag<0x2, PIPE_MTE3>(flagAivFinishComputeList[stageId]); - - stageId = (stageId + 1 < WORKSPACE_STAGES) ? (stageId + 1) : 0; - } - - gmGroupOffsetScale += inGroupProblemShape.n(); - gmGroupOffsetPerTokenScale += inGroupProblemShape.m(); - gmGroupOffsetD += inGroupProblemShape.m() * inGroupProblemShape.n(); - - startCoreIdx = (startCoreIdx + coreLoops) % coreNum; - } - } - - icache_preload(4); - if constexpr (EXEC_FLAG & EXEC_FLAG_DEEP_FUSE) { - if (get_subblockid() == 0) { - resource.pipe.Init(); - combiner->TPipeSet(&resource.pipe); - combiner->AllToAllSend(); - combiner->TPipeSet(nullptr); - resource.pipe.Destroy(); - } else { - resource.pipe.Init(); - combiner->TPipeSet(&resource.pipe); - combiner->ReducePermute(); - combiner->TPipeSet(nullptr); - resource.pipe.Destroy(); - } - } else { - resource.pipe.Init(); - combiner->TPipeSet(&resource.pipe); - combiner->Process(); - combiner->TPipeSet(nullptr); - resource.pipe.Destroy(); - } - } - -private: - friend struct AicWaitFunc; - friend struct AicSetFunc; - - struct AicWaitFunc { - using MatmulKernel = - GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace; - - ACT_DEVICE - AicWaitFunc() = default; - - ACT_DEVICE - void operator()() const - { - Arch::CrossCoreWaitFlag(ptr->flagAivFinishComputeList[stageId]); - } - - MatmulKernel *ptr{nullptr}; - uint32_t stageId; - }; - - struct AicSetFunc { - using MatmulKernel = - GroupedMatmulSliceMPerTokenDequantMultiStageWorkspace; - - ACT_DEVICE - AicSetFunc() = default; - - ACT_DEVICE - void operator()() const - { - Arch::CrossCoreSetFlag<0x2, PIPE_FIX>(ptr->flagAicFinishStoreList[stageId]); - } - - MatmulKernel *ptr{nullptr}; - uint32_t stageId; - }; - - Arch::CrossCoreFlag flagAicFinishStoreList[WORKSPACE_STAGES]; - Arch::CrossCoreFlag flagAivFinishComputeList[WORKSPACE_STAGES]; - - AicWaitFunc aicWaitFuncList[WORKSPACE_STAGES]; - AicSetFunc aicSetFuncList[WORKSPACE_STAGES]; - Arch::Resource resource; -}; - -} // namespace Act::Gemm::Kernel - -#endif // ACT_GEMM_KERNEL_GROUPED_MATMUL_M_PER_TOKEN_DEQUANT_MULTISTAGE_WORKSPACE_HPP diff --git a/act/gemm/tile/copy_gm_to_l1.hpp b/act/gemm/tile/copy_gm_to_l1.hpp deleted file mode 100644 index 5100d46f..00000000 --- a/act/gemm/tile/copy_gm_to_l1.hpp +++ /dev/null @@ -1,798 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_TILE_COPY_GM_TO_L1_HPP -#define ACT_GEMM_TILE_COPY_GM_TO_L1_HPP - -#include "../../../act/act.hpp" -#include "../../../act/gemm/gemm_type.hpp" -#include "../../../act/layout/layout.hpp" -#include "../../../tla/tensor.hpp" - -using namespace tla; - -namespace Act::Gemm::Tile { - -template -struct CopyGmToL1 { - static_assert(DEPENDENT_FALSE, "Unsupported copy gm to l1, can not find the specialization."); -}; - -/// Partial specialization for AtlasA2, half, RowMajor in and zN out. -/// Matrix A confirm -template -struct CopyGmToL1, Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (layoutSrc.stride(0) < STRIDE_LIMIT) { - intriParams.nValue = layoutSrc.shape(0); - intriParams.srcDValue = layoutSrc.stride(0); - intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { - AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(0)], intriParams); - } - } - } -}; - -template -struct CopyGmToL1, Gemm::GemmType> { - using LayoutDst = layout::zZ; - using LayoutSrc = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::Nd2NzParams intriParams; - uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(0); - uint32_t ndNum = layoutSrc.shape(0) / C0_NUM_PER_FRACTAL; - uint32_t remains = layoutSrc.shape(0) % C0_NUM_PER_FRACTAL; - if (srcNdStride < STRIDE_LIMIT) { - if (ndNum) { - intriParams.ndNum = ndNum; - intriParams.nValue = C0_NUM_PER_FRACTAL; - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcNdMatrixStride = srcNdStride; - intriParams.srcDValue = layoutSrc.stride(0); - - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - - intriParams.dstNzMatrixStride = layoutDst.stride(1); - - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } - - if (remains) { - AscendC::Nd2NzParams tailParams; - tailParams.ndNum = 1; - tailParams.nValue = remains; - tailParams.dValue = layoutSrc.shape(1); - tailParams.srcNdMatrixStride = srcNdStride; - tailParams.srcDValue = layoutSrc.stride(0); - - tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - tailParams.dstNzMatrixStride = 0; //` - - AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], srcTensor[ndNum * srcNdStride], tailParams); - } - } else if (layoutSrc.stride(0) < STRIDE_LIMIT) { - for (uint32_t i = 0; i < ndNum; i++) { - AscendC::Nd2NzParams intriParams; - intriParams.ndNum = 1; - intriParams.nValue = C0_NUM_PER_FRACTAL; - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcNdMatrixStride = 0; - intriParams.srcDValue = layoutSrc.stride(0); - - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - AscendC::DataCopy(dstTensor[i * layoutDst.stride(1)], srcTensor[i * srcNdStride], intriParams); - } - if (remains) { - AscendC::Nd2NzParams tailParams; - tailParams.ndNum = 1; - tailParams.nValue = remains; - tailParams.dValue = layoutSrc.shape(1); - tailParams.srcNdMatrixStride = 0; - tailParams.srcDValue = layoutSrc.stride(0); - - tailParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - tailParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - tailParams.dstNzMatrixStride = 0; - - AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(1)], srcTensor[ndNum * srcNdStride], tailParams); - } - } else { - for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { - uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; - uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; - - AscendC::Nd2NzParams intriParams; - intriParams.ndNum = 1; - intriParams.nValue = 1; - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcNdMatrixStride = 0; - intriParams.srcDValue = 0; - - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = 0; - intriParams.dstNzMatrixStride = 0; - - uint32_t offsetDst = i * idxR0 * layoutDst.stride(1) + idxInR0 * ELE_NUM_PER_C0; - uint32_t offsetSrc = i * layoutSrc.stride(0); - AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], intriParams); - } - } - } -}; - -template -struct CopyGmToL1, Gemm::GemmType> { - using LayoutDst = layout::nN; - using LayoutSrc = layout::ColumnMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::Nd2NzParams intriParams; - uint32_t srcNdStride = C0_NUM_PER_FRACTAL * layoutSrc.stride(1); - uint32_t ndNum = layoutSrc.shape(1) / C0_NUM_PER_FRACTAL; - uint32_t remains = layoutSrc.shape(1) % C0_NUM_PER_FRACTAL; - if (srcNdStride < STRIDE_LIMIT) { - if (ndNum) { - intriParams.ndNum = ndNum; - intriParams.nValue = C0_NUM_PER_FRACTAL; - intriParams.dValue = layoutSrc.shape(0); - intriParams.srcNdMatrixStride = srcNdStride; - intriParams.srcDValue = layoutSrc.stride(1); - - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - - intriParams.dstNzMatrixStride = layoutDst.stride(3); - - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } - - if (remains) { - AscendC::Nd2NzParams tailParams; - tailParams.ndNum = 1; - tailParams.nValue = remains; - tailParams.dValue = layoutSrc.shape(0); - tailParams.srcNdMatrixStride = srcNdStride; - tailParams.srcDValue = layoutSrc.stride(1); - - tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - tailParams.dstNzMatrixStride = 0; - - AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], srcTensor[ndNum * srcNdStride], tailParams); - } - } else if (layoutSrc.stride(1) < STRIDE_LIMIT) { - for (uint32_t i = 0; i < ndNum; i++) { - AscendC::Nd2NzParams intriParams; - intriParams.ndNum = 1; - intriParams.nValue = C0_NUM_PER_FRACTAL; - intriParams.dValue = layoutSrc.shape(0); - intriParams.srcNdMatrixStride = 0; - intriParams.srcDValue = layoutSrc.stride(1); - - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - AscendC::DataCopy(dstTensor[i * layoutDst.stride(3)], srcTensor[i * srcNdStride], intriParams); - } - if (remains) { - AscendC::Nd2NzParams tailParams; - tailParams.ndNum = 1; - tailParams.nValue = remains; - tailParams.dValue = layoutSrc.shape(0); - tailParams.srcNdMatrixStride = 0; - tailParams.srcDValue = layoutSrc.stride(1); - - tailParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - tailParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - tailParams.dstNzMatrixStride = 0; - - AscendC::DataCopy(dstTensor[ndNum * layoutDst.stride(3)], srcTensor[ndNum * srcNdStride], tailParams); - } - } else { - for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { - uint32_t idxR0 = i / C0_NUM_PER_FRACTAL; - uint32_t idxInR0 = i % C0_NUM_PER_FRACTAL; - - AscendC::Nd2NzParams intriParams; - intriParams.ndNum = 1; - intriParams.nValue = 1; - intriParams.dValue = layoutSrc.shape(0); - intriParams.srcNdMatrixStride = 0; - intriParams.srcDValue = 0; - - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzNStride = 0; - intriParams.dstNzMatrixStride = 0; - - uint32_t offsetDst = i * idxR0 * layoutDst.stride(3) + idxInR0 * ELE_NUM_PER_C0; - uint32_t offsetSrc = i * layoutSrc.stride(1); - AscendC::DataCopy(dstTensor[offsetDst], srcTensor[offsetSrc], intriParams); - } - } - } -}; - -template -struct CopyGmToL1, Gemm::GemmType> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::ColumnMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.shape(0); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (layoutSrc.stride(1) < STRIDE_LIMIT) { - intriParams.nValue = layoutSrc.shape(1); - intriParams.srcDValue = layoutSrc.stride(1); - intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { - AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(1)], intriParams); - } - } - } -}; - -/// Partial specialization for AtlasA2, RowMajor in and zN out. -template -struct CopyGmToL1> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (layoutSrc.stride(0) < STRIDE_LIMIT) { - intriParams.nValue = layoutSrc.shape(0); - intriParams.srcDValue = layoutSrc.stride(0); - intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < layoutSrc.shape(0); i++) { - AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(0)], intriParams); - } - } - } - - // layoutSrc must be the layout of one of the src matrices - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc, uint32_t ndNum, uint32_t srcNdMatrixStride, - uint32_t dstNzNStride, uint32_t dstNzMatrixStride, uint32_t dstNzC0Stride) - { - AscendC::Nd2NzParams intriParams; - - intriParams.nValue = layoutSrc.shape(0); - intriParams.dValue = layoutSrc.shape(1); - intriParams.srcDValue = layoutSrc.stride(0); - intriParams.dstNzNStride = dstNzNStride; - intriParams.dstNzC0Stride = dstNzC0Stride; - if (srcNdMatrixStride < STRIDE_LIMIT) { - intriParams.ndNum = ndNum; - intriParams.srcNdMatrixStride = srcNdMatrixStride; - intriParams.dstNzMatrixStride = dstNzMatrixStride; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } else { - intriParams.ndNum = 1; - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzMatrixStride = 0; - for (uint32_t i = 0; i < ndNum; i++) { - AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * srcNdMatrixStride], intriParams); - } - } - } -}; - -/// Partial specialization for AtlasA2, ColumnMajor in and nZ out. -template -struct CopyGmToL1> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::ColumnMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.shape(0); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (layoutSrc.stride(1) < STRIDE_LIMIT) { - intriParams.nValue = layoutSrc.shape(1); - intriParams.srcDValue = layoutSrc.stride(1); - intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < layoutSrc.shape(1); i++) { - AscendC::DataCopy(dstTensor[i * ELE_NUM_PER_C0], srcTensor[i * layoutSrc.stride(1)], intriParams); - } - } - } -}; - -/// Partial specialization for zN in and zN out. -template -struct CopyGmToL1> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - uint32_t blockCount = CeilDiv(layoutSrc.orgShape(1)); - uint32_t blockLen = RoundUp(layoutSrc.orgShape(0)); - - AscendC::DataCopyParams repeatParams; - - if (layoutSrc.stride(3) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { - repeatParams.blockCount = blockCount; - repeatParams.blockLen = blockLen; - repeatParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_C0 - blockLen; - repeatParams.dstStride = layoutDst.stride(3) / ELE_NUM_PER_C0 - blockLen; - AscendC::DataCopy(dstTensor, srcTensor, repeatParams); - } else { - repeatParams.blockCount = 1; - repeatParams.blockLen = blockLen; - repeatParams.srcStride = 0; - repeatParams.dstStride = 0; - for (uint32_t i = 0; i < blockCount; i++) { - uint64_t dstOffset = i * layoutDst.stride(3); - uint64_t srcOffset = i * layoutSrc.stride(3); - AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], repeatParams); - } - } - } -}; - -/// Partial specialization for nZ in and nZ out. -template -struct CopyGmToL1> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - uint32_t blockCount = CeilDiv(layoutSrc.orgShape(0)); - uint32_t blockLen = RoundUp(layoutSrc.orgShape(1)); - - AscendC::DataCopyParams repeatParams; - - if (layoutSrc.stride(1) / ELE_NUM_PER_C0 < STRIDE_LIMIT) { - repeatParams.blockCount = blockCount; - repeatParams.blockLen = blockLen; - repeatParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_C0 - blockLen; - repeatParams.dstStride = layoutDst.stride(1) / ELE_NUM_PER_C0 - blockLen; - AscendC::DataCopy(dstTensor, srcTensor, repeatParams); - } else { - repeatParams.blockCount = 1; - repeatParams.blockLen = blockLen; - repeatParams.srcStride = 0; - repeatParams.dstStride = 0; - for (uint32_t i = 0; i < blockCount; i++) { - uint64_t dstOffset = i * layoutDst.stride(1); - uint64_t srcOffset = i * layoutSrc.stride(1); - AscendC::DataCopy(dstTensor[dstOffset], srcTensor[srcOffset], repeatParams); - } - } - } -}; - -/// Partial specialization for AtlasA2, PaddingRowMajor in and zN out. -template -struct CopyGmToL1> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::PaddingRowMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.orgShape(1); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(3) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - intriParams.nValue = layoutSrc.orgShape(0); - intriParams.srcDValue = layoutSrc.stride(0); - intriParams.dstNzNStride = layoutDst.stride(0) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } -}; - -/// Partial specialization for AtlasA2, ColumnMajor in and nZ out. -template -struct CopyGmToL1> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::PaddingColumnMajor; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = layoutSrc.orgShape(0); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = layoutDst.stride(1) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - intriParams.nValue = layoutSrc.orgShape(1); - intriParams.srcDValue = layoutSrc.stride(2); - intriParams.dstNzNStride = layoutDst.stride(2) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor, srcTensor, intriParams); - } -}; - -/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. -template -struct CopyGmToL1, - Gemm::GemmType> { - using LayoutDst = layout::RowMajor; - using LayoutSrc = layout::RowMajor; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(Element); - static constexpr uint32_t BLOCK_LEN_LIMIT = 65536; - static constexpr uint32_t MAX_REPEAT = 4095; - - // Methods - - ACT_DEVICE - CopyGmToL1() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::GlobalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - uint32_t rows = layoutSrc.shape(0); - uint32_t cols = layoutSrc.shape(1); - uint32_t srcStride = (layoutSrc.stride(0) - layoutSrc.shape(1)) / ELE_NUM_PER_BLK; - uint32_t dstStride = (layoutDst.stride(0) - layoutDst.shape(1)) / ELE_NUM_PER_BLK; - - if ((layoutSrc.shape(1) == layoutSrc.stride(0)) && (layoutDst.shape(1) == layoutDst.stride(0))) { - DataCopy(dstTensor, srcTensor, rows * cols); - } else if (srcStride < STRIDE_LIMIT && dstStride < STRIDE_LIMIT && (cols / ELE_NUM_PER_BLK) < BLOCK_LEN_LIMIT) { - uint32_t rLoops = CeilDiv(rows, MAX_REPEAT); - for (uint32_t i = 0; i < rLoops; ++i) { - uint32_t rActual = (i < rLoops - 1) ? MAX_REPEAT : rows - i * MAX_REPEAT; - AscendC::DataCopyParams dataCopyParams(rActual, cols / ELE_NUM_PER_BLK, srcStride, dstStride); - DataCopy(dstTensor[i * MAX_REPEAT * layoutDst.stride(0)], - srcTensor[i * MAX_REPEAT * layoutSrc.stride(0)], dataCopyParams); - } - } else { - for (uint32_t i = 0; i < rows; ++i) { - DataCopy(dstTensor[i * layoutDst.stride(0)], srcTensor[i * layoutSrc.stride(0)], cols); - } - } - } -}; - -///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// -/// Partial specialization for CopyGmToL1, AtlasA2, RowMajor in and zN out. -template -struct TileCopyTla< - Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::GM>, - Tensor, LayoutDst_, AscendC::TPosition::A1>, - std::enable_if_t::value && tla::detail::iszN::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - const uint32_t nValue = get<0>(srcTensor.shape()); - const uint32_t dValue = get<1>(srcTensor.shape()); - const uint32_t srcDValue = get<0>(srcTensor.stride()); - const uint32_t dstInnerStrideRow = get<0, 0>(dstTensor.stride()); - const uint32_t dstOuterStrideCol = get<1, 1>(dstTensor.stride()); - - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = dValue; - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (srcDValue < STRIDE_LIMIT) { - intriParams.nValue = nValue; - intriParams.srcDValue = srcDValue; - intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < nValue; i++) { - AscendC::DataCopy(dstTensor.data()[i * ELE_NUM_PER_C0], srcTensor.data()[i * srcDValue], intriParams); - } - } - } -}; - -/// Partial specialization for CopyGmToL1, AtlasA2, ColumnMajor in and nZ out. -template -struct TileCopyTla, LayoutSrc_, AscendC::TPosition::GM>, - Tensor, LayoutDst_, AscendC::TPosition::A1>, - std::enable_if_t::value && - tla::detail::isnZ::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - const uint32_t nValue = get<1>(srcTensor.shape()); - const uint32_t dValue = get<0>(srcTensor.shape()); - const uint32_t srcDValue = get<1>(srcTensor.stride()); - const uint32_t dstInnerStrideRow = get<1, 0>(dstTensor.stride()); - const uint32_t dstOuterStrideCol = get<0, 1>(dstTensor.stride()); - - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = dValue; - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = dstOuterStrideCol / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - if (srcDValue < STRIDE_LIMIT) { - intriParams.nValue = nValue; - intriParams.srcDValue = srcDValue; - intriParams.dstNzNStride = dstInnerStrideRow / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); - } else { - intriParams.nValue = 1; - intriParams.srcDValue = 0; - intriParams.dstNzNStride = 0; - for (uint32_t i = 0; i < nValue; i++) { - AscendC::DataCopy(dstTensor.data()[i * ELE_NUM_PER_C0], srcTensor.data()[i * srcDValue], intriParams); - } - } - } -}; - -/// Partial specialization for CopyGmToL1, AtlasA2, PaddingRowMajor in and zN -/// out. -template -struct TileCopyTlaExt, LayoutSrc_, AscendC::TPosition::GM>, - Tensor, LayoutDst_, AscendC::TPosition::A1>, - layout::PaddingRowMajor, layout::zN> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTlaExt() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = get<1>(srcTensor.orgShape()); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = get<1, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - intriParams.nValue = get<0>(srcTensor.orgShape()); - intriParams.srcDValue = get<0, 0>(srcTensor.stride()); - intriParams.dstNzNStride = get<0, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); - } -}; - -/// Partial specialization for TileCopyTlaExt, CopyGmToL1, AtlasA2, -/// PaddingColumnMajor in and nZ out. -template -struct TileCopyTlaExt, LayoutSrc_, AscendC::TPosition::GM>, - Tensor, LayoutDst_, AscendC::TPosition::A1>, - layout::PaddingColumnMajor, layout::nZ> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A1>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTlaExt() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - AscendC::Nd2NzParams intriParams; - - intriParams.ndNum = 1; - intriParams.dValue = get<0>(srcTensor.orgShape()); - intriParams.srcNdMatrixStride = 0; - intriParams.dstNzC0Stride = get<0, 1>(dstTensor.stride()) / ELE_NUM_PER_C0; - intriParams.dstNzMatrixStride = 0; - - intriParams.nValue = get<1>(srcTensor.orgShape()); - intriParams.srcDValue = get<1, 0>(srcTensor.stride()); - intriParams.dstNzNStride = get<1, 0>(dstTensor.stride()) / ELE_NUM_PER_C0; - AscendC::DataCopy(dstTensor.data(), srcTensor.data(), intriParams); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace Act::Gemm::Tile - -#endif // ACT_GEMM_TILE_COPY_GM_TO_L1_HPP diff --git a/act/gemm/tile/copy_gm_to_ub.hpp b/act/gemm/tile/copy_gm_to_ub.hpp deleted file mode 100644 index d5065005..00000000 --- a/act/gemm/tile/copy_gm_to_ub.hpp +++ /dev/null @@ -1,53 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_TILE_COPY_GM_TO_UB_HPP -#define ACT_GEMM_TILE_COPY_GM_TO_UB_HPP - -#include "../../../act/act.hpp" -#include "../../../tla/tensor.hpp" - -namespace Act::Gemm::Tile { - -/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. -template -struct TileCopyTla< - Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::GM>, - Tensor, LayoutDst_, AscendC::TPosition::VECCALC>, - std::enable_if_t::value && tla::detail::isRowMajor::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::VECCALC>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::GM>; - - static constexpr uint32_t ELE_NUM_PER_BLK = BYTE_PER_BLK / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - AscendC::DataCopyExtParams dataCopyParams( - get<0>(srcTensor.shape()), get<1>(srcTensor.shape()) * sizeof(ElementSrc), - (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) * sizeof(ElementSrc), - (get<0>(dstTensor.stride()) - get<1>(dstTensor.shape())) / ELE_NUM_PER_BLK, 0); - AscendC::DataCopyPadExtParams padParams(false, 0, 0, 0); - AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams, padParams); - }; -}; - -} // namespace Act::Gemm::Tile - -#endif // ACT_GEMM_TILE_COPY_GM_TO_UB_HPP diff --git a/act/gemm/tile/copy_l0c_to_gm.hpp b/act/gemm/tile/copy_l0c_to_gm.hpp deleted file mode 100644 index b25e28b0..00000000 --- a/act/gemm/tile/copy_l0c_to_gm.hpp +++ /dev/null @@ -1,219 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP -#define ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP - -#include "../../../act/gemm/gemm_type.hpp" - -namespace Act::Gemm::Tile { - -enum class ScaleGranularity { UNDEFINED = -1, NO_QUANT = 0, PER_TENSOR, PER_CHANNEL, PER_GROUP }; - -template -struct CopyL0CToGmQuantMode { - static_assert(DEPENDENT_FALSE, "Unsupporteded copy l0c to gm, can not find the specialization."); -}; - -// CopyL0CToGm cast fp32 to fp16 -template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::F322F16; -}; - -// CopyL0CToGm cast fp32 to bf16 -template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::F322BF16; -}; - -// CopyL0CToGm output fp32 -template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::NoQuant; -}; - -// CopyL0CToGm output int32 -template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::NoQuant; -}; - -// CopyL0CToGm cast int32_t to fp16 -template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::DEQF16; -}; - -template <> -struct CopyL0CToGmQuantMode { - static constexpr auto VALUE = QuantMode_t::VDEQF16; -}; - -template -struct CopyL0CToGm { - static_assert(DEPENDENT_FALSE, "Unsupporteded copy l0c to gm, can not find the specialization."); -}; - -template -struct CopyL0CToGm, - ScaleGranularity::NO_QUANT, ReluEnable_> { - using ArchTag = Act::Arch::AtlasA2; - using ElementDst = ElementDst_; - using ElementSrc = ElementAccumulator_; - using LayoutSrc = Act::layout::zN; - using LayoutDst = Act::layout::RowMajor; - static constexpr auto quantPre = - CopyL0CToGmQuantMode::VALUE; - static constexpr auto reluEn = ReluEnable_; - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, - LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) - { - AscendC::FixpipeParamsV220 intriParams; - - // Fixpipe layout information - intriParams.nSize = dstLayout.shape(1); - intriParams.mSize = dstLayout.shape(0); - intriParams.srcStride = srcLayout.stride(3) / srcLayout.stride(0); - intriParams.dstStride = dstLayout.stride(0); - - // Fixpipe auxiliary arguments - intriParams.quantPre = quantPre; - intriParams.reluEn = reluEn; - intriParams.unitFlag = unitFlag; - - // Call AscendC Fixpipe - AscendC::Fixpipe(dst, src, intriParams); - } -}; - -template -struct CopyL0CToGm, - ScaleGranularity::NO_QUANT, ReluEnable_> { - using ArchTag = Act::Arch::AtlasA2; - using ElementDst = ElementDst_; - using ElementSrc = ElementAccumulator_; - using LayoutSrc = Act::layout::zN; - using LayoutDst = Act::layout::ColumnMajor; - static constexpr auto quantPre = - CopyL0CToGmQuantMode::VALUE; - static constexpr auto reluEn = ReluEnable_; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementDst); - - ACT_DEVICE - CopyL0CToGm() {} - - ACT_DEVICE - void operator()(AscendC::GlobalTensor dstTensor, AscendC::LocalTensor srcTensor, - LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) - { - AscendC::DataCopyCO12DstParams params; - - params.nSize = dstLayout.shape(0); - params.mSize = dstLayout.shape(1); - params.dstStride = dstLayout.stride(1); - params.srcStride = srcLayout.shape(2) * srcLayout.shape(3); - params.quantPre = quantPre; - params.reluPre = 0; - params.channelSplit = false; - params.nz2ndEn = true; - AscendC::DataCopy(dstTensor, srcTensor, params); - } -}; - -template -struct CopyL0CToGm, - ScaleGranularity::NO_QUANT, ReluEnable_> { - using ArchTag = Act::Arch::AtlasA2; - using ElementDst = ElementDst_; - using ElementSrc = ElementAccumulator_; - using LayoutSrc = Act::layout::zN; - using LayoutDst = Act::layout::zN; - static constexpr auto quantPre = - CopyL0CToGmQuantMode::VALUE; - static constexpr auto reluEn = ReluEnable_; - - ACT_DEVICE - void operator()(AscendC::GlobalTensor const &dst, AscendC::LocalTensor const &src, - LayoutDst const &dstLayout, LayoutSrc const &srcLayout, uint8_t unitFlag = 0) - { - AscendC::FixpipeParamsV220 intriParams; - - // Fixpipe layout information - intriParams.nSize = dstLayout.shape(2) * dstLayout.shape(3); - intriParams.mSize = dstLayout.shape(0) * dstLayout.shape(1); - intriParams.srcStride = srcLayout.stride(3) / srcLayout.shape(2); - intriParams.dstStride = dstLayout.stride(3) / (BYTE_PER_C0 / sizeof(ElementDst)); - - // Fixpipe auxiliary arguments - intriParams.quantPre = quantPre; - intriParams.reluEn = reluEn; - intriParams.unitFlag = unitFlag; - - // Call AscendC Fixpipe - AscendC::Fixpipe(dst, src, intriParams); - } -}; - -///////////////////////////////////////////CopyL0CToGmTla///////////////////////////////////////////////// -template -struct CopyL0CToGmTla { - static_assert(DEPENDENT_FALSE, "Unsupporteded copy l0c to gm, can not find the specialization."); -}; - -template -struct CopyL0CToGmTla< - Act::Arch::AtlasA2, TensorSrc_, Tensor, LayoutDst_, AscendC::TPosition::GM>, - ScaleGranularity::NO_QUANT, ReluEnable_, std::enable_if_t::value>> { - using ArchTag = Act::Arch::AtlasA2; - using TensorDst = Tensor, LayoutDst_, AscendC::TPosition::GM>; - using ElementDst = ElementDst_; - using TensorSrc = TensorSrc_; - using ElementSrc = typename TensorSrc::Element; - static constexpr auto quantPre = - CopyL0CToGmQuantMode::VALUE; - static constexpr auto reluEn = ReluEnable_; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor, uint8_t unitFlag = 0) - { - AscendC::FixpipeParamsV220 intriParams; - - // Fixpipe layout information - intriParams.nSize = get<1>(dstTensor.shape()); - intriParams.mSize = get<0>(dstTensor.shape()); - intriParams.srcStride = get<1, 1>(srcTensor.stride()) / get<0, 0>(srcTensor.stride()); - intriParams.dstStride = get<0>(dstTensor.stride()); - - // Fixpipe auxiliary arguments - intriParams.quantPre = quantPre; - intriParams.reluEn = reluEn; - intriParams.unitFlag = unitFlag; - - // Call AscendC Fixpipe - AscendC::Fixpipe(dstTensor.data(), srcTensor.data(), - intriParams); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace Act::Gemm::Tile - -#endif // ACT_GEMM_TILE_COPY_L0C_TO_GM_HPP diff --git a/act/gemm/tile/copy_l1_to_l0a.hpp b/act/gemm/tile/copy_l1_to_l0a.hpp deleted file mode 100644 index 14639773..00000000 --- a/act/gemm/tile/copy_l1_to_l0a.hpp +++ /dev/null @@ -1,392 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP -#define ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP - -#include "../../../act/act.hpp" -#include "../../../act/gemm/gemm_type.hpp" -#include "../../../act/layout/layout.hpp" -#include "../../../tla/tensor.hpp" - -using namespace tla; - -namespace Act::Gemm::Tile { - -template -struct CopyL1ToL0A { - static_assert(DEPENDENT_FALSE, "Unsupporteded copy l1 to l0, can not find the specialization."); -}; - -//////////////////////////////// -/// new add gemm -template -struct CopyL1ToL0A, Act::Gemm::GemmType> { - using LayoutDst = layout::zZ; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0A() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, - LayoutDst layoutDst, LayoutSrc layoutSrc) - { - AscendC::LoadData2DParams loadDataParams; - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < layoutDst.shape(1); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); - } - } -}; - -template -struct CopyL1ToL0A, Act::Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::nN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0A() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, - LayoutDst layoutDst, LayoutSrc layoutSrc) - { - AscendC::LoadData2DParams loadDataParams; - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutSrc.shape(1)); - loadDataParams.srcStride = 1; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - for (uint32_t i = 0; i < layoutDst.shape(1); i++) { - AscendC::LoadData(dstTensor[i * layoutSrc.stride(3)], srcTensor[i * layoutSrc.stride(3)], loadDataParams); - } - } -}; - -template -struct CopyL1ToL0A, Act::Gemm::GemmType> { - using Element = float; - using LayoutDst = layout::zN; - using LayoutSrc = layout::nN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0A() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, - LayoutDst layoutDst, LayoutSrc layoutSrc) - { - AscendC::LoadData2dTransposeParams loadDataParams; - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutSrc.shape(1) / 2); - loadDataParams.srcStride = 1; - loadDataParams.dstGap = 0; - loadDataParams.dstFracGap = static_cast(layoutSrc.shape(1) / 2) - 1; - for (uint32_t i = 0; i < layoutDst.shape(1); i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * layoutSrc.stride(3)], srcTensor[i * layoutSrc.stride(3)], - loadDataParams); - } - } -}; - -template -struct CopyL1ToL0A, Act::Gemm::GemmType> { - using Element = int8_t; - using LayoutDst = layout::zN; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0A() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, - LayoutDst layoutDst, LayoutSrc layoutSrc) - { - uint32_t MRound = layoutSrc.shape(0) * layoutSrc.shape(1); - uint32_t KRound = layoutSrc.shape(2) * layoutSrc.shape(3); - uint32_t KL0Alignment = C0_NUM_PER_FRACTAL * 2; - uint32_t KLoops = CeilDiv(KRound, KL0Alignment); - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(MRound / ELE_NUM_PER_C0); - loadDataParams.srcStride = static_cast(KRound / KL0Alignment); - loadDataParams.dstGap = 1; - loadDataParams.dstFracGap = 0; - - for (uint32_t i = 0; i < KLoops; i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * MRound * KL0Alignment], - srcTensor[i * KL0Alignment * ELE_NUM_PER_C0], loadDataParams); - } - } -}; -////////////////////////////////////////// - -/// Partial specialization for zN in and zZ out. -template -struct CopyL1ToL0A> { - using LayoutDst = layout::zZ; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0A() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < layoutDst.shape(1); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); - } - } -}; - -template -struct CopyL1ToL0A> { - using LayoutDst = layout::zZ; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0A() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); - } - } -}; - -/// Partial specialization for int8_t, nZ in and zZ out. (Transpose A) -template -struct CopyL1ToL0A> { - using Element = int8_t; - using LayoutDst = layout::zZ; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0A() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); - loadDataParams.srcStride = 1; - loadDataParams.dstGap = 0; - loadDataParams.dstFracGap = CeilDiv(layoutDst.orgShape(1)) - 1; - - for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1) * 2], srcTensor[i * layoutSrc.stride(1)], - loadDataParams); - } - } -}; - -///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// - -/// Partial specialization for CopyL1ToL0A, AtlasA2, zN in and zZ out. -template -struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, - Tensor, LayoutDst_, AscendC::TPosition::A2>, - std::enable_if_t::value && - tla::detail::iszN::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A2>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); - const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = dstOuterShapeCol; - loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < dstOuterShapeRow; i++) { - AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], - loadDataParams); - } - } -}; - -/// Partial specialization for CopyL1ToL0A, AtlasA2, nZ in and zZ out. -/// (Transpose A) -template -struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, - Tensor, LayoutDst_, AscendC::TPosition::A2>, - std::enable_if_t::value && - tla::detail::isnZ::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A2>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); - const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = dstOuterShapeCol; - loadDataParams.srcStride = 1; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < dstOuterShapeRow; i++) { - AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], - loadDataParams); - } - } -}; - -/// Partial specialization for CopyL1ToL0A, AtlasA2, int8_t, nZ in and zZ out. -/// (Transpose A) -template -struct TileCopyTla< - Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::A1>, - Tensor, LayoutDst_, AscendC::TPosition::A2>, - std::enable_if_t::value && tla::detail::isnZ::value>> { - using Element = int8_t; - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::A2>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - const uint32_t srcOuterShapeRow = get<0, 1>(srcTensor.shape()); - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = dstOuterShapeCol; - loadDataParams.srcStride = 1; - loadDataParams.dstGap = 0; - loadDataParams.dstFracGap = dstOuterShapeCol - 1; - - for (uint32_t i = 0; i < srcOuterShapeRow; i++) { - AscendC::LoadDataWithTranspose(dstTensor.data()[i * dstOuterStrideRow * 2], - srcTensor.data()[i * srcOuterStrideRow], loadDataParams); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace Act::Gemm::Tile - -#endif // ACT_GEMM_TILE_COPY_L1_TO_L0A_HPP diff --git a/act/gemm/tile/copy_l1_to_l0b.hpp b/act/gemm/tile/copy_l1_to_l0b.hpp deleted file mode 100644 index 6f1ced1d..00000000 --- a/act/gemm/tile/copy_l1_to_l0b.hpp +++ /dev/null @@ -1,537 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP -#define ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP - -#include "../../../act/act.hpp" -#include "../../../act/gemm/gemm_type.hpp" -#include "../../../act/layout/layout.hpp" -#include "../../../tla/tensor.hpp" - -using namespace tla; - -namespace Act::Gemm::Tile { - -template -struct CopyL1ToL0B { - static_assert(DEPENDENT_FALSE, "Unsupporteded copy l1 to l0, can not find the specialization."); -}; - -//////////////////////////////////////// -/// new add gemm -template -struct CopyL1ToL0B, Act::Gemm::GemmType> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::zZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0B() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, - LayoutDst layoutDst, LayoutSrc layoutSrc) - { - AscendC::LoadData2DParams loadDataParams; - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutSrc.shape(3)); - loadDataParams.srcStride = 1; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - for (uint32_t i = 0; i < layoutDst.shape(3); i++) { // K N - AscendC::LoadData(dstTensor[i * layoutSrc.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); - } - } -}; - -template -struct CopyL1ToL0B, Act::Gemm::GemmType> { - using Element = float; - using LayoutDst = layout::nZ; - using LayoutSrc = layout::zZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0B() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, - LayoutDst layoutDst, LayoutSrc layoutSrc) - { - AscendC::LoadData2dTransposeParams loadDataParams; - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutSrc.shape(3) / 2); - loadDataParams.srcStride = 1; - loadDataParams.dstGap = 0; - loadDataParams.dstFracGap = static_cast(layoutSrc.shape(3) / 2) - 1; - for (uint32_t i = 0; i < layoutDst.shape(3); i++) { // K N - AscendC::LoadDataWithTranspose(dstTensor[i * layoutSrc.stride(1)], srcTensor[i * layoutSrc.stride(1)], - loadDataParams); - } - } -}; - -template -struct CopyL1ToL0B, Act::Gemm::GemmType> { - using Element = int8_t; - using LayoutDst = layout::nZ; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0B() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, - LayoutDst layoutDst, LayoutSrc layoutSrc) - { - uint32_t NRound = layoutSrc.shape(2) * layoutSrc.shape(3); - uint32_t KRound = layoutSrc.shape(0) * layoutSrc.shape(1); - uint32_t KL0Alignment = C0_NUM_PER_FRACTAL * 2; - uint32_t KLoops = CeilDiv(KRound, KL0Alignment); - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(NRound / ELE_NUM_PER_C0); - loadDataParams.srcStride = static_cast(KRound / KL0Alignment); - loadDataParams.dstGap = 1; - loadDataParams.dstFracGap = 0; - - for (uint32_t i = 0; i < KLoops; i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * NRound * KL0Alignment], - srcTensor[i * KL0Alignment * ELE_NUM_PER_C0], loadDataParams); - } - } -}; - -template -struct CopyL1ToL0B, Act::Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); - loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < layoutDst.shape(3); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3)], loadDataParams); - } - } -}; - -template -struct CopyL1ToL0B, Act::Gemm::GemmType> { - using LayoutDst = layout::nN; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor dstTensor, AscendC::LocalTensor srcTensor, - LayoutDst layoutDst, LayoutSrc layoutSrc) - { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(1)); - loadDataParams.srcStride = layoutSrc.shape(3); - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < layoutSrc.shape(3); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3)], loadDataParams); - } - } -}; - -template -struct CopyL1ToL0B, Act::Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::nN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = layoutDst.shape(1) * layoutDst.shape(3); - loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(1) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - AscendC::LoadData(dstTensor, srcTensor, loadDataParams); - }; -}; - -template -struct CopyL1ToL0B, Act::Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::nN; - using Element = float; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(0))); - loadDataParams.srcStride = 1; - loadDataParams.dstGap = 0; - loadDataParams.dstFracGap = CeilDiv(layoutDst.orgShape(0)) - 1; - - for (uint32_t i = 0; i < CeilDiv<2 * ELE_NUM_PER_C0>(layoutDst.orgShape(1)); i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3) * 2], srcTensor[i * layoutSrc.stride(3)], - loadDataParams); - } - }; -}; - -template -struct CopyL1ToL0B, Act::Gemm::GemmType> { - using LayoutDst = layout::zN; - using LayoutSrc = layout::nZ; - using Element = int8_t; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(0))); - loadDataParams.srcStride = layoutSrc.stride(1) / ELE_NUM_PER_FRACTAL / 2; - loadDataParams.dstGap = 1; - loadDataParams.dstFracGap = 0; - - for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(1)); i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(3)], srcTensor[i * layoutSrc.stride(3) * 2], - loadDataParams); - } - } -}; -//////////////////////////////////////////// - -/// Partial specialization for int8_t, zN in and nZ out. -template -struct CopyL1ToL0B> { - using Element = int8_t; - using LayoutDst = layout::nZ; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL / 2; - loadDataParams.dstGap = 1; - loadDataParams.dstFracGap = 0; - - for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { - AscendC::LoadDataWithTranspose(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1) * 2], - loadDataParams); - } - } -}; - -/// Partial specialization for zN in and nZ out. -template -struct CopyL1ToL0B> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::zN; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(CeilDiv(layoutDst.orgShape(1))); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < CeilDiv(layoutDst.orgShape(0)); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], loadDataParams); - } - } -}; - -/// Partial specialization for nZ in and nZ out. (Transpose B) -template -struct CopyL1ToL0B> { - using LayoutDst = layout::nZ; - using LayoutSrc = layout::nZ; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - CopyL1ToL0B() {}; - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &dstTensor, AscendC::LocalTensor const &srcTensor, - LayoutDst const &layoutDst, LayoutSrc const &layoutSrc) - { - AscendC::LoadData2DParams loadDataParams; - if (layoutSrc.shape(3) == layoutDst.shape(3)) { - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(1) * layoutDst.shape(3)); - loadDataParams.srcStride = 1; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - AscendC::LoadData(dstTensor, srcTensor, loadDataParams); - } else { - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = static_cast(layoutDst.shape(3)); - loadDataParams.srcStride = layoutSrc.stride(3) / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = layoutDst.stride(3) / ELE_NUM_PER_FRACTAL - 1; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < layoutDst.shape(1); i++) { - AscendC::LoadData(dstTensor[i * layoutDst.stride(1)], srcTensor[i * layoutSrc.stride(1)], - loadDataParams); - } - } - } -}; - -///////////////////////////////////////////TileCopyTla////////////////////////////////////////////////////// -/// Partial specialization for CopyL1ToL0B, AtlasA2, zN in and nZ out. -template -struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, - Tensor, LayoutDst_, AscendC::TPosition::B2>, - std::enable_if_t::value && - tla::detail::iszN::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::B2>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); - const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = dstOuterShapeCol; - loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = true; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < dstOuterShapeRow; i++) { - AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], - loadDataParams); - } - } -}; - -/// Partial specialization for CopyL1ToL0B, AtlasA2, nZ in and nZ out. -/// (Transpose B) -template -struct TileCopyTla, LayoutSrc_, AscendC::TPosition::A1>, - Tensor, LayoutDst_, AscendC::TPosition::B2>, - std::enable_if_t::value && - tla::detail::isnZ::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::B2>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); - const uint32_t dstOuterShapeCol = get<1, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2DParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = dstOuterShapeCol; - loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL; - loadDataParams.sid = 0; - loadDataParams.dstGap = 0; - loadDataParams.ifTranspose = false; - loadDataParams.addrMode = 0; - - for (uint32_t i = 0; i < dstOuterShapeRow; i++) { - AscendC::LoadData(dstTensor.data()[i * dstOuterStrideRow], srcTensor.data()[i * srcOuterStrideRow], - loadDataParams); - } - } -}; - -/// Partial specialization for CopyL1ToL0B, AtlasA2, int8_t, zN in and nZ out. -template -struct TileCopyTla< - Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::A1>, - Tensor, LayoutDst_, AscendC::TPosition::B2>, - std::enable_if_t::value && tla::detail::iszN::value>> { - using Element = int8_t; - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::B2>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::A1>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - const uint32_t srcOuterShapeCol = get<1, 1>(srcTensor.shape()); - const uint32_t srcOuterStrideRow = get<0, 1>(srcTensor.stride()); - const uint32_t srcOuterStrideCol = get<1, 1>(srcTensor.stride()); - const uint32_t dstOuterShapeRow = get<0, 1>(dstTensor.shape()); - const uint32_t dstOuterStrideRow = get<0, 1>(dstTensor.stride()); - - AscendC::LoadData2dTransposeParams loadDataParams; - - loadDataParams.startIndex = 0; - loadDataParams.repeatTimes = srcOuterShapeCol; - loadDataParams.srcStride = srcOuterStrideCol / ELE_NUM_PER_FRACTAL / 2; - loadDataParams.dstGap = 1; - loadDataParams.dstFracGap = 0; - - for (uint32_t i = 0; i < dstOuterShapeRow; i++) { - AscendC::LoadDataWithTranspose(dstTensor.data()[i * dstOuterStrideRow], - srcTensor.data()[i * srcOuterStrideRow * 2], loadDataParams); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace Act::Gemm::Tile - -#endif // ACT_GEMM_TILE_COPY_L1_TO_L0B_HPP diff --git a/act/gemm/tile/copy_ub_to_gm.hpp b/act/gemm/tile/copy_ub_to_gm.hpp deleted file mode 100644 index 87d86e3b..00000000 --- a/act/gemm/tile/copy_ub_to_gm.hpp +++ /dev/null @@ -1,80 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_TILE_COPY_UB_TO_GM_HPP -#define ACT_GEMM_TILE_COPY_UB_TO_GM_HPP - -#include "../../../act/act.hpp" -#include "../../../tla/tensor.hpp" - -namespace Act::Gemm::Tile { - -/// Partial specialization for AtlasA2, RowMajor in and RowMajor out. -template -struct TileCopyTla< - Arch::AtlasA2, Tensor, LayoutSrc_, AscendC::TPosition::VECCALC>, - Tensor, LayoutDst_, AscendC::TPosition::GM>, - std::enable_if_t::value && tla::detail::isRowMajor::value>> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::GM>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::VECCALC>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTla() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - AscendC::DataCopyExtParams dataCopyParams( - get<0>(dstTensor.shape()), get<1>(dstTensor.shape()) * sizeof(ElementSrc), - (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) / ELE_NUM_PER_C0, - (get<0>(dstTensor.stride()) - get<1>(dstTensor.shape())) * sizeof(ElementSrc), 0); - AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams); - }; -}; - -/// Partial specialization for AtlasA2, RowMajor in and PaddingRowMajor out. -template -struct TileCopyTlaExt, LayoutSrc_, AscendC::TPosition::VECCALC>, - Tensor, LayoutDst_, AscendC::TPosition::GM>, layout::RowMajor, - layout::PaddingRowMajor> { - using LayoutDst = LayoutDst_; - using LayoutSrc = LayoutSrc_; - using TensorDst = Tensor, LayoutDst, AscendC::TPosition::GM>; - using TensorSrc = Tensor, LayoutSrc, AscendC::TPosition::VECCALC>; - - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(ElementSrc); - - // Methods - - ACT_DEVICE - TileCopyTlaExt() {}; - - ACT_DEVICE - void operator()(TensorDst const &dstTensor, TensorSrc const &srcTensor) - { - AscendC::DataCopyExtParams dataCopyParams( - get<1, 1>(dstTensor.shape()), get<1, 0>(dstTensor.shape()) * sizeof(ElementSrc), - (get<0>(srcTensor.stride()) - get<1>(srcTensor.shape())) / ELE_NUM_PER_C0, - (get<1, 1>(dstTensor.stride()) - get<1, 0>(dstTensor.shape())) * sizeof(ElementSrc), 0); - AscendC::DataCopyPad(dstTensor.data(), srcTensor.data(), dataCopyParams); - }; -}; - -} // namespace Act::Gemm::Tile - -#endif // ACT_GEMM_TILE_COPY_UB_TO_GM_HPP diff --git a/act/gemm/tile/tile_copy.hpp b/act/gemm/tile/tile_copy.hpp deleted file mode 100644 index c7135709..00000000 --- a/act/gemm/tile/tile_copy.hpp +++ /dev/null @@ -1,183 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_TILE_TILE_COPY_HPP -#define ACT_GEMM_TILE_TILE_COPY_HPP - -#include "../../../act/act.hpp" -#include "../../../act/detail/tag_to_layout.hpp" - -namespace Act::Gemm::Tile { - -template -struct TileCopyTla { - static_assert(DEPENDENT_FALSE, "Unsupporteded tileCopyTla, can not find the specialization."); -}; - -template -struct TileCopyTlaExt { - static_assert(DEPENDENT_FALSE, "Unsupporteded tileCopyTlaExt, can not find the specialization."); -}; -} // namespace Act::Gemm::Tile - -#include "../../../act/gemm/helper.hpp" -#include "../../../act/gemm/tile/copy_gm_to_l1.hpp" -#include "../../../act/gemm/tile/copy_gm_to_ub.hpp" -#include "../../../act/gemm/tile/copy_l0c_to_gm.hpp" -#include "../../../act/gemm/tile/copy_l1_to_l0a.hpp" -#include "../../../act/gemm/tile/copy_l1_to_l0b.hpp" -#include "../../../act/gemm/tile/copy_ub_to_gm.hpp" - -namespace Act::Gemm::Tile { - -template < - /// Tag indicating architecture - class ArchTag, - /// GemmType for A matrix operand - class AType, - /// GemmType type for B matrix operand - class BType, - /// GemmType type for C matrix operand - class CType, - /// GemmType type for Bias operand - class BiasType = void> -struct TileCopy { - using ElementA = typename AType::Element; - using ElementB = typename BType::Element; - using ElementAccumulator = - typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; - - using CopyGmToL1A = Gemm::Tile::CopyGmToL1; - using CopyGmToL1B = Gemm::Tile::CopyGmToL1; - using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A::L1AType>; - using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B::L1BType>; - using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; -}; - -/// new add -template < - /// Tag indicating architecture - class ArchTag, - /// GemmType for A matrix operand - class AType, - /// GemmType type for B matrix operand - class BType, - /// GemmType type for C matrix operand - class CType, - /// GemmTpe type for Bias operand - class BiasType = void> -struct TileCopyGemm { - using ElementA = typename AType::Element; - using ElementB = typename BType::Element; - using ElementAccumulator = - typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; - // change structural - using L1AType = typename helper::L1ATypeSelectorGemm::L1AType; - using L1BType = typename helper::L1BTypeSelectorGemm::L1BType; - using L0AType = typename helper::L0ATypeSelector::L0AType; - using L0BType = typename helper::L0BTypeSelectorGemm::L0BType; - - using CopyGmToL1A = Gemm::Tile::CopyGmToL1; - using CopyGmToL1B = Gemm::Tile::CopyGmToL1; - using CopyL1ToL0A = Gemm::Tile::CopyL1ToL0A; - using CopyL1ToL0B = Gemm::Tile::CopyL1ToL0B; - using CopyL0CToGm = Gemm::Tile::CopyL0CToGm; -}; - -template < - /// Tag indicating architecture - class ArchTag, class TensorA, class LayoutTagA, class TensorB, class LayoutTagB, class TensorC, class LayoutTagC, - class TensorBias = void, class LayoutTagBias = void> -struct PackedTileCopyTla { - using ElementA = typename TensorA::Element; - using ElementB = typename TensorB::Element; - using ElementAccumulator = - typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; - - using LayoutL1A = - detail::TagToLayout_t>::L1AType::Layout>; - using LayoutL1B = - detail::TagToLayout_t>::L1BType::Layout>; - using LayoutL0A = detail::TagToLayout_t; - using LayoutL0B = detail::TagToLayout_t; - using LayoutL0C = typename detail::LayoutL0C; - - using TensorL1A = Tensor, LayoutL1A, AscendC::TPosition::A1>; - using TensorL1B = Tensor, LayoutL1B, AscendC::TPosition::A1>; - using TensorL0A = Tensor, LayoutL0A, AscendC::TPosition::A2>; - using TensorL0B = Tensor, LayoutL0B, AscendC::TPosition::B2>; - using TensorL0C = Tensor, LayoutL0C, AscendC::TPosition::CO1>; - - using L1AAlignHelper = Gemm::helper::L1AlignHelper; - using L1BAlignHelper = Gemm::helper::L1AlignHelper; - - using CopyGmToL1A = Gemm::Tile::TileCopyTla; - using CopyGmToL1B = Gemm::Tile::TileCopyTla; - using CopyL1ToL0A = Gemm::Tile::TileCopyTla; - using CopyL1ToL0B = Gemm::Tile::TileCopyTla; - using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; -}; - -template < - /// Tag indicating architecture - class ArchTag, class TensorA, class LayoutTagA, class TensorB, class LayoutTagB, class TensorC, class LayoutTagC, - class TensorBias = void, class LayoutTagBias = void, bool IS_PADDING_A = false, bool IS_PADDING_B = false> -struct PaddingPackedTileCopyTla { - static_assert(std::is_same_v || std::is_same_v, - "Unsupporteded layout, only can be RowMajor and ColumnMajor"); - static_assert(std::is_same_v || std::is_same_v, - "Unsupporteded layout, only can be RowMajor and ColumnMajor"); - using ElementA = typename TensorA::Element; - using ElementB = typename TensorB::Element; - using ElementAccumulator = - typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; - - using LayoutTagL1A = typename helper::L1ATypeSelector>::L1AType::Layout; - using LayoutTagL1B = typename helper::L1BTypeSelector>::L1BType::Layout; - using LayoutL1A = detail::TagToLayout_t; - using LayoutL1B = detail::TagToLayout_t; - using LayoutL0A = detail::TagToLayout_t; - using LayoutL0B = detail::TagToLayout_t; - using LayoutL0C = typename detail::LayoutL0C; - - using TensorL1A = Tensor, LayoutL1A, AscendC::TPosition::A1>; - using TensorL1B = Tensor, LayoutL1B, AscendC::TPosition::A1>; - using TensorL0A = Tensor, LayoutL0A, AscendC::TPosition::A2>; - using TensorL0B = Tensor, LayoutL0B, AscendC::TPosition::B2>; - using TensorL0C = Tensor, LayoutL0C, AscendC::TPosition::CO1>; - - using L1AAlignHelper = Gemm::helper::L1AlignHelper; - using L1BAlignHelper = Gemm::helper::L1AlignHelper; - - using LayoutPaddingTagA = std::conditional_t, layout::PaddingRowMajor, - layout::PaddingColumnMajor>; - using LayoutPaddingTagB = std::conditional_t, layout::PaddingRowMajor, - layout::PaddingColumnMajor>; - - using CopyGmToL1A = - std::conditional_t, - Gemm::Tile::TileCopyTla>; - using CopyGmToL1B = - std::conditional_t, - Gemm::Tile::TileCopyTla>; - - using CopyL1ToL0A = Gemm::Tile::TileCopyTla; - using CopyL1ToL0B = Gemm::Tile::TileCopyTla; - using CopyL0CToGm = Gemm::Tile::CopyL0CToGmTla; -}; -} // namespace Act::Gemm::Tile - -#endif // ACT_GEMM_TILE_TILE_COPY_HPP diff --git a/act/gemm/tile/tile_mmad.hpp b/act/gemm/tile/tile_mmad.hpp deleted file mode 100644 index 7beacdf7..00000000 --- a/act/gemm/tile/tile_mmad.hpp +++ /dev/null @@ -1,110 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_TILE_TILE_MMAD_HPP -#define ACT_GEMM_TILE_TILE_MMAD_HPP - -#include "../../../act/act.hpp" -#include "../../../act/gemm/helper.hpp" -namespace Act::Gemm::Tile { - -/////////////////////////////////////////////////////////// - -template < - /// Tag indicating architecture - class ArchTag_, - /// GemmType for A matrix operand - class AType_, - /// GemmType type for B matrix operand - class BType_, - /// GemmType type for Bias operand - class BiasType_> -struct TileMmad { - using ElementA = typename AType_::Element; - using ElementB = typename BType_::Element; - using ElementAccumulator = - typename Gemm::helper::ElementAccumulatorSelector::ElementAccumulator; - - // Methods - - ACT_DEVICE - TileMmad() {} - - ACT_DEVICE - void operator()(AscendC::LocalTensor const &l0CTensor, - AscendC::LocalTensor const &l0ATensor, AscendC::LocalTensor const &l0BTensor, - uint32_t m, uint32_t n, uint32_t k, bool initC = true, uint8_t unitFlag = 0) - { - AscendC::MmadParams mmadParams; - mmadParams.m = m; - mmadParams.n = n; - mmadParams.k = k; - mmadParams.unitFlag = unitFlag; - mmadParams.cmatrixInitVal = initC; - - AscendC::Mmad(l0CTensor, l0ATensor, l0BTensor, mmadParams); - - const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; - if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < PIPE_M_BARRIER_THRESHOLD) { - AscendC::PipeBarrier(); - } - } -}; - -///////////////////////////////////////////TileMmadTla///////////////////////////////////////////////// - -template < - /// Tag indicating architecture - class ArchTag_, - /// Tensor type for A matrix operand - class TensorA, - /// Tensor type for B matrix operand - class TensorB, - /// Tensor type for C matrix operand - class TensorC, - /// Tensor type for Bias operand - class TensorBias = void> -struct TileMmadTla { - // Methods - - ACT_DEVICE - TileMmadTla() {} - - ACT_DEVICE - void operator()(TensorC const &l0CTensor, TensorA const &l0ATensor, TensorB const &l0BTensor, bool initC = true, - uint8_t unitFlag = 0) - { - const uint32_t m = get<0>(l0ATensor.orgShape()); - const uint32_t n = get<1>(l0BTensor.orgShape()); - const uint32_t k = get<1>(l0ATensor.orgShape()); - - AscendC::MmadParams mmadParams; - mmadParams.m = m; - mmadParams.n = n; - mmadParams.k = k; - mmadParams.unitFlag = unitFlag; - mmadParams.cmatrixInitVal = initC; - - AscendC::Mmad(l0CTensor.data(), l0ATensor.data(), l0BTensor.data(), mmadParams); - - const uint32_t PIPE_M_BARRIER_THRESHOLD = 10; - if ((m / C0_NUM_PER_FRACTAL) * (n / C0_NUM_PER_FRACTAL) < PIPE_M_BARRIER_THRESHOLD) { - AscendC::PipeBarrier(); - } - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////////////////// - -} // namespace Act::Gemm::Tile - -#endif // ACT_GEMM_TILE_TILE_MMAD_HPP diff --git a/act/gemm_coord.hpp b/act/gemm_coord.hpp deleted file mode 100644 index 2e8dbb56..00000000 --- a/act/gemm_coord.hpp +++ /dev/null @@ -1,159 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMM_COORD_HPP -#define ACT_GEMM_COORD_HPP - -#include "../act/coord.hpp" - -namespace Act { - -/// Shape of a matrix multiply-add operation -template < - /// Rows of matrix product - uint32_t M_ = 1, - /// Columns of matrix product - uint32_t N_ = 1, - /// Inner dimension of matrix product - uint32_t K_ = 1> -struct GemmShape { - static constexpr uint32_t M = M_; - static constexpr uint32_t N = N_; - static constexpr uint32_t K = K_; - - static constexpr int64_t MN = M * N; - static constexpr int64_t MK = M * K; - static constexpr int64_t KN = N * K; - static constexpr int64_t MNK = M * N * K; - - static constexpr int64_t COUNT = MNK; - - /// Returns a Coord object - ACT_HOST_DEVICE - static Coord<3> ToCoord() - { - return MakeCoord(M, N, K); - } - - ACT_HOST_DEVICE - static Coord<2> ToCoordMN() - { - return MakeCoord(M, N); - } - - ACT_HOST_DEVICE - static Coord<2> ToCoordMK() - { - return MakeCoord(M, K); - } - - ACT_HOST_DEVICE - static Coord<2> ToCoordKN() - { - return MakeCoord(K, N); - } -}; - -/// GemmCoord is a structure derived from Coord<3> that specifies a location -/// within the coordinate space of a Gemm problem. -struct GemmCoord : public Coord<3, uint32_t> { - /// Integer-valued index - using Index = uint32_t; - - /// Base type is a Coord of rank=3 - using Base = Coord<3, Index>; - - /// Gemm M dimension - rows of the output C matrix - static constexpr int M_INDEX = 0; - - /// Gemm N dimension - columns of the output C matrix - static constexpr int N_INDEX = 1; - - /// Gemm K dimension - inner dimension of the Gemm problem - static constexpr int K_INDEX = 2; - - /// Default ctor - ACT_HOST_DEVICE - GemmCoord() {} - - /// Constructs from Coord<3> and a batch - ACT_HOST_DEVICE - GemmCoord(Coord<3, Index> const &coord) : Base(coord) {} - - /// Helper to construct from a K, N, M, batch variables - ACT_HOST_DEVICE - GemmCoord(Index m, Index n, Index k) : Base(MakeCoord(m, n, k)) {} - - /// Returns the Gemm M coordinate - ACT_HOST_DEVICE - Index const &m() const - { - return this->At(M_INDEX); - } - - /// Returns reference to the Gemm M coordinate - ACT_HOST_DEVICE - Index &m() - { - return this->At(M_INDEX); - } - - /// Returns the Gemm N coordinate - ACT_HOST_DEVICE - Index const &n() const - { - return this->At(N_INDEX); - } - - /// Returns reference to the Gemm N coordinate - ACT_HOST_DEVICE - Index &n() - { - return this->At(N_INDEX); - } - - /// Returns the Gemm K coordinate - ACT_HOST_DEVICE - Index const &k() const - { - return this->At(K_INDEX); - } - - /// Returns reference to the Gemm K coordinate - ACT_HOST_DEVICE - Index &k() - { - return this->At(K_INDEX); - } - - ACT_HOST_DEVICE - auto GetCoordMN() const - { - return this->GetCoordByAxis(); - } - - ACT_HOST_DEVICE - auto GetCoordMK() const - { - return this->GetCoordByAxis(); - } - - ACT_HOST_DEVICE - auto GetCoordKN() const - { - return this->GetCoordByAxis(); - } -}; - -} // namespace Act - -#endif // ACT_GEMM_COORD_HPP diff --git a/act/gemv_coord.hpp b/act/gemv_coord.hpp deleted file mode 100644 index 2e925c4a..00000000 --- a/act/gemv_coord.hpp +++ /dev/null @@ -1,107 +0,0 @@ -/* - * Copyright (c) 2024 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_GEMV_COORD_HPP -#define ACT_GEMV_COORD_HPP - -#include "../act/coord.hpp" - -namespace Act { - -/// Shape of a matrix multiply-add operation -template < - /// Rows of matrix product - uint32_t M_ = 1, - /// Columns of the matrix (number of elements in the input vector) - uint32_t N_ = 1> -struct GemvShape { - static constexpr uint32_t M = M_; - static constexpr uint32_t N = N_; - - static constexpr int64_t MN = M * N; - - static constexpr int64_t COUNT = MN; - - /// Returns a Coord object - ACT_HOST_DEVICE - static Coord<2> ToCoord() - { - return MakeCoord(M, N); - } -}; - -/// GemvCoord is a structure derived from Coord<2> that specifies a location -/// within the coordinate space of a GEMV problem. -struct GemvCoord : public Coord<2, uint32_t> { - /// Integer-valued index - using Index = uint32_t; - - /// Base type is a Coord of rank=2 - using Base = Coord<2, Index>; - - /// GEMV M dimension - rows of the output vector (y) - static constexpr int M_INDEX = 0; - - /// GEMV N dimension - columns of the matrix (length of the input vector x) - static constexpr int N_INDEX = 1; - - /// Default ctor - ACT_HOST_DEVICE - GemvCoord() {} - - /// Constructs from Coord<2> and a batch - ACT_HOST_DEVICE - GemvCoord(Coord<2, Index> const &coord) : Base(coord) {} - - /// Helper to construct from M, N coordinates - ACT_HOST_DEVICE - GemvCoord(Index m, Index n) : Base(MakeCoord(m, n)) {} - - /// Returns the GEMV M coordinate (row of the result y) - ACT_HOST_DEVICE - Index const &m() const - { - return this->At(M_INDEX); - } - - /// Returns reference to the GEMV M coordinate - ACT_HOST_DEVICE - Index &m() - { - return this->At(M_INDEX); - } - - /// Returns the GEMV N coordinate (column of the matrix A or the input vector - /// x) - ACT_HOST_DEVICE - Index const &n() const - { - return this->At(N_INDEX); - } - - /// Returns reference to the GEMV N coordinate - ACT_HOST_DEVICE - Index &n() - { - return this->At(N_INDEX); - } - - ACT_HOST_DEVICE - auto GetCoordMN() const - { - return this->GetCoordByAxis(); - } -}; - -} // namespace Act - -#endif // ACT_GEMV_COORD_HPP diff --git a/act/layout/layout.hpp b/act/layout/layout.hpp deleted file mode 100644 index 5282545e..00000000 --- a/act/layout/layout.hpp +++ /dev/null @@ -1,20 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_LAYOUT_LAYOUT_HPP -#define ACT_LAYOUT_LAYOUT_HPP - -#include "../../act/act.hpp" -#include "../../act/layout/matrix.hpp" -#include "../../act/layout/vector.hpp" - -#endif // ACT_LAYOUT_LAYOUT_HPP diff --git a/act/layout/matrix.hpp b/act/layout/matrix.hpp deleted file mode 100644 index be705ce0..00000000 --- a/act/layout/matrix.hpp +++ /dev/null @@ -1,1184 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_LAYOUT_MATRIX_HPP -#define ACT_LAYOUT_MATRIX_HPP - -#include "../../act/act.hpp" -#include "../../act/coord.hpp" -#include "../../act/detail/alignment.hpp" -#include "../../act/matrix_coord.hpp" - -namespace Act::layout { - -/// Mapping function for row-major matrices -struct RowMajor { -public: - /// Logical rank of tensor - static constexpr int RANK = 2; - - /// Index type used for coordinates - using Index = uint32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using Shape = Coord; - - /// Stride vector - using Stride = Coord; - -public: - /// Constructor - ACT_HOST_DEVICE - RowMajor(Index rows = 0, Index cols = 0) - : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(cols), LongIndex(1))) - {} - - /// Constructor - ACT_HOST_DEVICE - RowMajor(Index rows, Index cols, LongIndex ldm) - : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(ldm, LongIndex(1))) - {} - - /// Ctor - ACT_HOST_DEVICE - RowMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} - - template - ACT_HOST_DEVICE static RowMajor MakeLayoutInUb(MatrixCoord const &shape) - { - return RowMajor(shape.row(), shape.column(), RoundUp(shape.column())); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const - { - return LongIndex(coord.row()) * stride_[0] + LongIndex(coord.column()); - } - - /// Returns the layout of a tile. - ACT_HOST_DEVICE - RowMajor GetTileLayout(MatrixCoord const &tileShape) const - { - return RowMajor(tileShape, stride()); - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const - { - return shape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) - { - return shape_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const - { - return stride_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) - { - return stride_[idx]; - } - -private: - // - // Data members - // - - /// Shape data member - Shape shape_; - - /// Stride data member - Stride stride_; -}; - -/// Mapping function for col-major matrices -struct ColumnMajor { -public: - /// Logical rank of tensor - static constexpr int RANK = 2; - - /// Index type used for coordinates - using Index = uint32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using Shape = Coord; - - /// Stride vector - using Stride = Coord; - -public: - // Methods - - /// Constructor - ACT_HOST_DEVICE - ColumnMajor(Index rows = 0, Index cols = 0) - : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(1), LongIndex(rows))) - {} - - /// Constructor - ACT_HOST_DEVICE - ColumnMajor(Index rows, Index cols, LongIndex ldm) - : shape_(MakeCoord(rows, cols)), stride_(MakeCoord(LongIndex(1), ldm)) - {} - - /// Ctor - ACT_HOST_DEVICE - ColumnMajor(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const - { - return LongIndex(coord.row()) + LongIndex(coord.column()) * stride_[1]; - } - - /// Returns the layout of a tile. - ACT_HOST_DEVICE - ColumnMajor GetTileLayout(MatrixCoord const &tileShape) const - { - return ColumnMajor(tileShape, stride()); - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const - { - return shape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) - { - return shape_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const - { - return stride_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) - { - return stride_[idx]; - } - -private: - // - // Data members - // - - /// Shape data member - Shape shape_; - - /// Stride data member - Stride stride_; -}; - -/// Mapping function for nZ matrices which is col-major inside fractal and -/// row-major between fractal -struct nZ { -public: - /// Logical rank of tensor - static constexpr int RANK = 4; - - /// Index type used for coordinates - using Index = uint32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; - - /// Logical coordinate - using OrgShape = Coord; - - /// Logical coordinate - using Shape = Coord; - - /// Stride vector - using Stride = Coord; - -public: - // Methods - - /// Constructor - ACT_HOST_DEVICE constexpr nZ( - Index orgRows = 0, /// Number of rows of origin matrices - Index orgCols = 0, /// Number of cols of origin matrices - Index rowsInFractal = 0, /// Number of rows inside the fractal - Index rowsByFractal = 0, /// number of rows by the fractal - Index colsInFractal = 0, /// number of cols inside the fractal - Index colsByFractal = 0, /// number of cols by the fractal - LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal - LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows - LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal - LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), - stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) - {} - - /// Ctor - ACT_HOST_DEVICE constexpr nZ(OrgShape orgShape, Shape shape, Stride stride) - : orgShape_(orgShape), shape_(shape), stride_(stride) - {} - - /// Make the layout of a coordinate (row, column) - template - ACT_HOST_DEVICE constexpr static nZ MakeLayout(Index orgRows, Index orgCols) - { - constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - Index rowsRound = RoundUp(orgRows); - Index colsRound = RoundUp(orgCols); - return nZ(orgRows, orgCols, ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, C0_NUM_PER_FRACTAL, - colsRound / C0_NUM_PER_FRACTAL, 1, colsRound * ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const - { - return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3] + - (LongIndex(coord.row()) % shape_[0]) * stride_[0] + (LongIndex(coord.column()) % shape_[2]) * stride_[2]; - } - - /// Returns the layout of a tile. - ACT_HOST_DEVICE - nZ GetTileLayout(MatrixCoord const &tileOriShape) const - { - auto tileShape = MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), - CeilDiv(tileOriShape.column(), shape(2))); - return nZ(tileOriShape, tileShape, stride()); - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const - { - return orgShape_[idx]; - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) - { - return orgShape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const - { - return shape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) - { - return shape_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const - { - return stride_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) - { - return stride_[idx]; - } - -private: - /// Origin Shape data member - OrgShape orgShape_; - - /// Shape data member - Shape shape_; - - /// Stride data member - Stride stride_; -}; - -/// Mapping function for zN matrices which is row-major inside fractal and -/// col-major between fractal -struct zN { -public: - /// Logical rank of tensor - static constexpr int RANK = 4; - - /// Index type used for coordinates - using Index = uint32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; - - /// Logical coordinate - using OrgShape = Coord; - - /// Logical coordinate - using Shape = Coord; - - /// Stride vector - using Stride = Coord; - -public: - // Methods - - /// Constructor - ACT_HOST_DEVICE constexpr zN( - Index orgRows = 0, /// Number of rows of origin matrices - Index orgCols = 0, /// Number of cols of origin matrices - Index rowsInFractal = 0, /// Number of rows inside the fractal - Index rowsByFractal = 0, /// number of rows by the fractal - Index colsInFractal = 0, /// number of cols inside the fractal - Index colsByFractal = 0, /// number of cols by the fractal - LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal - LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows - LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal - LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), - stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) - {} - - /// Ctor - ACT_HOST_DEVICE constexpr zN(OrgShape orgShape, Shape shape, Stride stride) - : orgShape_(orgShape), shape_(shape), stride_(stride) - {} - - /// Make the layout of a coordinate (row, column) - template - ACT_HOST_DEVICE constexpr static zN MakeLayout(Index orgRows, Index orgCols) - { - constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - Index rowsRound = RoundUp(orgRows); - Index colsRound = RoundUp(orgCols); - return zN(orgRows, orgCols, C0_NUM_PER_FRACTAL, rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, - colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, ELE_NUM_PER_FRACTAL, 1, rowsRound * ELE_NUM_PER_C0); - } - - ACT_HOST_DEVICE - static zN MakeLayoutInL0C(MatrixCoord const &shape) - { - return zN(shape.row(), shape.column(), C0_NUM_PER_FRACTAL, CeilDiv(shape.row()), - C0_NUM_PER_FRACTAL, CeilDiv(shape.column()), C0_NUM_PER_FRACTAL, - C0_NUM_PER_FRACTAL * C0_NUM_PER_FRACTAL, 1, - RoundUp(shape.row()) * C0_NUM_PER_FRACTAL); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const - { - return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3] + - (LongIndex(coord.row()) % shape_[0]) * stride_[0] + (LongIndex(coord.column()) % shape_[2]) * stride_[2]; - } - - /// Returns the layout of a tile. - ACT_HOST_DEVICE - zN GetTileLayout(MatrixCoord const &tileOriShape) const - { - auto tileShape = MakeCoord(shape(0), CeilDiv(tileOriShape.row(), shape(0)), shape(2), - CeilDiv(tileOriShape.column(), shape(2))); - return zN(tileOriShape, tileShape, stride()); - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const - { - return orgShape_[idx]; - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) - { - return orgShape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const - { - return shape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) - { - return shape_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const - { - return stride_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) - { - return stride_[idx]; - } - -private: - /// Origin Shape data member - OrgShape orgShape_; - - /// Shape data member - Shape shape_; - - /// Stride data member - Stride stride_; -}; - -/// Mapping function for zN matrices which is row-major inside fractal and -/// row-major between fractal -struct zZ { -public: - /// Logical rank of tensor - static constexpr int RANK = 4; - - /// Index type used for coordinates - using Index = uint32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; - - /// Logical coordinate - using OrgShape = Coord; - - /// Logical coordinate - using Shape = Coord; - - /// Stride vector - using Stride = Coord; - -public: - // Methods - - /// Constructor - ACT_HOST_DEVICE constexpr zZ( - Index orgRows = 0, /// Number of rows of origin matrices - Index orgCols = 0, /// Number of cols of origin matrices - Index rowsInFractal = 0, /// Number of rows inside the fractal - Index rowsByFractal = 0, /// number of rows by the fractal - Index colsInFractal = 0, /// number of cols inside the fractal - Index colsByFractal = 0, /// number of cols by the fractal - LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal - LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows - LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal - LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), - stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) - {} - - /// Ctor - ACT_HOST_DEVICE constexpr zZ(OrgShape orgShape, Shape shape, Stride stride) - : orgShape_(orgShape), shape_(shape), stride_(stride) - {} - - /// Make the layout of a coordinate (row, column) - template - ACT_HOST_DEVICE constexpr static zZ MakeLayout(Index orgRows, Index orgCols) - { - constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - Index rowsRound = RoundUp(orgRows); - Index colsRound = RoundUp(orgCols); - return zZ(orgRows, orgCols, C0_NUM_PER_FRACTAL, rowsRound / C0_NUM_PER_FRACTAL, ELE_NUM_PER_C0, - colsRound / ELE_NUM_PER_C0, ELE_NUM_PER_C0, colsRound * C0_NUM_PER_FRACTAL, 1, ELE_NUM_PER_FRACTAL); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const - { - return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3]; - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const - { - return orgShape_[idx]; - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) - { - return orgShape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const - { - return shape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) - { - return shape_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const - { - return stride_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) - { - return stride_[idx]; - } - -private: - /// Origin Shape data member - OrgShape orgShape_; - - /// Shape data member - Shape shape_; - - /// Stride data member - Stride stride_; -}; - -/// Mapping function for padding rowmajor matrices -/// A special data layout designed to improve the efficiency of matrix -/// operations in non-512B aligned scenarios. This layout is row-major within -/// blocks and also row-major between blocks. -struct PaddingRowMajor { -public: - /// Logical rank of tensor - static constexpr int RANK = 4; - - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; - - /// Index type used for coordinates - using Index = uint32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using OrgShape = Coord; - - /// Logical coordinate - using Shape = Coord; - - /// Stride vector - using Stride = Coord; - -public: - /// Constructor - ACT_HOST_DEVICE - PaddingRowMajor(Index orgRows, Index orgCols, Index blockRows, Index blockCols) - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, CeilDiv(orgCols, blockCols))), - stride_(MakeCoord((LongIndex)blockCols, (LongIndex)blockRows * (LongIndex)RoundUp(orgCols, blockCols), - (LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols)) - {} - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const - { - LongIndex blockRows = (LongIndex)shape_[0]; - LongIndex blockCols = (LongIndex)shape_[2]; - return (LongIndex)coord.row() / blockRows * stride_[1] + (LongIndex)coord.column() / blockCols * stride_[3] + - (LongIndex)coord.row() % blockRows * stride_[0] + (LongIndex)coord.column() % blockCols; - } - - ACT_HOST_DEVICE - PaddingRowMajor GetTileLayout(MatrixCoord const &tileShape) const - { - return PaddingRowMajor(tileShape.row(), tileShape.column(), shape_[0], shape_[2]); - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const - { - return orgShape_[idx]; - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) - { - return orgShape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const - { - return shape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) - { - return shape_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const - { - return stride_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) - { - return stride_[idx]; - } - -private: - // - // Data members - // - - /// Origin Shape data member - OrgShape orgShape_; - - /// Shape data member - Shape shape_; - - /// Stride data member - Stride stride_; -}; - -/// Mapping function for padding columnmajor matrices -/// A special data layout designed to improve the efficiency of matrix -/// operations in non-512B aligned scenarios. This layout is column-major within -/// blocks and also column-major between blocks. -struct PaddingColumnMajor { -public: - /// Logical rank of tensor - static constexpr int RANK = 4; - - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; - - /// Index type used for coordinates - using Index = uint32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical coordinate - using OrgShape = Coord; - - /// Logical coordinate - using Shape = Coord; - - /// Stride vector - using Stride = Coord; - -public: - /// Constructor - ACT_HOST_DEVICE - PaddingColumnMajor(Index orgRows, Index orgCols, Index blockRows, Index blockCols) - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(blockRows, CeilDiv(orgRows, blockRows), blockCols, CeilDiv(orgCols, blockCols))), - stride_(MakeCoord((LongIndex)1, (LongIndex)blockRows * (LongIndex)blockCols, (LongIndex)blockRows, - (LongIndex)RoundUp(orgRows, blockRows) * (LongIndex)blockCols)) - {} - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const - { - LongIndex blockRows = (LongIndex)shape_[0]; - LongIndex blockCols = (LongIndex)shape_[2]; - return (LongIndex)coord.row() / blockRows * stride_[1] + (LongIndex)coord.column() / blockCols * stride_[3] + - (LongIndex)coord.row() % blockRows + (LongIndex)coord.column() % blockCols * stride_[2]; - } - - ACT_HOST_DEVICE - PaddingColumnMajor GetTileLayout(MatrixCoord const &tileShape) const - { - return PaddingColumnMajor(tileShape.row(), tileShape.column(), shape_[0], shape_[2]); - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const - { - return orgShape_[idx]; - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) - { - return orgShape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const - { - return shape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) - { - return shape_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const - { - return stride_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) - { - return stride_[idx]; - } - -private: - // - // Data members - // - - /// Origin Shape data member - OrgShape orgShape_; - - /// Shape data member - Shape shape_; - - /// Stride data member - Stride stride_; -}; - -/////////////////////// -// new add layout nN -// nN layout -struct nN { -public: - /// Logical rank of tensor - static constexpr int RANK = 4; - - /// Index type used for coordinates - using Index = uint32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Logical rank of orgshape - static constexpr int ORG_SHAPE_RANK = 2; - - /// Logical coordinate - using OrgShape = Coord; - - /// Logical coordinate - using Shape = Coord; - - /// Stride vector - using Stride = Coord; - -public: - // Methods - - /// Constructor - ACT_HOST_DEVICE - nN(Index orgRows = 0, /// Number of rows of origin matrices - Index orgCols = 0, /// Number of cols of origin matrices - - Index rowsInFractal = 0, /// Number of rows inside the fractal - Index rowsByFractal = 0, /// number of rows by the fractal - Index colsInFractal = 0, /// number of cols inside the fractal - Index colsByFractal = 0, /// number of cols by the fractal - - LongIndex strideRowsInFractal = 0, /// number of elements between adjacent rows inside the fractal - LongIndex strideRowsByFractal = 0, /// number of elements between adjacent fractal rows - LongIndex strideColsInFractal = 0, /// number of elements between adjacent cols inside the fractal - LongIndex strideColsByFractal = 0) /// number of elements between adjacent fractal cols - : orgShape_(MakeCoord(orgRows, orgCols)), - shape_(MakeCoord(rowsInFractal, rowsByFractal, colsInFractal, colsByFractal)), - stride_(MakeCoord(strideRowsInFractal, strideRowsByFractal, strideColsInFractal, strideColsByFractal)) - {} - - /// Ctor - ACT_HOST_DEVICE - nN(OrgShape orgShape, Shape shape, Stride stride) : orgShape_(orgShape), shape_(shape), stride_(stride) {} - - /// Make the layout of a coordinate (row, column) - template - ACT_HOST_DEVICE static nN MakeLayout(Index orgRows, Index orgCols) - { - static constexpr uint32_t ELE_NUM_PER_C0 = BYTE_PER_C0 / sizeof(Element); - static constexpr uint32_t ELE_NUM_PER_FRACTAL = BYTE_PER_FRACTAL / sizeof(Element); - Index rowsRound = RoundUp(orgRows); - Index colsRound = RoundUp(orgCols); - return nN(orgRows, orgCols, - - ELE_NUM_PER_C0, rowsRound / ELE_NUM_PER_C0, C0_NUM_PER_FRACTAL, colsRound / C0_NUM_PER_FRACTAL, - - 1, ELE_NUM_PER_FRACTAL, ELE_NUM_PER_C0, rowsRound * C0_NUM_PER_FRACTAL); - } - - /// Returns the offset of a coordinate in linear memory. - /// Assumes coordinate has convention (row, column) - ACT_HOST_DEVICE - LongIndex GetOffset(MatrixCoord const &coord) const - { - return LongIndex(coord.row()) / shape_[0] * stride_[1] + LongIndex(coord.column()) / shape_[2] * stride_[3]; - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index orgShape(int idx) const - { - return orgShape_[idx]; - } - - /// Returns the origin shape of the layout - ACT_HOST_DEVICE - typename OrgShape::Index &orgShape(int idx) - { - return orgShape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const - { - return shape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) - { - return shape_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const - { - return stride_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) - { - return stride_[idx]; - } - -private: - /// Origin Shape data member - OrgShape orgShape_; - - /// Shape data member - Shape shape_; - - /// Stride data member - Stride stride_; -}; -} // namespace Act::layout - -#endif // ACT_LAYOUT_MATRIX_HPP diff --git a/act/layout/vector.hpp b/act/layout/vector.hpp deleted file mode 100644 index 8b62f92a..00000000 --- a/act/layout/vector.hpp +++ /dev/null @@ -1,133 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_LAYOUT_VECTOR_HPP -#define ACT_LAYOUT_VECTOR_HPP - -#include "../../act/act.hpp" -#include "../../act/coord.hpp" - -namespace Act::layout { - -struct VectorLayout { -public: - /// Logical rank of tensor - static constexpr int RANK = 1; - - /// Index type used for coordinates - using Index = uint32_t; - - /// Long index type used for offsets - using LongIndex = int64_t; - - /// Shape vector - using Shape = Coord; - - /// Stride vector - using Stride = Coord; - - /// Logical coordinate - using TensorCoord = Coord; - -public: - // Methods - - ACT_HOST_DEVICE - VectorLayout(Index size = 0) : shape_(MakeCoord(size)), stride_(MakeCoord(LongIndex(1))) {} - - ACT_HOST_DEVICE - VectorLayout(Shape shape, Stride stride) : shape_(shape), stride_(stride) {} - - template - ACT_HOST_DEVICE static VectorLayout MakeLayoutInUb(TensorCoord const &tileShape) - { - return VectorLayout{RoundUp(tileShape[0])}; - } - - ACT_HOST_DEVICE - LongIndex GetOffset(TensorCoord const &coord) const - { - return stride_[0] * coord[0]; - } - - /// Returns the layout of a tile. - ACT_HOST_DEVICE - VectorLayout GetTileLayout(TensorCoord const &tileShape) const - { - return VectorLayout(tileShape, stride()); - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape shape() const - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - Shape &shape() - { - return shape_; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index shape(int idx) const - { - return shape_[idx]; - } - - /// Returns the shape of the layout - ACT_HOST_DEVICE - typename Shape::Index &shape(int idx) - { - return shape_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride stride() const - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - Stride &stride() - { - return stride_; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index stride(int idx) const - { - return stride_[idx]; - } - - /// Returns the stride of the layout - ACT_HOST_DEVICE - typename Stride::Index &stride(int idx) - { - return stride_[idx]; - } - -private: - /// Stride data member - Shape shape_; - Stride stride_; -}; - -} // namespace Act::layout - -#endif // ACT_LAYOUT_VECTOR_HPP diff --git a/act/matrix_coord.hpp b/act/matrix_coord.hpp deleted file mode 100644 index a9018db4..00000000 --- a/act/matrix_coord.hpp +++ /dev/null @@ -1,115 +0,0 @@ -/* - * Copyright (c) 2025 Huawei Technologies Co., Ltd. - * This file is a part of the CANN Open Software. - * Licensed under CANN Open Software License Agreement Version 1.0 (the - * "License"). Please refer to the License for details. You may not use this - * file except in compliance with the License. THIS SOFTWARE IS PROVIDED ON AN - * "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, - * INCLUDING BUT NOT LIMITED TO NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS - * FOR A PARTICULAR PURPOSE. See LICENSE in the root of the software repository - * for the full text of the License. - */ - -#ifndef ACT_MATRIX_COORD_HPP -#define ACT_MATRIX_COORD_HPP - -#include "../act/coord.hpp" - -namespace Act { - -template -struct MatrixShape { - static constexpr uint32_t ROW = ROW_; - static constexpr uint32_t COLUMN = COLUMN_; - - static constexpr int64_t COUNT = ROW * COLUMN; - - ACT_HOST_DEVICE - static Coord<2> ToCoord() - { - return MakeCoord(ROW, COLUMN); - } -}; - -/// MatrixCoord wraps Coord<2, uint32_t> to provide a helper for accessing named -/// dimensions. Classes expecting a coordinate in the rank=2 index space of a -/// matrix should use MatrixCoord. -struct MatrixCoord : public Coord<2, uint32_t> { - /// Integer-valued index - using Index = uint32_t; - - /// Base type is a Coord of rank=2 - using Base = Coord<2, Index>; - - /// LongIndex type - using LongIndex = typename Base::LongIndex; - - /// Rows dimension - static constexpr uint32_t ROW_INDEX = 0; - - /// Columns dimension - static constexpr uint32_t COLUMN_INDEX = 1; - - /// Default ctor - ACT_HOST_DEVICE - MatrixCoord() {} - - /// Constructs from Coord<2> - ACT_HOST_DEVICE - MatrixCoord(Coord<2, Index> const &coord) : Base(coord) {} - - /// Helper to construct from a row and column - ACT_HOST_DEVICE - MatrixCoord(Index row, Index column) : Base(MakeCoord(row, column)) {} - - /// Helper to construct from a row and column, which are LongIndex based - ACT_HOST_DEVICE - MatrixCoord(LongIndex row, LongIndex column) : Base(MakeCoord(Index(row), Index(column))) {} - - /// Returns the row of the coordinate - ACT_HOST_DEVICE - Index const &row() const - { - return this->At(ROW_INDEX); - } - - /// Returns the row of the coordinate - ACT_HOST_DEVICE - Index &row() - { - return this->At(ROW_INDEX); - } - - /// Returns the column of the coordinate - ACT_HOST_DEVICE - Index const &column() const - { - return this->At(COLUMN_INDEX); - } - - /// Returns the column of the coordinate - ACT_HOST_DEVICE - Index &column() - { - return this->At(COLUMN_INDEX); - } - - /// Element-wise addition - ACT_HOST_DEVICE - MatrixCoord operator+(Base const &b) const - { - return MatrixCoord(Base::operator+(b)); - } - - /// In-place addition - ACT_HOST_DEVICE - MatrixCoord &operator+=(Base const &b) - { - Base::operator+=(b); - return *this; - } -}; - -} // namespace Act - -#endif From d463e4bcceab2e5e4b4a35fde35cc6d8683dd0b4 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Tue, 21 Oct 2025 11:34:32 +0800 Subject: [PATCH 4/6] 1.un padding for fused moe --- .github/workflows/pr-test-npu.yml | 4 +-- csrc/deepep/deep_ep.cpp | 57 ++----------------------------- 2 files changed, 5 insertions(+), 56 deletions(-) diff --git a/.github/workflows/pr-test-npu.yml b/.github/workflows/pr-test-npu.yml index 87b418cc..5be7fd2c 100644 --- a/.github/workflows/pr-test-npu.yml +++ b/.github/workflows/pr-test-npu.yml @@ -66,7 +66,7 @@ jobs: run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py - - name: Run test deepep eplb + - name: Run test fused deep moe timeout-minutes: 10 env: HCCL_BUFFSIZE: 2000 @@ -121,7 +121,7 @@ jobs: run: | python3 $GITHUB_WORKSPACE/tests/python/deepep/test_low_latency.py - - name: Run test deepep eplb + - name: Run test fused deep moe timeout-minutes: 10 env: HCCL_BUFFSIZE: 2000 diff --git a/csrc/deepep/deep_ep.cpp b/csrc/deepep/deep_ep.cpp index 02966361..e8d95ebd 100644 --- a/csrc/deepep/deep_ep.cpp +++ b/csrc/deepep/deep_ep.cpp @@ -608,46 +608,6 @@ std::vector Buffer::fused_deep_moe(const at::Tensor &x, const at::Te EP_HOST_ASSERT(expert_scales_optional.dim() == 2); this->is_padding = false; - at::Tensor new_x = x; - this->new_topk_idx = expert_ids; - at::Tensor new_scales = expert_scales_optional; - - if (expert_ids.size(0) < PADDING_SIZE) { - this->is_padding = true; - this->padding_cnt = PADDING_SIZE - expert_ids.size(0); - - std::vector x_blocks; - std::vector idx_blocks; - - if (expert_ids.size(0) != 0) { - x_blocks.emplace_back(x); - idx_blocks.emplace_back(expert_ids); - } else { - this->ori_x = x.clone(); // store the original input when the batch is completely empty - } - - int topk = static_cast(expert_ids.size(1)); - for (int i = 0; i < this->padding_cnt; i++) { - at::Tensor tmp_x = torch::ones({1, x.size(1)}, x.options()); - at::Tensor tmp_idx = - torch::randperm(num_experts, expert_ids.options()).slice(0, 0, topk).reshape({1, topk}); - x_blocks.emplace_back(tmp_x); - idx_blocks.emplace_back(tmp_idx); - } - new_x = torch::cat(x_blocks, 0); - this->new_topk_idx = torch::cat(idx_blocks, 0); - - // padding expert_scales_optional - std::vector scales_blocks; - if (this->padding_cnt != PADDING_SIZE) { - scales_blocks.emplace_back(expert_scales_optional); - } - for (int i = 0; i < this->padding_cnt; i++) { - at::Tensor tmp_scales = torch::zeros({1, expert_scales_optional.size(1)}, expert_scales_optional.options()); - scales_blocks.emplace_back(tmp_scales); - } - new_scales = torch::cat(scales_blocks, 0); - } char hcom_ep_name[128]; if (!moe_all_to_all_group_name.empty()) { @@ -657,10 +617,9 @@ std::vector Buffer::fused_deep_moe(const at::Tensor &x, const at::Te } int64_t global_bs = std::max(new_topk_idx.size(0), num_max_dispatch_tokens_per_rank) * num_ranks; - auto x_shape = x.sizes(); int h = x_shape[1]; - int bs = this->new_topk_idx.size(0); + int bs = expert_ids.size(0); at::Tensor output = at::empty({bs, h}, x.options()); @@ -670,24 +629,14 @@ std::vector Buffer::fused_deep_moe(const at::Tensor &x, const at::Te EXEC_NPU_CMD(aclnnFusedDeepMoe, // input - new_x, this->new_topk_idx, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, - gmm2_weight_scale, static_cast(nullptr), new_scales, + x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, + gmm2_weight_scale, static_cast(nullptr), expert_scales_optional, // attr hcom_ep_name, num_ranks, rank, num_experts, shared_expert_num, shared_expert_rank_num, quant_mode, global_bs, // output output, ep_recv_count); - // ---------- unpadding ---------- - if (this->is_padding) { - if (expert_ids.size(0) == 0) { - output = this->ori_x; - } else { - output = output.slice(0, 0, PADDING_SIZE - this->padding_cnt); - } - this->is_padding = false; - } - return {output, ep_recv_count}; } } // namespace deep_ep From 72089a5adbb0be669cd04e1fe8529107dd2ca231 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Tue, 21 Oct 2025 11:37:46 +0800 Subject: [PATCH 5/6] 1.un padding for fused moe --- csrc/deepep/deep_ep.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/csrc/deepep/deep_ep.cpp b/csrc/deepep/deep_ep.cpp index e8d95ebd..564ebe0a 100644 --- a/csrc/deepep/deep_ep.cpp +++ b/csrc/deepep/deep_ep.cpp @@ -629,8 +629,8 @@ std::vector Buffer::fused_deep_moe(const at::Tensor &x, const at::Te EXEC_NPU_CMD(aclnnFusedDeepMoe, // input - x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, - gmm2_weight_scale, static_cast(nullptr), expert_scales_optional, + x, expert_ids, gmm1_permuted_weight, gmm1_permuted_weight_scale, gmm2_weight, gmm2_weight_scale, + static_cast(nullptr), expert_scales_optional, // attr hcom_ep_name, num_ranks, rank, num_experts, shared_expert_num, shared_expert_rank_num, quant_mode, global_bs, From b8ad6f30153f469fbef87dbcee6cc7f2e669ca56 Mon Sep 17 00:00:00 2001 From: Kaniel_Zhou Date: Tue, 21 Oct 2025 19:03:59 +0800 Subject: [PATCH 6/6] 1.un padding for fused moe --- csrc/deepep/deep_ep.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/csrc/deepep/deep_ep.cpp b/csrc/deepep/deep_ep.cpp index 564ebe0a..213adbc7 100644 --- a/csrc/deepep/deep_ep.cpp +++ b/csrc/deepep/deep_ep.cpp @@ -616,7 +616,7 @@ std::vector Buffer::fused_deep_moe(const at::Tensor &x, const at::Te HCCL_CHECK(HcclGetCommName(ep_comm, hcom_ep_name)); } - int64_t global_bs = std::max(new_topk_idx.size(0), num_max_dispatch_tokens_per_rank) * num_ranks; + int64_t global_bs = std::max(expert_ids.size(0), num_max_dispatch_tokens_per_rank) * num_ranks; auto x_shape = x.sizes(); int h = x_shape[1]; int bs = expert_ids.size(0);