diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index 6fc19aa2f1c..ae3956b4430 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -6,13 +6,39 @@ # Without approval from a member of this team, PRs cannot be merged to release branches. # * @NVIDIA/trt-llm-release-branch-approval +## TensorRT-LLM Infra +### CI +/jenkins @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs +### Setup +/docker @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs +### Github workflows +/.github @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs +/.coderabbit.yaml @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs + +## TensorRT-LLM - Docs +/docs @NVIDIA/trt-llm-doc-owners + +## Examples +/examples @NVIDIA/trt-llm-doc-owners + +## TensorRT-LLM - Triton backend +/triton_backend @NVIDIA/trt-llm-triton-backend-devs + # TensorRT-LLM Pytorch backend /tensorrt_llm/_torch @NVIDIA/trt-llm-torch-devs + +## TensorRT-LLM Pytorch - Modules +/tensorrt_llm/_torch/modules @NVIDIA/trt-llm-torch-modules + +## TensorRT-LLM Pytorch Models +/tensorrt_llm/_torch/models @NVIDIA/trt-llm-torch-models-devs +/examples/models @NVIDIA/trt-llm-torch-models-devs @NVIDIA/trt-llm-doc-owners + ## TensorRT-LLM Pytorch backend - runtime /tensorrt_llm/_torch/pyexecutor @NVIDIA/trt-llm-torch-runtime-devs ## TensorRT-LLM Pytorch backend - AutoDeploy flow /tensorrt_llm/_torch/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs -/tensorrt_llm/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs +/examples/auto_deploy @NVIDIA/trt-llm-torch-autodeploy-devs @NVIDIA/trt-llm-doc-owners ## TensorRT-LLM Pytorch - Speculative Decoding /tensorrt_llm/_torch/speculative @NVIDIA/trt-llm-torch-spec-decoding @@ -31,12 +57,6 @@ /tensorrt_llm/_torch/attention_backend @NVIDIA/trt-llm-torch-attention-devs /tensorrt_llm/_torch/modules/attention.py @NVIDIA/trt-llm-torch-attention-devs -## TensorRT-LLM Pytorch - Modules -/tensorrt_llm/_torch/modules @NVIDIA/trt-llm-torch-modules - - -## TensorRT-LLM Pytorch Models -/tensorrt_llm/_torch/models @NVIDIA/trt-llm-torch-models-devs ### TensorRT-LLM Pytorch - Models - Gemma /tensorrt_llm/_torch/models/modeling_gemma3.py @NVIDIA/trt-llm-torch-models-gemma-devs @NVIDIA/trt-llm-torch-models-devs @@ -108,8 +128,6 @@ /cpp/tensorrt_llm/runtime/loraUtils.cpp @NVIDIA/trt-llm-torch-peft /cpp/tensorrt_llm/runtime/loraUtils.h @NVIDIA/trt-llm-torch-peft -## TensorRT-LLM - Triton backend -/triton_backend @NVIDIA/trt-llm-triton-backend-devs ## TensorRT-LLM trtllm-bench Reviewers /tensorrt_llm/bench @NVIDIA/trtllm-bench-reviewers @@ -121,10 +139,9 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers /tensorrt_llm/executor @NVIDIA/trt-llm-llmapi-devs ## TensorRT-LLM LLM Disaggregated -/examples/disaggregated @NVIDIA/trt-llm-disagg-devs +/examples/disaggregated @NVIDIA/trt-llm-disagg-devs @NVIDIA/trt-llm-doc-owners /tensorrt_llm/disaggregated_params.py @NVIDIA/trt-llm-disagg-devs /tensorrt_llm/_torch/pyexecutor/kv_cache_transceiver.py @NVIDIA/trt-llm-disagg-devs -/tensorrt_llm/_torch/pyexecutor/py_executor.py @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/cacheFormatter.h @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/cacheTransBuffer.cpp @NVIDIA/trt-llm-disagg-devs @@ -135,19 +152,6 @@ docs/source/performance/perf-benchmarking.md @NVIDIA/trtllm-bench-reviewers /cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @NVIDIA/trt-llm-disagg-devs /cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h @NVIDIA/trt-llm-disagg-devs -## TensorRT-LLM Infra - -### CI -/jenkins @NVIDIA/trt-llm-ci-infra-devs @NVIDIA/trt-llm-infra-devs -### Setup -/docker @NVIDIA/trt-llm-setup-infra-devs @NVIDIA/trt-llm-infra-devs -### Github workflows -/tensorrt_llm/.github @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs -/tensorrt_llm/.coderabbit.yaml @NVIDIA/trt-llm-gh-workflows-infra-devs @NVIDIA/trt-llm-infra-devs - -## TensorRT-LLM - Docs -/docs @NVIDIA/trt-llm-doc-owners -/examples @NVIDIA/trt-llm-doc-owners # The rule below requires that any PR modifying public APIs must be approved by at least one member # of the NVIDIA/trt-llm-committed-api-review-committee or NVIDIA/trt-llm-noncommitted-api-review-committee team. diff --git a/README.md b/README.md index 5ab7fb51b7f..83cad6eb028 100644 --- a/README.md +++ b/README.md @@ -9,7 +9,7 @@ TensorRT-LLM [![python](https://img.shields.io/badge/python-3.10-green)](https://www.python.org/downloads/release/python-31012/) [![cuda](https://img.shields.io/badge/cuda-12.9.1-green)](https://developer.nvidia.com/cuda-downloads) [![trt](https://img.shields.io/badge/TRT-10.11.0-green)](https://developer.nvidia.com/tensorrt) -[![version](https://img.shields.io/badge/release-1.0.0rc6-green)](./tensorrt_llm/version.py) +[![version](https://img.shields.io/badge/release-1.1.0rc0-green)](./tensorrt_llm/version.py) [![license](https://img.shields.io/badge/license-Apache%202-blue)](./LICENSE) [Architecture](./docs/source/torch/arch_overview.md)   |   [Performance](./docs/source/performance/perf-overview.md)   |   [Examples](https://nvidia.github.io/TensorRT-LLM/quick-start-guide.html)   |   [Documentation](./docs/source/)   |   [Roadmap](https://github.com/NVIDIA/TensorRT-LLM/issues?q=is%3Aissue%20state%3Aopen%20label%3Aroadmap) diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h index a232230c4ff..09a96a56eee 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheEventManager.h @@ -18,6 +18,7 @@ #include "tensorrt_llm/executor/executor.h" +#include #include #include #include @@ -36,7 +37,8 @@ using BlockPtr = std::shared_ptr; class KVCacheEventManager { public: - explicit KVCacheEventManager(size_t maxKVEventEntries); + explicit KVCacheEventManager(size_t maxKVEventEntries, std::optional attentionDpRank = std::nullopt, + std::optional attentionDpSize = std::nullopt, SizeType32 attentionDpEventsGatherPeriodMs = 5); ~KVCacheEventManager(); KVCacheEventManager(KVCacheEventManager& other) = delete; @@ -61,14 +63,19 @@ class KVCacheEventManager // Worker thread which adds events to mEvents. void worker(); + // Thread which exchanges events if attentionDP is enabled + void exchangeAttentionDpThread(); + private: // Add an event to mEventQueue void enqueueEvent(executor::KVCacheEvent&& event); /// @brief Flag to terminate the worker - bool mRun; + std::atomic mRun; /// @brief Worker thread std::thread mWorkerThread; + /// @brief Exchange thread for attention DP events + std::thread mExchangeAttentionDpThread; /// @brief The deque of events std::deque mEvents; @@ -91,6 +98,17 @@ class KVCacheEventManager size_t mMaxSize; /// @brief An auto-incrementing event id counter size_t mEventId; + + /// @brief Attention DP ranks and size + /// If set, we will exchange KV cache events and accumulate on rank 0 + std::optional mAttentionDpRank; + std::optional mAttentionDpSize; + + /// @brief The period in milliseconds to gather attention DP events across rank + SizeType32 mAttentionDpEventsGatherPeriodMs; + + /// @brief MPI communicator for attention DP + std::unique_ptr mMpiComm; }; } // namespace tensorrt_llm::batch_manager::kv_cache_manager diff --git a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h index a0234cbbe49..a49527a6157 100644 --- a/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h +++ b/cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h @@ -536,8 +536,7 @@ class WindowBlockManager SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse); + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse); ~WindowBlockManager(); @@ -633,11 +632,6 @@ class WindowBlockManager return mAllBlocksById.at(blockId); } - [[nodiscard]] BlockMapIterRange getBlocksByHash(size_t hash) const - { - return mContextBlocksByHash.equal_range(hash); - } - [[nodiscard]] SizeType32 getTokensPerBlock() const noexcept { return mTokensPerBlock; @@ -723,10 +717,6 @@ class WindowBlockManager //! \param blockIds Id of each block. void storeBlocks(std::vector const& blockKeys, std::vector const& blockIds); - void addBlockToHashMap(BlockPtr const& block); - - void removeBlockFromHashMap(BlockPtr const& block); - [[nodiscard]] bool verifyQueueIntegrity(); // Only needed when sliding window attention + paged context fmha are used together. @@ -808,8 +798,6 @@ class WindowBlockManager SizeType32 mTokensPerBlock; // List of all blocks by idx std::vector mAllBlocksById; - // List of all context blocks by hash - BlockMap mContextBlocksByHash; // Dummy block acting as root for BlockToken searches BlockPtr mCachedBlocksRoot; // KV cache type (self or cross) @@ -841,8 +829,6 @@ class WindowBlockManager double mReusedTokens; // Total number of input tokens double mTotalInputTokens; - // Whether or not to maintain a hashmap of blocks. - bool mEnableHashKey; // Whether blocks that are partially matched should be reused. bool mEnablePartialReuse; // Whether partially matched blocks that are already in use should be copied and reused. @@ -863,8 +849,8 @@ class BlockManager std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, - std::shared_ptr eventManager = nullptr, bool enableHashKey = false, - bool enablePartialReuse = true, bool copyOnPartialReuse = true); + std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, + bool copyOnPartialReuse = true); BlockManager(BlockManager const&) = delete; BlockManager& operator=(BlockManager const&) = delete; @@ -1081,11 +1067,6 @@ class BlockManager return mWindowBlockManagers.at(windowSize).getBlockById(blockId); } - [[nodiscard]] WindowBlockManager::BlockMapIterRange getBlocksByHash(size_t hash, SizeType32 windowSize) const - { - return mWindowBlockManagers.at(windowSize).getBlocksByHash(hash); - } - [[nodiscard]] SizeType32 getNumPrimaryBlocks() const { return sumWindows([](auto const& manager) { return manager.getNumPrimaryBlocks(); }); @@ -1096,16 +1077,6 @@ class BlockManager return getPool(poolIdx).containsBlockScales; } - void addBlockToHashMap(BlockPtr const& block, SizeType32 windowSize) - { - mWindowBlockManagers.at(windowSize).addBlockToHashMap(block); - } - - void removeBlockFromHashMap(BlockPtr const& block, SizeType32 windowSize) - { - mWindowBlockManagers.at(windowSize).removeBlockFromHashMap(block); - } - //! \brief Store context blocks void storeContextBlocks(GenerationRequest& sequence, LlmRequest const& llmRequest); @@ -1385,8 +1356,8 @@ class KVCacheManager : public BaseKVCacheManager SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, bool enableBlockReuse = false, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, - std::shared_ptr eventManager = nullptr, bool enableHashKey = false, - bool enablePartialReuse = true, bool copyOnpartialReuse = true); + std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, + bool copyOnpartialReuse = true); KVCacheManager(std::vector const& numKvHeadsPerLayer, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, @@ -1405,8 +1376,8 @@ class KVCacheManager : public BaseKVCacheManager SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, bool enableBlockReuse = true, bool onboardBlocks = true, CacheType cacheType = CacheType::kSELF, std::optional secondaryOffloadMinPriority = std::nullopt, - std::shared_ptr eventManager = nullptr, bool enableHashKey = false, - bool enablePartialReuse = true, bool copyOnpartialReuse = true); + std::shared_ptr eventManager = nullptr, bool enablePartialReuse = true, + bool copyOnpartialReuse = true); KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, SizeType32 sizePerHead, SizeType32 tokensPerBlock, BlocksPerWindow const& blocksPerWindow, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, @@ -1692,8 +1663,6 @@ class KVCacheManager : public BaseKVCacheManager std::unordered_map mSequences; // Whether to cache KV pages for reuse bool mEnableBlockReuse; - // Whether enable finding blocks by their hash, ignored when reuse enabled - bool mEnableHashKey; // Mutex to protect access to mSequences mutable std::mutex mSequencesMtx; // buffers for static tensors, will be created after allocating pools diff --git a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h index 0d087d96c0f..e4d13c9e17b 100644 --- a/cpp/include/tensorrt_llm/batch_manager/llmRequest.h +++ b/cpp/include/tensorrt_llm/batch_manager/llmRequest.h @@ -828,8 +828,10 @@ class GenericLlmRequest // for enc-dec models, pause means saving generated tokens to prompt but need to re-do encoder phase mState = mEncoderTokens.has_value() || mEncoderInputFeatures ? LlmRequestState::kENCODER_INIT : LlmRequestState::kCONTEXT_INIT; - mContextCurrentPosition = 0; - mPrepopulatedPromptLen = 0; + mContextCurrentPositionTarget = 0; + mContextCurrentPositionDraft = 0; + mPrepopulatedPromptLenTarget = 0; + mPrepopulatedPromptLenDraft = 0; mContextChunkSize = mPromptLen; mSeqSlot.reset(); } @@ -1049,7 +1051,7 @@ class GenericLlmRequest [[nodiscard]] SizeType32 getPrepopulatedPromptLen() const { - return mPrepopulatedPromptLen; + return mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget; } void setPrepopulatedPromptLen(SizeType32 prepopulatedPromptLen, SizeType32 kvTokensPerBlock) @@ -1066,7 +1068,10 @@ class GenericLlmRequest "Invalid state: prepopulatedPromptLen (%d) >= promptLen (%d) for request %lu", prepopulatedPromptLen, promptLen, mRequestId); TLLM_CHECK(prepopulatedPromptLen < promptLen); - mPrepopulatedPromptLen = prepopulatedPromptLen; + + auto& prePromptLen = mUseDraftModel ? mPrepopulatedPromptLenDraft : mPrepopulatedPromptLenTarget; + auto& contextCurrentPosition = mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget; + prePromptLen = prepopulatedPromptLen; if (prepopulatedPromptLen > 0) { @@ -1081,7 +1086,7 @@ class GenericLlmRequest chunkSize = flooredEndPosition - prepopulatedPromptLen; TLLM_CHECK(chunkSize <= getContextChunkSize()); } - setContextCurrentPosition(prepopulatedPromptLen); + contextCurrentPosition = prepopulatedPromptLen; setContextChunkSize(chunkSize); if (!isLastContextChunk()) @@ -1522,14 +1527,15 @@ class GenericLlmRequest void setContextCurrentPosition(SizeType32 contextCurrentPosition) { - mContextCurrentPosition = contextCurrentPosition; + mContextCurrentPositionDraft = contextCurrentPosition; + mContextCurrentPositionTarget = contextCurrentPosition; } /// When chunked, the position of the current chunk is returned. Otherwise, only the beginning /// or end of the context is returned. [[nodiscard]] SizeType32 getContextCurrentPosition() const noexcept { - return mContextCurrentPosition; + return mUseDraftModel ? mContextCurrentPositionDraft : mContextCurrentPositionTarget; } /// Return the length of the context that has not yet been processed. @@ -1570,14 +1576,16 @@ class GenericLlmRequest { // The number of cached token is encountered in mContextCurrentPosition, // so the start position of the context is mPrepopulatedPromptLen. - return mContextCurrentPosition == mPrepopulatedPromptLen; + return getContextCurrentPosition() == getPrepopulatedPromptLen(); } /// Move the cursor forward one chunk. When not chunked, move forward to the end of the context. void moveToNextContextChunk() { TLLM_CHECK_WITH_INFO(isContextInitState(), "Chunking is only possible during the context phase."); - mContextCurrentPosition += getContextChunkSize(); + + mContextCurrentPositionDraft += getContextChunkSize(); + mContextCurrentPositionTarget += getContextChunkSize(); setContextChunkSize(0); } @@ -1843,6 +1851,16 @@ class GenericLlmRequest return mIsDummyRequest; } + void setUseDraftModel(bool useDraftModel) + { + mUseDraftModel = useDraftModel; + } + + [[nodiscard]] bool useDraftModel() const + { + return mUseDraftModel; + } + RequestIdType mRequestId; SizeType32 mPromptLen; SizeType32 mMaxNewTokens; @@ -1885,7 +1903,8 @@ class GenericLlmRequest // Number of tokens already in KV cache before context phase. // A value > 0 indicates cached KV cache blocks were reused. // Up to inputLen - 1 tokens can be reused. - SizeType32 mPrepopulatedPromptLen{0}; + SizeType32 mPrepopulatedPromptLenTarget{0}; + SizeType32 mPrepopulatedPromptLenDraft{0}; SizeType32 mMaxSentTokenLen; @@ -1916,7 +1935,8 @@ class GenericLlmRequest // The size of the context chunk must be multiple of the KV-Cache block size except the last one. // Value `0` means Chunked-Context is disabled. SizeType32 mContextChunkSize{0}; - SizeType32 mContextCurrentPosition{0}; + SizeType32 mContextCurrentPositionTarget{0}; + SizeType32 mContextCurrentPositionDraft{0}; std::vector mLogProbs; // [beamSize, seqLen] VecLogProbs mCumLogProbs; // [beamSize] @@ -2017,6 +2037,8 @@ class GenericLlmRequest bool mIsDummyRequest{false}; + bool mUseDraftModel{false}; + private: void initialize(VecTokens const& inputTokens, bool outputLogProbs) { @@ -2027,7 +2049,7 @@ class GenericLlmRequest // Scatter the input tokens to other beam mTokens = BeamTokens(mSamplingConfig.beamWidth, inputTokens); - mLastTokens = VecTokens(mSamplingConfig.beamWidth); + mLastTokens = VecTokens(mSamplingConfig.beamWidth, inputTokens.back()); // Init mUniqueTokens VecUniqueTokens uniqueTokens{inputTokens.size()}; @@ -2347,6 +2369,9 @@ class LlmRequest : public GenericLlmRequest void movePromptEmbeddingTableToGpu(runtime::BufferManager const& manager); void moveLoraWeightsToGpu(runtime::BufferManager const& manager); + + // Remove LoRA weights and LoRA config tensors + void removeLoraTensors(); }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/include/tensorrt_llm/common/quantization.h b/cpp/include/tensorrt_llm/common/quantization.h index 836faa258fe..50aae114e0c 100644 --- a/cpp/include/tensorrt_llm/common/quantization.h +++ b/cpp/include/tensorrt_llm/common/quantization.h @@ -122,6 +122,16 @@ class QuantMode return QuantMode(BaseType(1u) << 14); } + static constexpr QuantMode w4a8Mxfp4Mxfp8() noexcept + { + return QuantMode(BaseType(1u) << 15); + } + + static constexpr QuantMode w4a16Mxfp4() noexcept + { + return QuantMode(BaseType(1u) << 16); + } + constexpr BaseType value() const noexcept { return mValue; @@ -202,6 +212,16 @@ class QuantMode return isSet(w4a8Mxfp4Fp8()); } + constexpr bool hasW4a8Mxfp4Mxfp8() const noexcept + { + return isSet(w4a8Mxfp4Mxfp8()); + } + + constexpr bool hasW4a16Mxfp4() const noexcept + { + return isSet(w4a16Mxfp4()); + } + constexpr bool hasKvCacheQuant() const noexcept { return hasInt8KvCache() || hasFp8KvCache() || hasFp4KvCache(); @@ -209,7 +229,8 @@ class QuantMode static constexpr QuantMode fromDescription(bool quantizeWeights, bool quantizeActivations, bool perToken, bool perChannel, bool perGroup, bool useInt4Weights, bool useInt8KvCache, bool useFp8KvCache, bool useFp8Qdq, - bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8) + bool useFp8RowWise, bool useW4a8QServe, bool useFp4Quant, bool useFp8BlockScales, bool useW4a8Mxfp4Fp8, + bool useW4a8Mxfp4Mxfp8, bool useW4a16Mxfp4) { QuantMode quantMode{}; if (quantizeWeights) @@ -278,25 +299,35 @@ class QuantMode quantMode += w4a8Mxfp4Fp8(); } + if (useW4a8Mxfp4Mxfp8) + { + quantMode += w4a8Mxfp4Mxfp8(); + } + + if (useW4a16Mxfp4) + { + quantMode += w4a16Mxfp4(); + } + return quantMode; } static constexpr QuantMode useSmoothQuant(bool perToken = false, bool perChannel = false) { - return fromDescription( - true, true, perToken, perChannel, false, false, false, false, false, false, false, false, false, false); + return fromDescription(true, true, perToken, perChannel, false, false, false, false, false, false, false, false, + false, false, false, false); } static constexpr QuantMode useQServe(bool perGroup) { - return fromDescription( - true, true, false, false, perGroup, true, false, false, false, false, true, false, false, false); + return fromDescription(true, true, false, false, perGroup, true, false, false, false, false, true, false, false, + false, false, false); } static constexpr QuantMode useWeightOnly(bool useInt4Weights = false, bool perGroup = false) { return fromDescription(true, false, false, false, perGroup, useInt4Weights, false, false, false, false, false, - false, false, false); + false, false, false, false, false); } static QuantMode const fromQuantAlgo( @@ -353,28 +384,38 @@ class QuantMode } else if (quantAlgo == "FP8") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, true, false, false, false, false, false); + quantMode = fromDescription(false, false, false, false, false, false, false, false, true, false, false, + false, false, false, false, false); } else if (quantAlgo == "FP8_ROWWISE") { - quantMode = fromDescription( - false, false, true, true, false, false, false, false, false, true, false, false, false, false); + quantMode = fromDescription(false, false, true, true, false, false, false, false, false, true, false, false, + false, false, false, false); } else if (quantAlgo == "FP4") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, false, false, false, true, false, false); + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + true, false, false, false, false); } else if (quantAlgo == "FP8_BLOCK_SCALES") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, false, false, false, false, true, false); + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, true, false, false, false); } else if (quantAlgo == "W4A8_MXFP4_FP8") { - quantMode = fromDescription( - false, false, false, false, false, false, false, false, false, false, false, false, false, true); + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, false, true, false, false); + } + else if (quantAlgo == "W4A8_MXFP4_MXFP8") + { + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, false, false, true, false); + } + else if (quantAlgo == "W4A16_MXFP4") + { + quantMode = fromDescription(false, false, false, false, false, false, false, false, false, false, false, + false, false, false, false, true); } if (kvCacheQuantAlgo == "INT8") diff --git a/cpp/include/tensorrt_llm/executor/executor.h b/cpp/include/tensorrt_llm/executor/executor.h index 6d592654ffd..0a58298c279 100644 --- a/cpp/include/tensorrt_llm/executor/executor.h +++ b/cpp/include/tensorrt_llm/executor/executor.h @@ -1001,6 +1001,7 @@ class KvCacheConfig std::optional const& crossKvCacheFraction = std::nullopt, std::optional secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0, bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false, + SizeType32 attentionDpEventsGatherPeriodMs = 5, std::optional const& runtimeDefaults = std::nullopt); [[nodiscard]] bool getEnableBlockReuse() const; @@ -1016,6 +1017,7 @@ class KvCacheConfig [[nodiscard]] std::optional getSecondaryOffloadMinPriority() const; [[nodiscard]] size_t getEventBufferMaxSize() const; [[nodiscard]] bool getUseUvm() const; + [[nodiscard]] SizeType32 getAttentionDpEventsGatherPeriodMs() const; void setEnableBlockReuse(bool enableBlockReuse); void setEnablePartialReuse(bool enablePartialReuse); @@ -1030,6 +1032,7 @@ class KvCacheConfig void setSecondaryOffloadMinPriority(std::optional secondaryOffloadMinPriority); void setEventBufferMaxSize(size_t eventBufferMaxSize); void setUseUvm(bool useUvm); + void setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs); void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults); @@ -1085,6 +1088,9 @@ class KvCacheConfig /// @brief Whether to use UVM for the KV cache. bool mUseUvm; + + /// @brief The period in milliseconds to gather attention DP events across ranks + SizeType32 mAttentionDpEventsGatherPeriodMs; }; /// @brief Configuration class for the runtime perf knobs @@ -1702,6 +1708,12 @@ struct KVCacheUpdatedData explicit KVCacheUpdatedData(IdType blockHash) : blockHash{blockHash} {}; + explicit KVCacheUpdatedData(IdType blockHash, std::optional> cacheLevel, + std::optional> priority) + : blockHash{blockHash} + , cacheLevel{cacheLevel} + , priority{priority} {}; + KVCacheUpdatedData& cacheLevelUpdated(SizeType32 oldValue, SizeType32 newValue) { cacheLevel = KVCacheEventDiff{oldValue, newValue}; @@ -1726,8 +1738,8 @@ using KVCacheEventData = std::variant attentionDpRank = std::nullopt); /// @brief The unique id of this event IdType eventId; @@ -1735,6 +1747,8 @@ struct KVCacheEvent KVCacheEventData data; /// @brief The sliding window size SizeType32 windowSize; + /// @brief The attention DP rank of the event, if applicable + std::optional attentionDpRank; }; /// @brief Exposes a limited set of KV cache manager functionalities diff --git a/cpp/include/tensorrt_llm/executor/serialization.h b/cpp/include/tensorrt_llm/executor/serialization.h index b2ecfc66c84..c370a652350 100644 --- a/cpp/include/tensorrt_llm/executor/serialization.h +++ b/cpp/include/tensorrt_llm/executor/serialization.h @@ -302,6 +302,53 @@ class Serialization [[nodiscard]] static std::vector deserializeRequestStatsPerIterationVec( std::vector& buffer); + // KVCacheEvent deque + [[nodiscard]] static std::vector serialize(std::deque const& kvCacheEvents); + [[nodiscard]] static std::deque deserializeKVCacheEvents(std::vector& buffer); + + // KVCacheEvent + [[nodiscard]] static size_t serializedSize(KVCacheEvent const& event); + static void serialize(KVCacheEvent const& event, std::ostream& os); + [[nodiscard]] static KVCacheEvent deserializeKVCacheEvent(std::istream& is); + + // KVCacheCreatedData + [[nodiscard]] static size_t serializedSize(KVCacheCreatedData const& data); + static void serialize(KVCacheCreatedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheCreatedData deserializeKVCacheCreatedData(std::istream& is); + + // KVCacheStoredData + [[nodiscard]] static size_t serializedSize(KVCacheStoredData const& data); + static void serialize(KVCacheStoredData const& data, std::ostream& os); + [[nodiscard]] static KVCacheStoredData deserializeKVCacheStoredData(std::istream& is); + + // KVCacheStoredBlockData + [[nodiscard]] static size_t serializedSize(KVCacheStoredBlockData const& data); + static void serialize(KVCacheStoredBlockData const& data, std::ostream& os); + [[nodiscard]] static KVCacheStoredBlockData deserializeKVCacheStoredBlockData(std::istream& is); + + // KVCacheRemovedData + [[nodiscard]] static size_t serializedSize(KVCacheRemovedData const& data); + static void serialize(KVCacheRemovedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheRemovedData deserializeKVCacheRemovedData(std::istream& is); + + // KVCacheEventDiff + template + [[nodiscard]] static size_t serializedSize(KVCacheEventDiff const& data); + template + static void serialize(KVCacheEventDiff const& data, std::ostream& os); + template + [[nodiscard]] static KVCacheEventDiff deserializeKVCacheEventDiff(std::istream& is); + + // KVCacheUpdateData + [[nodiscard]] static size_t serializedSize(KVCacheUpdatedData const& data); + static void serialize(KVCacheUpdatedData const& data, std::ostream& os); + [[nodiscard]] static KVCacheUpdatedData deserializeKVCacheUpdatedData(std::istream& is); + + // UniqueToken + [[nodiscard]] static size_t serializedSize(tensorrt_llm::runtime::UniqueToken const& token); + static void serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os); + [[nodiscard]] static tensorrt_llm::runtime::UniqueToken deserializeUniqueToken(std::istream& is); + // String static std::string deserializeString(std::istream& is); diff --git a/cpp/include/tensorrt_llm/runtime/decoderState.h b/cpp/include/tensorrt_llm/runtime/decoderState.h index e4fe9c38010..8166156a0cc 100644 --- a/cpp/include/tensorrt_llm/runtime/decoderState.h +++ b/cpp/include/tensorrt_llm/runtime/decoderState.h @@ -51,13 +51,13 @@ class DecoderState DecoderState(); //! @brief Setup buffers for the decoder excluding speculative decoding. - void setup(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, + void setup(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, SizeType32 sinkTokenLength, SizeType32 maxSequenceLength, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager); //! @brief Setup buffers for the cache indirection. //! @details This is used for beam search on pipeline parallel ranks without a decoder. - void setupCacheIndirection(SizeType32 maxBatchSize, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, + void setupCacheIndirection(SizeType32 maxNumSequences, SizeType32 maxBeamWidth, SizeType32 maxAttentionWindow, BufferManager const& bufferManager); //! @brief Setup buffers for speculative decoding. @@ -134,7 +134,7 @@ class DecoderState //! @returns [batchSize, maxAcceptedDraftTokensPerStep], accepted paths packed into continuous tensor, on gpu [[nodiscard]] TensorPtr getAcceptedPackedPaths() const; - [[nodiscard]] SizeType32 getMaxBatchSize() const; + [[nodiscard]] SizeType32 getMaxNumSequences() const; [[nodiscard]] SizeType32 getMaxBeamWidth() const; @@ -187,10 +187,10 @@ class DecoderState //! @param generationSteps The generation steps for all requests in the batch. void setGenerationSteps(std::vector const& generationSteps); - //! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful inputs for the decoder. Allocated for maxNumSequences slots. [[nodiscard]] DecodingInput& getJointDecodingInput() const; - //! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful outputs for the decoder. Allocated for maxNumSequences slots. [[nodiscard]] DecodingOutput& getJointDecodingOutput() const; private: @@ -209,13 +209,13 @@ class DecoderState SizeType32 maxTokensPerEngineStep, ModelConfig const& modelConfig, WorldConfig const& worldConfig, BufferManager const& bufferManager); - SizeType32 mMaxBatchSize{}; + SizeType32 mMaxNumSequences{}; SizeType32 mMaxBeamWidth{}; SizeType32 mMaxSequenceLength{}; - //! @brief Stateful inputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful inputs for the decoder. Allocated for maxNumSequences slots. DecodingInputPtr mJointDecodingInput; - //! @brief Stateful outputs for the decoder. Allocated for maxBatchSize slots. + //! @brief Stateful outputs for the decoder. Allocated for maxNumSequences slots. DecodingOutputPtr mJointDecodingOutput; //! @brief Workspace for beam search in streaming mode. diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoder.h b/cpp/include/tensorrt_llm/runtime/gptDecoder.h index 90690c90fc0..7e0cc1bb56d 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoder.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoder.h @@ -71,7 +71,7 @@ class IGptDecoder = 0; static std::unique_ptr create(executor::DecodingMode const& mode, nvinfer1::DataType dtype, - size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, + size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream, std::shared_ptr const& speculativeDecodingModule = nullptr); }; @@ -84,7 +84,7 @@ class GptDecoder : public virtual IGptDecoder using CudaStreamPtr = BufferManager::CudaStreamPtr; using TensorPtr = std::shared_ptr; - GptDecoder(executor::DecodingMode const& mode, size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, + GptDecoder(executor::DecodingMode const& mode, size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, CudaStreamPtr const& stream, std::shared_ptr speculativeDecodingModule = nullptr); @@ -114,7 +114,7 @@ class GptDecoder : public virtual IGptDecoder SamplingConfig mSamplingConfig; - size_t mMaxBatchSize; + size_t mMaxNumSequences; size_t mVocabSize; size_t mVocabSizePadded; @@ -122,7 +122,7 @@ class GptDecoder : public virtual IGptDecoder }; inline std::unique_ptr IGptDecoder::create(executor::DecodingMode const& mode, nvinfer1::DataType dtype, - size_t maxBatchSize, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, + size_t maxNumSequences, size_t maxBeamWidth, size_t vocabSize, size_t vocabSizePadded, BufferManager::CudaStreamPtr const& stream, std::shared_ptr const& speculativeDecodingModule) { @@ -130,10 +130,10 @@ inline std::unique_ptr IGptDecoder::create(executor::DecodingMode c { case nvinfer1::DataType::kFLOAT: return std::make_unique>( - mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); + mode, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); case nvinfer1::DataType::kHALF: return std::make_unique>( - mode, maxBatchSize, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); + mode, maxNumSequences, maxBeamWidth, vocabSize, vocabSizePadded, stream, speculativeDecodingModule); default: TLLM_THROW("Unsupported decoder data type: %d. Use either kFLOAT or kHALF.", static_cast(dtype)); return nullptr; diff --git a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h index d5dfe9b7b19..d0a9e726d13 100644 --- a/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/gptDecoderBatched.h @@ -47,7 +47,7 @@ class GptDecoderBatched : public IGptDecoderBatched explicit GptDecoderBatched(CudaStreamPtr stream); - void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth, + void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) override; void disableLookahead(RequestVector const& genRequests, TensorPtr const& batchSlots) override; diff --git a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h index 327af71f8a7..606ba3c98a4 100644 --- a/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h +++ b/cpp/include/tensorrt_llm/runtime/iGptDecoderBatched.h @@ -86,7 +86,7 @@ class IGptDecoderBatched using TensorPtr = std::shared_ptr; //! @brief Setup the decoder before calling `forward()` - virtual void setup(executor::DecodingMode const& mode, SizeType32 maxBatchSize, SizeType32 maxBeamWidth, + virtual void setup(executor::DecodingMode const& mode, SizeType32 maxNumSequences, SizeType32 maxBeamWidth, nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig) = 0; diff --git a/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h b/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h index 4443d422ab8..32c086c84ee 100644 --- a/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h +++ b/cpp/include/tensorrt_llm/runtime/utils/mpiTags.h @@ -68,6 +68,10 @@ enum class MpiTag : int // LogitsThread kSpecDecLogitsId = 129, kSpecDecLogitsData = 1025, + + // KvCacheEventManager + kKvCacheEventSize = 1026, + kKvCacheEvent = 1027 }; } // namespace tensorrt_llm::mpi diff --git a/cpp/kernels/fmha_v2/fmha_test.py b/cpp/kernels/fmha_v2/fmha_test.py index f9f28978e66..d02e3cc31c0 100644 --- a/cpp/kernels/fmha_v2/fmha_test.py +++ b/cpp/kernels/fmha_v2/fmha_test.py @@ -1,7 +1,12 @@ import subprocess import pytest -from cuda import cuda, nvrtc + +try: + from cuda.bindings import driver as cuda + from cuda.bindings import nvrtc +except ImportError: + from cuda import cuda, nvrtc def ASSERT_DRV(err): @@ -50,7 +55,7 @@ def getSMVersion(): ids=["fp16", "bf16", "fp16-fp32", "e4m3"]) @pytest.mark.parametrize('flag', [ "-s-q 128 -paged-kv", "-s-q 63 -paged-kv", "-paged-kv", - "-softcapping-scale-bmm1 30", "-contiguous-q-kv" + "-softcapping-scale-bmm1 30", "-contiguous-q-kv", "-use-attention-sinks" ]) @pytest.mark.parametrize('tiled_kernel', ["", "-force-non-tiled"]) def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel): @@ -117,8 +122,8 @@ def test_trtllm_flash_attention_fmha(d, s, dtype, flag, tiled_kernel): f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -custom-mask -gqa 2 -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, check=True) - # alibi and softcapping-scale-bmm1 are mutually exclusive. - if '-softcapping-scale-bmm1' not in flag: + # alibi doesn't work with softcapping-scale-bmm1/use-attention-sinks. + if '-softcapping-scale-bmm1' not in flag and '-use-attention-sinks' not in flag: subprocess.run( f"bin/fmha.exe -d {d} -h 16 -b 8 -s {s} -min-s 128 -causal-mask -alibi -v {verbose} {dtype} {epsilon} {flag} {tiled_kernel}", shell=True, diff --git a/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h b/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h index 65e56dbf5de..eed6f852da3 100644 --- a/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h +++ b/cpp/kernels/fmha_v2/src/fmha/warpspec/compute.h @@ -326,9 +326,6 @@ struct Compute uint32_t smem_v = __cvta_generic_to_shared(&shared->smem_v[0]); Compute_tile_o ctile_o(0, smem_v); - // BMM2 epilogue - Tile_o_epilogue tile_o_epilogue(params); - // Mutex between two compute groups. OrderedMutexAccessor mutex_accessor(shared->compute_mutex, warpgroup_id, SYNC_BARRIER); // Notify warpgroup 0 to execute HGMMA first (overlap HGMMA and Softmax Math Instructions). @@ -368,6 +365,9 @@ struct Compute sage_scale_row = head_info.bidb * params.h + head_info.bidh; } + // BMM2 epilogue + Tile_o_epilogue tile_o_epilogue(params, head_info); + int q_step_idx = warpgroup_id; // Compute work. @@ -490,7 +490,7 @@ struct Compute if (valid_run) { // Final step's update. - tile_o_epilogue.scale(ctile_o, p_sum); + tile_o_epilogue.scale(ctile_o, p_max, p_sum); // Store o_tile to gmem. gmem_o.store(ctile_o.acc_); } diff --git a/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h b/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h index 217e8c08722..99ea1643cd0 100644 --- a/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h +++ b/cpp/kernels/fmha_v2/src/fmha/warpspec/epilogue.h @@ -454,7 +454,7 @@ struct Softmax_base #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { - uint32_t const scale = float_to_half2(correction_[mi]); + const uint32_t scale = float_to_half2(correction_[mi]); // Assume only N has multiple MMAs (MMAS_M = 1). // MMAS_N > 1 when N dimension is split. @@ -477,9 +477,15 @@ struct Softmax_base } // BMM1 scale. - uint32_t const scale_bmm1_; + const uint32_t scale_bmm1_; // BMM1 softcapping scale. float const softcapping_scale_bmm1_; + + // The sliding window size. + int const sliding_window_size_; + // The log2 attention chunk size. + int const log2_chunked_attention_size_; + // The thread idx in the warp group. int tidx_; // The col index for the mma thread layout. @@ -487,15 +493,10 @@ struct Softmax_base // The row index for the mma thread layout. int quad_row_; - // The sliding window size. - int const sliding_window_size_; - // The log2 attention chunk size. - int const log2_chunked_attention_size_; - // The packed mask ptr. uint32_t const* packed_mask_ptr_; // The packed mask k-dim stride in bytes; - int64_t const params_packed_mask_stride_in_bytes_; + const int64_t params_packed_mask_stride_in_bytes_; // Unpacked BMM1 output buffer. float elt_[Mma_tile_p::CORES_M][Mma_tile_p::CORES_N * 2]; @@ -1072,20 +1073,53 @@ struct Tile_o_epilogue_base // The MMA tile for the BMM2. using Mma_tile_o = typename Kernel_traits::Mma_tile_o; - template - inline __device__ Tile_o_epilogue_base(Params const& params) + // Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs). + enum + { + EXP2F_OPTIMIZATION = Kernel_traits::EXP2F_OPTIMIZATION + }; + + template + inline __device__ Tile_o_epilogue_base(Params const& params, Block_info& block_info) { - ; // nothing to construct. + has_attention_sink_ = params.attention_sinks != nullptr; + head_idx_ = block_info.bidh; + attention_sink_ = has_attention_sink_ ? params.attention_sinks[block_info.bidh] : 0.f; + // It is only need when the exp2f optimization is enabled, so params.scale_bmm1 is always float. + scale_bmm1_f_ = reinterpret_cast(params.scale_bmm1_d ? *params.scale_bmm1_d : params.scale_bmm1); }; + // The attention sinks. + inline __device__ void add_attention_sink(float& sum, float max) + { + if (has_attention_sink_) + { + // The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled. + if constexpr (EXP2F_OPTIMIZATION) + { + sum += exp2f(attention_sink_ * M_LOG2E - max * scale_bmm1_f_); + } + else + { + sum += expf(attention_sink_ - max); + } + } + } + // Scale ctile_o output by 1/sum - inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M]) + inline __device__ void scale( + Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M]) { // Final step's update. #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { - global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi]; + // The global sum. + float global_sum_mi = global_sum[mi]; + // Add the attention sink to the global sum. + add_attention_sink(global_sum_mi, global_max[mi]); + // The scale. + float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi; // Assume only N has multiple MMAs (MMAS_M = 1). #pragma unroll @@ -1096,12 +1130,21 @@ struct Tile_o_epilogue_base { float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi); float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1); - reg0 *= global_sum[mi]; - reg1 *= global_sum[mi]; + reg0 *= scale; + reg1 *= scale; } } } } + + // Whether the attention sink is enabled. + bool has_attention_sink_ = false; + // The attention sink value. + float attention_sink_ = 0.f; + // The float scale of bmm1 outputs. + float scale_bmm1_f_ = 1.f; + // The head idx. + int head_idx_ = 0; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -1138,14 +1181,21 @@ struct Tile_o_epilogue using Base::Tile_o_epilogue_base; // Scale ctile_o output by 1/sum - inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M]) + inline __device__ void scale( + Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M]) { // Final step's update. #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { - global_sum[mi] = global_sum[mi] == 0.f ? 1.f : 1.0f / global_sum[mi]; - uint32_t const scale = float_to_half2(global_sum[mi]); + // The global sum. + float global_sum_mi = global_sum[mi]; + // Add the attention sink to the global sum. + this->add_attention_sink(global_sum_mi, global_max[mi]); + // The scale. + float scale = global_sum_mi == 0.f ? 1.f : 1.0f / global_sum_mi; + // The scale. + const uint32_t scale_h = float_to_half2(scale); // Assume only N has multiple MMAs (MMAS_M = 1). #pragma unroll @@ -1155,7 +1205,7 @@ struct Tile_o_epilogue for (int ni = 0; ni < Mma_tile_o::CORES_N; ni++) { uint32_t& reg = ctile_o.acc_[0][mma_ni].reg(ni * Mma_tile_o::CORES_M + mi); - reg = hmul2(reg, scale); + reg = hmul2(reg, scale_h); } } } @@ -1215,27 +1265,58 @@ struct Tile_o_epilogue // The MMA tile for the BMM2. using Mma_tile_o = typename Base::Mma_tile_o; + // Apply the exp2f optimization (fuse bmm1_scale and -max into FMAs). + enum + { + EXP2F_OPTIMIZATION = Base::EXP2F_OPTIMIZATION + }; + // Ctor. - template - inline __device__ Tile_o_epilogue(Params const& params) - : Base(params) + template + inline __device__ Tile_o_epilogue(Params const& params, Block_info& block_info) + : Base(params, block_info) , scale_bmm2_(*params.scale_bmm2_d) { } + // Add the attention sink to the global sum. + inline __device__ void add_attention_sink(float& sum, float max) + { + if (this->has_attention_sink_) + { + // The global max needs to be scaled by the bmm1 scale if exp2f optimization is enabled. + // Take the log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE) into account as the same scale has been applied to sum. + float quant_scale_in_log2 = log2f(Traits_o::SOFTMAX_FP_QUANT_SCALE); + if constexpr (EXP2F_OPTIMIZATION) + { + sum += exp2f(this->attention_sink_ * M_LOG2E - max * this->scale_bmm1_f_ + quant_scale_in_log2); + } + else + { + sum += expf(this->attention_sink_ - max + quant_scale_in_log2); + } + } + } + // Scale ctile_o output by 1/sum - inline __device__ void scale(Compute_tile_o& ctile_o, float (&global_sum)[Mma_tile_o::CORES_M]) + inline __device__ void scale( + Compute_tile_o& ctile_o, float (&global_max)[Mma_tile_o::CORES_M], float (&global_sum)[Mma_tile_o::CORES_M]) { // Final step's update. #pragma unroll for (int mi = 0; mi < Mma_tile_o::CORES_M; mi++) { + // The global sum. + float global_sum_mi = global_sum[mi]; + // Add the attention sink to the global sum. + add_attention_sink(global_sum_mi, global_max[mi]); #ifdef UNIFIED_EPILOGUE_SCALE // Descaling factor float const scale_bmm2_f_ = reinterpret_cast(scale_bmm2_); - global_sum[mi] = global_sum[mi] == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum[mi]; + // The scale. + float scale = global_sum_mi == 0.f ? scale_bmm2_f_ : scale_bmm2_f_ / global_sum_mi; #else - global_sum[mi] = global_sum[mi] == 0.f ? 1.0f : 1.0f / global_sum[mi]; + float scale = global_sum_mi == 0.f ? 1.0f : 1.0f / global_sum_mi; #endif // Assume only N has multiple MMAs (MMAS_M = 1). #pragma unroll @@ -1246,8 +1327,8 @@ struct Tile_o_epilogue { float& reg0 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi); float& reg1 = ctile_o.acc_[0][mma_ni].elt(2 * ni * Mma_tile_o::CORES_M + 2 * mi + 1); - reg0 *= global_sum[mi]; - reg1 *= global_sum[mi]; + reg0 *= scale; + reg1 *= scale; } } } diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp b/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp index 6d9811ac071..6cf52fcf4c9 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention.cpp @@ -29,30 +29,33 @@ using Kv_block_array = fmha::Kv_block_array; //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi); +void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, + int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n, - bool has_alibi); +void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, + float softcapping_scale_bmm1, int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,11 +84,11 @@ void run_sage_quant(unsigned int batch_size, unsigned int head_num, unsigned int //////////////////////////////////////////////////////////////////////////////////////////////////// -void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type, +void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type, float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1, - void* qkv_d, void* vt_d, void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d, - void* cu_q_seqlens_d, size_t const b, size_t const s, size_t const h, size_t const d, size_t const dv, - int const runs, int const warps_m, int const warps_n, bool const has_alibi) + void* qkv_d, void* vt_d, void* mask_d, void* attention_sinks_d, void* p_d, void* s_d, void* tmp_d, void* o_d, + void* softmax_sum_d, void* cu_q_seqlens_d, const size_t b, const size_t s, const size_t h, const size_t d, + const size_t dv, int const runs, int const warps_m, int const warps_n, bool const has_alibi) { cudaStream_t stream = 0; @@ -106,28 +109,28 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty // Softmax. if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16) { - run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1, - warps_n, has_alibi); + run_softmax_fp16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_BF16 && acc_type == DATA_TYPE_FP32) { - run_softmax_bf16(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1, - warps_n, has_alibi); + run_softmax_bf16(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32) { - run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, softcapping_scale_bmm1, - warps_n, has_alibi); + run_softmax_fp32(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32) { - run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_softmax, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax_e4m3(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, + scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32) { - run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1, scale_softmax, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax_int8(s_d, p_d, mask_d, attention_sinks_d, softmax_sum_d, cu_q_seqlens_d, s, s, b, h, scale_bmm1, + scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } else { @@ -179,7 +182,7 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params, // types Data_type data_type, Data_type acc_type, // sizes - size_t const b, size_t const s, size_t const h, size_t const d, size_t const packed_mask_stride, + const size_t b, const size_t s, const size_t h, const size_t d, const size_t packed_mask_stride, // device pointers void* qkv_d, void* packed_mask_d, void* o_d, void* p_d, void* s_d, // scale factors @@ -235,17 +238,17 @@ static inline void set_params(bert::Fused_multihead_attention_params_v1& params, //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline void set_params(bert::Fused_multihead_attention_params_v2& params, Launch_params const launch_params, +static inline void set_params(bert::Fused_multihead_attention_params_v2& params, const Launch_params launch_params, // types Data_type data_type, Data_type acc_type, Data_type output_dtype, // attention input layout Attention_input_layout input_layout, // sizes - size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const h_kv, size_t const d, - size_t const dv, size_t const total, const size_t num_grouped_heads, const size_t sliding_window_size, + const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t h_kv, const size_t d, + const size_t dv, const size_t total, const size_t num_grouped_heads, const size_t sliding_window_size, const size_t chunked_attention_size, // paged kv cache block size. - size_t const tokens_per_block, + const size_t tokens_per_block, // device pointers void* qkv_packed_d, // contiguous q. @@ -261,8 +264,10 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params, // offsets for different blocks in terms of the start address. int32_t* paged_block_offsets, // mask input. - void* packed_mask_d, void* cu_mask_rows_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d, - void* s_d, void* softmax_stats_d, void* scale_bmm2_d, + void* packed_mask_d, void* cu_mask_rows_d, + // attention sinks. + void* attention_sinks_d, void* cu_kv_seqlens_d, void* cu_q_seqlens_d, void* o_packed_d, void* p_d, void* s_d, + void* softmax_stats_d, void* scale_bmm2_d, // scale factors float const scale_bmm1, float const scale_softmax, float const scale_bmm2, float const softcapping_scale_bmm1, // flags @@ -329,6 +334,9 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params, // The N dimension has to be aligned. params.packed_mask_stride_in_bytes = (align_to(int64_t(s_kv), int64_t(fmha::FLASH_ATTEN_MASK_N_ALIGNMENT))) / 8; + // Attention sinks. + params.attention_sinks = reinterpret_cast(attention_sinks_d); + #if defined(STORE_P) params.p_ptr = p_d; params.p_stride_in_bytes = get_size_in_bytes(b * h * s_kv, acc_type); @@ -412,13 +420,13 @@ static inline void set_params(bert::Fused_multihead_attention_params_v2& params, //////////////////////////////////////////////////////////////////////////////////////////////////// -static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, size_t const s, - size_t const d, Attention_mask_type const attention_mask_type, Attention_input_layout const input_layout, +static inline void determine_launch_params(Launch_params& launch_params, Data_type data_type, int sm, const size_t s, + const size_t d, const Attention_mask_type attention_mask_type, const Attention_input_layout input_layout, bool const interleaved, bool const ignore_b1opt, bool const force_unroll, bool const use_tma, bool const force_non_flash_attention, bool const force_non_warp_specialization, bool const force_non_granular_tiling, bool const force_fp32_acc, // device props - cudaDeviceProp const props) + const cudaDeviceProp props) { // Set launch params to choose kernels @@ -573,6 +581,9 @@ int main(int argc, char** argv) // SageAttention block sizes int sage_block_size_q = 0, sage_block_size_k = 0, sage_block_size_v = 0; + // Use attention sinks (added to the denominator of softmax) + bool use_attention_sinks = false; + // Read the parameters from the command-line. for (int ii = 1; ii < argc; ++ii) { @@ -865,13 +876,16 @@ int main(int argc, char** argv) { sage_block_size_v = strtol(argv[ii], nullptr, 10); } + else if (!strcmp(argv[ii], "-use-attention-sinks")) + { + use_attention_sinks = true; + } else { fprintf(stderr, "Unrecognized option: %s. Aborting!\n", argv[ii]); return -1; } } - if (save_softmax == true) { if (input_layout != Attention_input_layout::CONTIGUOUS_Q_KV) @@ -1043,11 +1057,11 @@ int main(int argc, char** argv) force_non_granular_tiling, force_fp32_acc, props); // The Q, K and V matrices are packed into one big matrix of size S x B x H x 3 x D. - size_t const qkv_size = s * b * h * (2 * d + dv); + const size_t qkv_size = s * b * h * (2 * d + dv); // Allocate on the host. float* qkv_h = (float*) malloc(qkv_size * sizeof(float)); // The size in bytes. - size_t const qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type); + const size_t qkv_size_in_bytes = get_size_in_bytes(qkv_size, data_type); // Allocate on the device. void *qkv_sbh3d_d = nullptr, *qkv_bsh3d_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&qkv_sbh3d_d, qkv_size_in_bytes)); @@ -1057,7 +1071,7 @@ int main(int argc, char** argv) // The shape is [B, 2, S, H, D]. const size_t kv_size = b * s * h_kv * (d + dv); // The size in bytes. - size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); + const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); // Allocate on the host. void* contiguous_kv_h = malloc(kv_size_in_bytes); // Memset the buffer. @@ -1071,13 +1085,13 @@ int main(int argc, char** argv) void** kv_cache_ptrs_h = nullptr; void* kv_cache_pool_ptr = nullptr; int32_t *kv_cache_block_offsets_h, *kv_cache_block_offsets_d = nullptr; - size_t const max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block; - size_t const num_total_blocks = b * 2 * max_blocks_per_seq; + const size_t max_blocks_per_seq = (s + tokens_per_block - 1) / tokens_per_block; + const size_t num_total_blocks = b * 2 * max_blocks_per_seq; kv_cache_ptrs_h = (void**) malloc(num_total_blocks * sizeof(void*)); kv_cache_block_offsets_h = (int32_t*) malloc(num_total_blocks * sizeof(int32_t)); - size_t const paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type); + const size_t paged_kv_block_size_in_bytes = get_size_in_bytes(tokens_per_block * h_kv * std::gcd(d, dv), data_type); FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_block_offsets_d), num_total_blocks * sizeof(int32_t))); - size_t const kv_cache_pool_sz + const size_t kv_cache_pool_sz = get_size_in_bytes(num_total_blocks * tokens_per_block * h_kv * (d + dv) / 2, data_type); FMHA_CHECK_CUDA(cudaMalloc((void**) (&kv_cache_pool_ptr), kv_cache_pool_sz)); size_t ptr_index = 0; @@ -1104,7 +1118,7 @@ int main(int argc, char** argv) // Q will always be [B, S, H, Dh] with paged kv cache. void* q_d; - size_t const q_size = s * b * h * d; + const size_t q_size = s * b * h * d; FMHA_CHECK_CUDA(cudaMalloc(&q_d, get_size_in_bytes(q_size, data_type))); // K has [B, S, H_kv, D] with separate kv cache. @@ -1122,11 +1136,11 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMalloc(&scale_bmm2_d, sizeof(uint32_t))); // The mask for dropout or any mask patterns. - size_t const mask_size = s * b * s; + const size_t mask_size = s * b * s; // Allocate on the host. float* mask_h = (float*) malloc(mask_size * sizeof(float)); // The size in bytes. - size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); + const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); // Allocate on the device. void* mask_d = nullptr; if (!skip_checks) @@ -1158,7 +1172,7 @@ int main(int argc, char** argv) v1 ? 1 : 2); // The number of threads per CTA. - size_t const threads_per_cta = warps_m * warps_n * warps_k * 32; + const size_t threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension. size_t mmas_m = (s + 16 * warps_m - 1) / (16 * warps_m); // The number of mmas in the N dimension. @@ -1182,7 +1196,7 @@ int main(int argc, char** argv) packed_mask_size = b * mmas_m * mmas_n * threads_per_cta; } // The size in bytes. - size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); + const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); // Allocate on the host. uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes); // Set it to 0 (indicates that all elements are valid). @@ -1190,12 +1204,30 @@ int main(int argc, char** argv) // Allocate on the device. void* packed_mask_d = nullptr; + // The size of the attention sinks. + const size_t attention_sinks_size_in_bytes = h * sizeof(float); + + // The attention sinks. + void* attention_sinks_d = nullptr; + if (use_attention_sinks) + { + // Allocate on the host. + float* attention_sinks_h = (float*) malloc(attention_sinks_size_in_bytes); + // Randomly initialize the attention sinks. + random_init("attention_sinks", attention_sinks_h, 1, h, 1, false, 5.f, 1.f, verbose); + // Allocate on the device. + FMHA_CHECK_CUDA(cudaMalloc(&attention_sinks_d, attention_sinks_size_in_bytes)); + // Copy from the host to the device. + FMHA_CHECK_CUDA( + cudaMemcpy(attention_sinks_d, attention_sinks_h, attention_sinks_size_in_bytes, cudaMemcpyDefault)); + } + // The O matrix is packed as S * B * H * D. - size_t const o_size = s * b * h * dv; + const size_t o_size = s * b * h * dv; // Allocate on the host. float* o_h = (float*) malloc(o_size * sizeof(float)); // The size in bytes. - size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type); + const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type); // Allocate on the device. void* o_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes)); @@ -1206,7 +1238,7 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMemset(softmax_stats_d, 0x00, 2 * sizeof(float) * b * s * h)); // The size in bytes. - size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); + const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); // Allocate on the device. void* tmp_d = nullptr; if (data_type != acc_type) @@ -1220,9 +1252,9 @@ int main(int argc, char** argv) float* softmax_sum_h = (float*) malloc(b * s * h * sizeof(float)); // The P matrix is stored as one big matrix of size S x B x H x S. - size_t const p_size = s * b * h * s; + const size_t p_size = s * b * h * s; // The size in bytes. - size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type); + const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type); // Allocate on the device. void* p_d = nullptr; if (!skip_checks) @@ -1238,7 +1270,7 @@ int main(int argc, char** argv) #endif // defined(STORE_P) // The size in bytes of the S matrix (the data type may be different from P for int8). - size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type); + const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type); // Allocate on the device. void* s_d = nullptr; if (!skip_checks) @@ -1327,7 +1359,7 @@ int main(int argc, char** argv) std::vector seqlens(b, 0); // randomly draw a batch of sequence lengths >= min_s std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(), - [=](uint32_t const) + [=](const uint32_t) { if (fix_s) { @@ -1415,7 +1447,7 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_packed_d, mqa_qkv_packed_size_in_bytes)); FMHA_CHECK_CUDA(cudaMalloc(&mqa_qkv_d, mqa_qkv_size_in_bytes)); - size_t const o_packed_size = cu_seqlens.back() * h * dv; + const size_t o_packed_size = cu_seqlens.back() * h * dv; // Allocate on the host. float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float)); void* o_packed_d = nullptr; @@ -1676,9 +1708,9 @@ int main(int argc, char** argv) total, num_grouped_heads, sliding_window_size, chunked_attention_size, // Paged kv cache. tokens_per_block, qkv_d_view, q_d, k_d, v_d, contiguous_kv_d, kv_cache_pool_ptr, kv_cache_block_offsets_d, - packed_mask_d, cu_mask_rows_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d, softmax_stats_ptr, - scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, use_int8_scale_max, interleaved, - is_s_padded, has_alibi); + packed_mask_d, cu_mask_rows_d, attention_sinks_d, cu_seqlens_d, cu_q_seqlens_d, o_d_view, p_d, s_d, + softmax_stats_ptr, scale_bmm2_d, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, + use_int8_scale_max, interleaved, is_s_padded, has_alibi); // total number of tokens is needed to set TMA desc on the host. launch_params.total_q_seqlen = q_seqlens[b]; @@ -1894,8 +1926,8 @@ int main(int argc, char** argv) ground_truth(bmm1, bmm2, data_type, acc_type, scale_bmm1, scale_softmax, scale_bmm2, softcapping_scale_bmm1, qkv_sbh3d_d, vt_d, // WAR pass in V' - mask_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs, warps_m, warps_n, - has_alibi); + mask_d, attention_sinks_d, p_d, s_d, tmp_d, o_d, softmax_stats_d, cu_seqlens_d, b, s, h, d, dv, runs, + warps_m, warps_n, has_alibi); timer.stop(); FMHA_CHECK_CUDA(cudaPeekAtLastError()); FMHA_CHECK_CUDA(cudaDeviceSynchronize()); @@ -2009,7 +2041,6 @@ int main(int argc, char** argv) // Extract the last s_q tokens from the output. extract_and_transpose_output( o_ref_trans_h.data(), o_ref_h, seqlens, q_seqlens, s, s_q, b, h, dv, is_s_padded); - if (verbose) { printf("\nChecking .....: O = V * S\n"); diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention.h b/cpp/kernels/fmha_v2/src/fused_multihead_attention.h index f77e3f14d0c..16e2f9a8db5 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention.h +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention.h @@ -197,6 +197,9 @@ struct Fused_multihead_attention_params_v2 : Fused_multihead_attention_params_ba // The stride between rows of softmax_stats_ptr int64_t softmax_stats_stride_in_bytes; + // The attention sinks (per head). + float* attention_sinks; + // array of length b+1 holding prefix sum of actual q sequence lengths. int* cu_q_seqlens; // array of length b+1 holding prefix sum of actual kv sequence lengths. diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h b/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h index 76670971e57..bacb4938cf2 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h +++ b/cpp/kernels/fmha_v2/src/fused_multihead_attention_demo_bert_params.h @@ -87,6 +87,8 @@ struct Fused_multihead_attention_params_v2 fmha::Kv_block_array paged_kv_cache; // The mask to implement drop-out. void* packed_mask_ptr; + // The attention sinks (per head). + float* attention_sinks; // The O matrix (output). void* o_ptr; // The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max diff --git a/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp b/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp index 8a2e7a8fc0c..6e37fc6ab43 100644 --- a/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp +++ b/cpp/kernels/fmha_v2/src/fused_multihead_cross_attention.cpp @@ -23,25 +23,27 @@ using Launch_params = bert::Fused_multihead_attention_launch_params; //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi); +void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, + int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi); +void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// -void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_seqlens_q_d, - int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, float softcapping_scale_bmm1, int warps_n, - bool has_alibi); +void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_seqlens_q_d, int s_inner, int s_outer, int b, int h, float scale_i2f, float scale_f2i, + float softcapping_scale_bmm1, int warps_n, bool has_alibi); //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -57,10 +59,10 @@ void run_conversion_fp32_to_e4m3(void* dst, void const* src, int s, int b, int h //////////////////////////////////////////////////////////////////////////////////////////////////// -void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_type const acc_type, +void ground_truth(RefBMM& bmm1, RefBMM& bmm2, const Data_type data_type, const Data_type acc_type, float const scale_bmm1, float const scale_softmax, float const scale_bmm2, void* q_d, void* kv_d, void* vt_d, void* mask_d, void* p_d, void* s_d, void* tmp_d, void* o_d, void* softmax_sum_d, void* cu_seqlens_q_d, - size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, int const runs, + const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, int const runs, int const warps_m, int const warps_n, bool has_alibi) { @@ -84,20 +86,22 @@ void ground_truth(RefBMM& bmm1, RefBMM& bmm2, Data_type const data_type, Data_ty // Softmax. if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP16) { - run_softmax_fp16(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); + run_softmax_fp16( + s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); } else if (data_type == DATA_TYPE_FP16 && acc_type == DATA_TYPE_FP32) { - run_softmax_fp32(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); + run_softmax_fp32( + s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, 0.f, warps_n, has_alibi); } else if (data_type == DATA_TYPE_E4M3 && acc_type == DATA_TYPE_FP32) { - run_softmax_e4m3(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax, 0.f, - warps_n, has_alibi); + run_softmax_e4m3(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_softmax, + 0.f, warps_n, has_alibi); } else if (data_type == DATA_TYPE_INT8 && acc_type == DATA_TYPE_INT32) { - run_softmax_int8(s_d, p_d, mask_d, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1, + run_softmax_int8(s_d, p_d, mask_d, nullptr, softmax_sum_d, cu_seqlens_q_d, s_kv, s_q, b, h, scale_bmm1, scale_softmax, 0.f, warps_n, has_alibi); } else @@ -148,8 +152,8 @@ static inline void set_params(bert::Fused_multihead_attention_params_mhca& param // types Data_type data_type, Data_type acc_type, // sizes - size_t const b, size_t const s_q, size_t const s_kv, size_t const h, size_t const d, size_t const d_padded, - size_t const total, + const size_t b, const size_t s_q, const size_t s_kv, const size_t h, const size_t d, const size_t d_padded, + const size_t total, // device pointers void* q_packed_d, void* kv_packed_d, void* cu_seqlens_q_d, void* cu_seqlens_kv_d, void* o_packed_d, void* p_d, void* s_d, @@ -515,17 +519,17 @@ int main(int argc, char** argv) launch_params.use_tma = use_tma; // The Q matrix of size S_Q x B x H x D. - size_t const q_size = s_q * b * h * d; + const size_t q_size = s_q * b * h * d; // The K and V matrices are packed into one big matrix of size S_KV x B x H x 2 x D. - size_t const kv_size = s_kv_padded * b * h * 2 * d; + const size_t kv_size = s_kv_padded * b * h * 2 * d; // Allocate on the host. float* q_h = (float*) malloc(q_size * sizeof(float)); // Allocate on the host. float* kv_h = (float*) malloc(kv_size * sizeof(float)); // The size in bytes. - size_t const q_size_in_bytes = get_size_in_bytes(q_size, data_type); + const size_t q_size_in_bytes = get_size_in_bytes(q_size, data_type); // The size in bytes. - size_t const kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); + const size_t kv_size_in_bytes = get_size_in_bytes(kv_size, data_type); // Allocate on the device. void* q_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&q_d, q_size_in_bytes)); @@ -534,11 +538,11 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMalloc(&kv_d, kv_size_in_bytes)); // The mask for dropout. - size_t const mask_size = s_q * b * s_kv_padded; + const size_t mask_size = s_q * b * s_kv_padded; // Allocate on the host. float* mask_h = (float*) malloc(mask_size * sizeof(float)); // The size in bytes. - size_t const mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); + const size_t mask_size_in_bytes = get_size_in_bytes(mask_size, DATA_TYPE_INT8); // Allocate on the device. void* mask_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&mask_d, mask_size_in_bytes)); @@ -554,28 +558,28 @@ int main(int argc, char** argv) v1 ? 1 : 2); // The number of threads per CTA. - size_t const threads_per_cta = warps_m * warps_n * warps_k * 32; + const size_t threads_per_cta = warps_m * warps_n * warps_k * 32; // The number of mmas in the M dimension. We use one uint32_t per MMA in the M dimension. - size_t const mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m); + const size_t mmas_m = (s_q + 16 * warps_m - 1) / (16 * warps_m); // The number of mmas in the N dimension. - size_t const mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n); + const size_t mmas_n = (s_kv_padded + 16 * warps_n - 1) / (16 * warps_n); // We do not support more than 4 MMAS in the N dimension (as each MMA needs 8 bits in the mask). assert(!v1 || mmas_n <= 4); // The packed mask for dropout (in the fused kernel). Layout is B * MMAS_M * THREADS_PER_CTA. - size_t const packed_mask_size = b * mmas_m * threads_per_cta; + const size_t packed_mask_size = b * mmas_m * threads_per_cta; // The size in bytes. - size_t const packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); + const size_t packed_mask_size_in_bytes = packed_mask_size * sizeof(uint32_t); // Allocate on the host. uint32_t* packed_mask_h = (uint32_t*) malloc(packed_mask_size_in_bytes); // Allocate on the device. void* packed_mask_d = nullptr; // The O matrix is packed as S_Q * B * H * D. - size_t const o_size = s_q * b * h * d; + const size_t o_size = s_q * b * h * d; // Allocate on the host. float* o_h = (float*) malloc(o_size * sizeof(float)); // The size in bytes. - size_t const o_size_in_bytes = get_size_in_bytes(o_size, data_type); + const size_t o_size_in_bytes = get_size_in_bytes(o_size, data_type); // Allocate on the device. void* o_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&o_d, o_size_in_bytes)); @@ -587,7 +591,7 @@ int main(int argc, char** argv) FMHA_CHECK_CUDA(cudaMemset(softmax_max_d, 0x00, sizeof(float) * b * s_q * h)); // The size in bytes. - size_t const tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); + const size_t tmp_size_in_bytes = get_size_in_bytes(o_size, acc_type); // Allocate on the device. void* tmp_d = nullptr; if (data_type != acc_type) @@ -599,9 +603,9 @@ int main(int argc, char** argv) float* o_ref_h = (float*) malloc(o_size * sizeof(float)); // The P matrix is stored as one big matrix of size S_Q x B x H x S_KV. - size_t const p_size = s_q * b * h * s_kv_padded; + const size_t p_size = s_q * b * h * s_kv_padded; // The size in bytes. - size_t const p_size_in_bytes = get_size_in_bytes(p_size, acc_type); + const size_t p_size_in_bytes = get_size_in_bytes(p_size, acc_type); // Allocate on the device. void* p_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&p_d, p_size_in_bytes)); @@ -614,7 +618,7 @@ int main(int argc, char** argv) #endif // defined(STORE_P) // The size in bytes of the S matrix (the data type may be different from P for int8). - size_t const s_size_in_bytes = get_size_in_bytes(p_size, data_type); + const size_t s_size_in_bytes = get_size_in_bytes(p_size, data_type); // Allocate on the device. void* s_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&s_d, s_size_in_bytes)); @@ -634,9 +638,9 @@ int main(int argc, char** argv) // WAR fOR MISSING CUBLAS FP8 NN SUPPORT. // Transpose V, so that we can do a TN BMM2, i.e. O = S x V' instead of O = S x V. - size_t const v_size = s_kv_padded * b * h * d; + const size_t v_size = s_kv_padded * b * h * d; // The size in bytes. - size_t const v_size_in_bytes = get_size_in_bytes(v_size, data_type); + const size_t v_size_in_bytes = get_size_in_bytes(v_size, data_type); float* vt_h = (float*) malloc(v_size * sizeof(float)); void* vt_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&vt_d, v_size_in_bytes)); @@ -676,7 +680,7 @@ int main(int argc, char** argv) = [min_s, fix_s, b](int s, std::vector& seqlens, std::vector& cu_seqlens, void** cu_seqlens_d) { std::transform(seqlens.begin(), seqlens.end(), seqlens.begin(), - [=](uint32_t const) + [=](const uint32_t) { if (fix_s) { @@ -728,7 +732,7 @@ int main(int argc, char** argv) void* kv_packed_d = nullptr; FMHA_CHECK_CUDA(cudaMalloc(&kv_packed_d, kv_packed_size_in_bytes)); - size_t const o_packed_size = cu_seqlens_q.back() * h * d; + const size_t o_packed_size = cu_seqlens_q.back() * h * d; // Allocate on the host. float* o_packed_h = (float*) malloc(o_packed_size * sizeof(float)); float* o_ref_packed_h = (float*) malloc(o_packed_size * sizeof(float)); diff --git a/cpp/kernels/fmha_v2/src/softmax_bf16.cu b/cpp/kernels/fmha_v2/src/softmax_bf16.cu index 5212d317174..79b681b5023 100644 --- a/cpp/kernels/fmha_v2/src/softmax_bf16.cu +++ b/cpp/kernels/fmha_v2/src/softmax_bf16.cu @@ -12,9 +12,10 @@ #include "softmax_impl.h" -void run_softmax_bf16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) +void run_softmax_bf16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi) { - run_softmax(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, + b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_fp16.cu b/cpp/kernels/fmha_v2/src/softmax_fp16.cu index 1fb68b1136d..9df37605a2e 100644 --- a/cpp/kernels/fmha_v2/src/softmax_fp16.cu +++ b/cpp/kernels/fmha_v2/src/softmax_fp16.cu @@ -12,9 +12,10 @@ #include "softmax_impl.h" -void run_softmax_fp16(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) +void run_softmax_fp16(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi) { - run_softmax(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, + h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_fp32.cu b/cpp/kernels/fmha_v2/src/softmax_fp32.cu index 2b3bb6acbb7..12bcd8624d9 100644 --- a/cpp/kernels/fmha_v2/src/softmax_fp32.cu +++ b/cpp/kernels/fmha_v2/src/softmax_fp32.cu @@ -12,9 +12,10 @@ #include "softmax_impl.h" -void run_softmax_fp32(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, bool has_alibi) +void run_softmax_fp32(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float softcapping_scale_bmm1, int warps_n, + bool has_alibi) { - run_softmax(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, 0.f, - softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, + b, h, 0.f, 0.f, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_fp8.cu b/cpp/kernels/fmha_v2/src/softmax_fp8.cu index 0a8e5f50299..26c2f5e88d7 100644 --- a/cpp/kernels/fmha_v2/src/softmax_fp8.cu +++ b/cpp/kernels/fmha_v2/src/softmax_fp8.cu @@ -12,10 +12,10 @@ #include "softmax_impl.h" -void run_softmax_e4m3(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi) +void run_softmax_e4m3(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_softmax, float softcapping_scale_bmm1, + int warps_n, bool has_alibi) { - run_softmax(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, 0.f, - scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, + b, h, 0.f, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/fmha_v2/src/softmax_impl.h b/cpp/kernels/fmha_v2/src/softmax_impl.h index 2bc9f3380be..ca652627442 100644 --- a/cpp/kernels/fmha_v2/src/softmax_impl.h +++ b/cpp/kernels/fmha_v2/src/softmax_impl.h @@ -10,6 +10,7 @@ * its affiliates is strictly prohibited. */ +#include #include #include #include @@ -33,6 +34,8 @@ struct Softmax_params Src_type const* src; // Masks. int8_t const* mask; + // Attention sinks (per head). + float const* attention_sinks; // Softmax sum pointer. float* softmax_sum; // ALiBi @@ -148,7 +151,8 @@ static inline __device__ float apply_exp_(float x, float max) //////////////////////////////////////////////////////////////////////////////////////////////////// template -static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&mask)[N][1], int warps_n, float& sum_fp32) +static inline __device__ void reduce( + float (&data_fp32)[N][1], const int8_t (&mask)[N][1], int warps_n, float& sum_fp32, float const attention_sink) { // Apply the masks. @@ -233,7 +237,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma } // Normalize. - float inv_sum_fp32 = 1.f / sum_fp32; + float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -244,7 +248,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][1], int8_t const (&ma //////////////////////////////////////////////////////////////////////////////////////////////////// template -static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&mask)[N][2], int warps_n, float& sum_fp32) +static inline __device__ void reduce( + float (&data_fp32)[N][2], const int8_t (&mask)[N][2], int warps_n, float& sum_fp32, float const attention_sink) { // Apply the masks. #pragma unroll @@ -401,7 +406,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma } // Normalize. - float inv_sum_fp32 = 1.f / sum_fp32; + float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -413,7 +418,8 @@ static inline __device__ void reduce(float (&data_fp32)[N][2], int8_t const (&ma //////////////////////////////////////////////////////////////////////////////////////////////////// template -static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&mask)[N][4], int warps_n, float& sum_fp32) +static inline __device__ void reduce( + float (&data_fp32)[N][4], const int8_t (&mask)[N][4], int warps_n, float& sum_fp32, float const attention_sink) { // Apply the masks. @@ -824,7 +830,7 @@ static inline __device__ void reduce(float (&data_fp32)[N][4], int8_t const (&ma } // Normalize. - float inv_sum_fp32 = 1.f / sum_fp32; + float inv_sum_fp32 = 1.f / (sum_fp32 + expf(attention_sink - max_fp32)); #pragma unroll for (int ii = 0; ii < N; ++ii) { @@ -994,9 +1000,16 @@ static __global__ void softmax_kernel(Softmax_params params) } } + // The attention sink value. + float attention_sink = -FLT_MAX; + if (params.attention_sinks != nullptr) + { + attention_sink = params.attention_sinks[hi]; + } + // Do the reduction. float sum_fp32 = 0.f; - reduce(data_fp32, mask_, params.warps_n, sum_fp32); + reduce(data_fp32, mask_, params.warps_n, sum_fp32, attention_sink); if (threadIdx.x == 0) { int sum_s = params.cu_q_seqlens[bi]; @@ -1025,9 +1038,9 @@ static __global__ void softmax_kernel(Softmax_params params) //////////////////////////////////////////////////////////////////////////////////////////////////// template -void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum, void* cu_q_seqlens, int s_inner, - int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, int warps_n, - bool has_alibi) +void run_softmax(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum, + void* cu_q_seqlens, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, + float softcapping_scale_bmm1, int warps_n, bool has_alibi) { Softmax_params params; @@ -1039,6 +1052,7 @@ void run_softmax(void* dst, void const* src, void const* mask, void* softmax_sum params.softmax_sum = reinterpret_cast(softmax_sum); params.cu_q_seqlens = reinterpret_cast(cu_q_seqlens); params.mask = reinterpret_cast(mask); + params.attention_sinks = reinterpret_cast(attention_sinks); params.has_alibi = has_alibi; // The dimensions and precomputed values. diff --git a/cpp/kernels/fmha_v2/src/softmax_int8.cu b/cpp/kernels/fmha_v2/src/softmax_int8.cu index 772fe1520ce..28701de9789 100644 --- a/cpp/kernels/fmha_v2/src/softmax_int8.cu +++ b/cpp/kernels/fmha_v2/src/softmax_int8.cu @@ -12,10 +12,10 @@ #include "softmax_impl.h" -void run_softmax_int8(void* dst, void const* src, void const* mask, void* softmax_sum_d, void* cu_q_seqlens_d, - int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, float softcapping_scale_bmm1, - int warps_n, bool has_alibi) +void run_softmax_int8(void* dst, void const* src, void const* mask, void const* attention_sinks, void* softmax_sum_d, + void* cu_q_seqlens_d, int s_inner, int s_outer, int b, int h, float scale_bmm1, float scale_softmax, + float softcapping_scale_bmm1, int warps_n, bool has_alibi) { - run_softmax(dst, src, mask, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, scale_bmm1, - scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); + run_softmax(dst, src, mask, attention_sinks, softmax_sum_d, cu_q_seqlens_d, s_inner, s_outer, b, h, + scale_bmm1, scale_softmax, softcapping_scale_bmm1, warps_n, has_alibi); } diff --git a/cpp/kernels/xqa/mha.cu b/cpp/kernels/xqa/mha.cu index c9690cbc6b0..69d93e901c3 100644 --- a/cpp/kernels/xqa/mha.cu +++ b/cpp/kernels/xqa/mha.cu @@ -1379,6 +1379,19 @@ __device__ inline ThrdRegRowMax mergeRowMax( return mergedRowMax; } +__device__ inline void addAttentionSinks( + ThrdRegRowMax& globalRowSum, ThrdRegRowMax const globalRowMax, float const* attentionSinks) +{ + for (uint32_t i = 0; i < globalRowSum.size; i++) + { + uint32_t srcOffset = warp_size * i + laneId(); + if (srcOffset < headGrpSize) + { + globalRowSum[i] += expf(attentionSinks[srcOffset] - globalRowMax[i]); + } + } +} + #ifdef NDEBUG __device__ __forceinline__ #else @@ -1405,6 +1418,7 @@ CUBIN_EXPORT __global__ #if SPEC_DEC MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32)]. #endif + float const* attentionSinks, // [headGrpSize] #ifdef NDEBUG KVCacheList const& cacheList, #if BEAM_WIDTH > 1 @@ -2371,6 +2385,12 @@ CUBIN_EXPORT __global__ float voScale = (isKVCacheQuantized ? kvCacheScale[0] : 1.F); if (seqIterInit < nbSeqIters) { // otherwise rcpRowSum will be NAN. + // The attention sinks are moved to the multi-block reduction part if the multi-block is enabled. + if (!isMultiBlock && attentionSinks != nullptr) + { + // Attention sinks are per head. + addAttentionSinks(globalRowSum, globalRowMax, attentionSinks + headGrpSize * idxHeadGrp); + } ThrdRegRowMax const rcpRowSum = __frcp_rn(globalRowSum); #if LOW_PREC_OUTPUT voScale *= rcpOutScale[0]; @@ -2559,6 +2579,11 @@ CUBIN_EXPORT __global__ assert(std::isfinite(mergedRowSum[0])); } } + if (attentionSinks != nullptr) + { + // Attention sinks are per head. + addAttentionSinks(mergedRowSum, mergedRowMax, attentionSinks + headGrpSize * idxHeadGrp); + } __syncthreads(); rescaleAcc(warp, sumAcc, fullRescaleMask, __frcp_rn(mergedRowSum)); GemmOutRegTile const mergedOutTile = toFp16(sumAcc); @@ -2615,6 +2640,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( MaskType const* __restrict__ mask, // [qSeqLen, divUp(qSeqLen, 32))] uint2 (each bit represents mask for one col // position). #endif + float const* attentionSinks, // [headGrpSize] KVCacheList const cacheList, #if BEAM_WIDTH > 1 BeamSearchParams const beamSearchParams, @@ -2640,7 +2666,7 @@ CUBIN_EXPORT __global__ __launch_bounds__(256, nbCtaPerSM) void kernel_mha( #if SPEC_DEC mask, #endif - cacheList, + attentionSinks, cacheList, #if BEAM_WIDTH > 1 beamSearchParams, #endif @@ -2667,6 +2693,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, @@ -2760,7 +2787,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SPEC_DEC mask, #endif - cacheList, + attentionSinks, cacheList, #if BEAM_WIDTH > 1 beamSearchParams, #endif @@ -2788,7 +2815,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #if SPEC_DEC mask, #endif - cacheList, + attentionSinks, cacheList, #if BEAM_WIDTH > 1 beamSearchParams, #endif diff --git a/cpp/kernels/xqa/mha.h b/cpp/kernels/xqa/mha.h index 39c94f985ec..d35ad48104a 100644 --- a/cpp/kernels/xqa/mha.h +++ b/cpp/kernels/xqa/mha.h @@ -101,6 +101,7 @@ void launchMHA(cudaDeviceProp const& prop, uint32_t const nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, @@ -140,6 +141,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, diff --git a/cpp/kernels/xqa/mha_sm90.cu b/cpp/kernels/xqa/mha_sm90.cu index 88d4c75e30b..9a438df9a2a 100644 --- a/cpp/kernels/xqa/mha_sm90.cu +++ b/cpp/kernels/xqa/mha_sm90.cu @@ -428,6 +428,7 @@ __device__ RegColWiseVec computeWarpColSum(Gemm0Acc& src); __device__ void storeGemm0AccToShm( uint32_t warpRank, uint32_t lane, SharedMem::XBuffer& smemX, CtaBarrier& barConsumed, Gemm0Acc const& acc); __device__ RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smemVec); +__device__ RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound); #else __device__ RegRowWiseVec computeWarpGrpRowMax_sync(uint32_t warpRank, ShmQWiseVec& smemColMax, Gemm0Acc const& src); __device__ void warpGrpApplyMask(Gemm0Acc& acc, uint32_t validColBeg, uint32_t validColEnd); @@ -453,7 +454,8 @@ __device__ void rescaleGemm1AccForNewColMax_sync(uint32_t warpRank, ShmQWiseVec template __device__ void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, - ShmQWiseVec const& accColSum, uint32_t nbKHeads = 0 /* only for final result in spec dec. */); + ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, + uint32_t nbKHeads = 0 /* only for final result in spec dec. */); #else __device__ void transposeVTile( uint32_t warpRank, uint32_t lane, SharedMem::VTBuffer& dst, SharedMem::VBuffer const& src); @@ -651,6 +653,7 @@ CUBIN_EXPORT __global__ #else IOHead const* __restrict__ const q, // [nbReq][beamWidth][nbQHeads], #endif + float const* attentionSinks, // [headGrpSize] KVCacheList const cacheList, #if USE_BEAM_SEARCH BeamSearchParams const beamSearchParams, @@ -1252,7 +1255,7 @@ CUBIN_EXPORT __global__ IOHead* const dst = (scratchMem.tokens() + idxChunk).template cast(); #if SWAP_AB finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, - smem.gemm1WarpGrpBar, smem.gemm1AccColSum); + smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, nullptr); #else finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, smem.gemm1AccColSum, 1, ctaNbValidTokens); @@ -1262,9 +1265,16 @@ CUBIN_EXPORT __global__ { uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); OutputHead* const dst = &output[outOffset]; + ShmQWiseVec const* attentionSinksVec = nullptr; + if (attentionSinks != nullptr) + { + attentionSinksVec + = reinterpret_cast(attentionSinks + headGrpSize * idxHeadGrp); + } #if SWAP_AB finalizeAndWriteOut_sync(threadIdx.x, warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, - xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, nbKHeads); + xvoScale, smem.gemm1WarpGrpBar, smem.gemm1AccColSum, smem.gemm1AccColMax, attentionSinksVec, + nbKHeads); #else finalizeAndWriteOut_sync(warpRank, dst, smem.outSwizzleBuf(idxXBuf), acc, xvoScale, smem.gemm1AccColSum, nbKHeads, ctaNbValidTokens); @@ -1585,6 +1595,17 @@ CUBIN_EXPORT __global__ } unused(bar.consumed.arrive()); } + // Add the attention sinks. + if (attentionSinks != nullptr) + { + for (uint32_t i = 0; i < headsPerWarp; i++) + { + uint32_t const idxHead = wid + nbMathWarps * i; + float sink = expf( + attentionSinks[mha::min(idxHead, headGrpSize - 1) + idxHeadGrp * headGrpSize] - states[i].max); + states[i].sum += sink; + } + } __syncthreads(); uint32_t const outOffset = headGrpSize * (nbKHeads * (beamWidth * ctaInputTokBeg) + idxHeadGrp); auto const dst = &output[outOffset]; @@ -2029,6 +2050,22 @@ __device__ inline RegColWiseVec loadShmColWiseVecWithDup(ShmQWiseVec const& smem return ret; } +__device__ inline RegColWiseVec loadGmemColWiseVecWithDup(ShmQWiseVec const& gmemVec, uint32_t bound) +{ + RegColWiseVec ret; + constexpr uint32_t nbThrdsPerInstNBase = exactDiv(gmma::instNBase, GmmaAccCoreMat::cols); + auto const idx = laneId() % nbThrdsPerInstNBase; +#pragma unroll + for (uint32_t i = 0; i < exactDiv(ShmQWiseVec::size, gmma::instNBase); i++) + { + static_assert(nbThrdsPerInstNBase * RegColWiseVec::size == exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)); + ret[i] = reinterpret_cast< + Vec, exactDiv(ShmQWiseVec::size, GmmaAccCoreMat::cols)> const&>( + gmemVec)[mha::min(i * nbThrdsPerInstNBase + idx, bound)]; + } + return ret; +} + __device__ inline void warpGrpApplyMask(uint32_t warpRank, Gemm0Acc& acc, uint32_t validRowBeg, uint32_t validRowEnd) { uint32_t const idxInQuad = laneId() % 4; @@ -2878,12 +2915,19 @@ __device__ inline void saveTransposedOutput(uint32_t threadRank, uint32_t warpRa template __device__ inline void finalizeAndWriteOut_sync(uint32_t threadRank, uint32_t warpRank, DstHead* dst, SharedMem::OutSwizzleBuf& swizzleBuf, Gemm1Acc& acc, float xvoScale, CtaBarrier& warpGrpBar, - ShmQWiseVec const& accColSum, uint32_t nbKHeads) + ShmQWiseVec const& accColSum, ShmQWiseVec const& accColMax, ShmQWiseVec const* attentionSinksVec, uint32_t nbKHeads) { // @fixme: if ctaNbQHeads is large, use loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of mufu.rcp // static_assert(ctaNbQHeads <= 8, "Warning: consider using loadShmColWiseVecNoDup + rcp + shfl to avoid 8x waste of // mufu.rcp"); - auto const regColSum = loadShmColWiseVecWithDup(accColSum); + auto regColSum = loadShmColWiseVecWithDup(accColSum); + if (attentionSinksVec != nullptr) + { + auto const regAccColMax = loadShmColWiseVecWithDup(accColMax); + auto const regAttentionSinks = loadGmemColWiseVecWithDup(attentionSinksVec[0], headGrpSize - 1); + auto regColSinks = expf(regAttentionSinks - regAccColMax); + regColSum = regColSum + regColSinks; + } auto const regOutScale = __frcp_rn(regColSum) * xvoScale; rescaleAcc(acc, regOutScale); @@ -3175,6 +3219,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else InputHead const* q, #endif + float const* attentionSinks, // [headGrpSize] #if USE_PAGED_KV_CACHE #if PAGED_KV_CACHE_LAYOUT == 1 GMemCacheHead* kCacheVLLM, GMemCacheHead* vCacheVLLM, @@ -3286,7 +3331,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else q, #endif - cacheList, + attentionSinks, cacheList, #if USE_BEAM_SEARCH beamSearchParams, #endif @@ -3322,7 +3367,7 @@ void launchHopperF8MHA(cudaDeviceProp const& prop, uint32_t nbKHeads, #else q, #endif - cacheList, + attentionSinks, cacheList, #if USE_BEAM_SEARCH beamSearchParams, #endif diff --git a/cpp/kernels/xqa/mla_sm120.cu b/cpp/kernels/xqa/mla_sm120.cu index 74877512a7d..072908fe3e8 100644 --- a/cpp/kernels/xqa/mla_sm120.cu +++ b/cpp/kernels/xqa/mla_sm120.cu @@ -1859,12 +1859,13 @@ CUtensorMap makeTensorMapForQ( #endif // IS_MLA void launchMLA(cudaDeviceProp const& prop, - uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed + uint32_t inputSeqLen, // uniform for all requests and causal mask is assumed float qScale, OutputHead* output, InputHead const* q, + float* attentionSinks, // [headGrpSize], not supported. #if USE_PAGED_KV_CACHE - GMemCacheHead* pool, // global pool of pages + GMemCacheHead* pool, // global pool of pages KVCachePageIndex const* - kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] + kvCachePageList, // device pointer. shape: KVCachePage[batchSize][beamWidth][2][maxNbPagesPerSeq] #else GMemKVCacheHead* kvCacheData, #endif diff --git a/cpp/kernels/xqa/test/refAttention.cpp b/cpp/kernels/xqa/test/refAttention.cpp index d8f1a688f5d..dd356c101c0 100644 --- a/cpp/kernels/xqa/test/refAttention.cpp +++ b/cpp/kernels/xqa/test/refAttention.cpp @@ -45,7 +45,7 @@ using Vector = Matrix; template Eigen::Matrix refFlashAttention(IOHead const* q, CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize) + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) { uint32_t const nbTiles = divUp(seqLen, tileSize); auto gemm1Acc = Eigen::Matrix::Zero().eval(); @@ -113,6 +113,16 @@ Eigen::Matrix refFlashAt } rowSum += tileRowSum; } + + // Add the attention sinks. + if (attentionSinks != nullptr) + { + for (uint32_t i = 0; i < headGrpSize; i++) + { + rowSum[i] += expf(attentionSinks[i] - rowMax[i]); + } + } + Eigen::Matrix out = gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array()); std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); }); @@ -123,7 +133,7 @@ Eigen::Matrix refFlashAt template Eigen::Matrix \ refFlashAttention(IOHead const* q, \ CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, \ - float qScale, float kvScale, float xScale, uint32_t slidingWinSize) + float qScale, float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) INSTANTIATE_refFlashAttention(CacheElem, 64, false, false); INSTANTIATE_refFlashAttention(CacheElem, 64, false, true); @@ -143,7 +153,7 @@ Eigen::Matrix refAttenti #else Eigen::Matrix refAttention(IOHead const* q, CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize) + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks) { #endif float const rcpXScale = 1.f / xScale; @@ -184,7 +194,7 @@ Eigen::Matrix refAttenti Eigen::Matrix x = (gemm0Acc.colwise() - rowMax).array().exp().eval(); - Eigen::Vector const rowSum = x.rowwise().sum().eval(); + Eigen::Vector rowSum = x.rowwise().sum().eval(); std::for_each(x.data(), x.data() + x.size(), [&](float& e) { e = float(MathElem(e * rcpXScale)); }); @@ -200,6 +210,18 @@ Eigen::Matrix refAttenti } } } + + // Add the attention sinks. +#if !SPEC_DEC + if (attentionSinks != nullptr) + { + for (uint32_t i = 0; i < headGrpSize; i++) + { + rowSum[i] += expf(attentionSinks[i] - rowMax[i]); + } + } +#endif + Eigen::Matrix out = gemm1Acc.array().colwise() * (xScale * kvScale / rowSum.array()); std::for_each(out.data(), out.data() + out.size(), [](float& e) { e = float(OutputElem(e)); }); @@ -217,7 +239,7 @@ Eigen::Matrix refAttenti template Eigen::Matrix \ refAttention(IOHead const* q, CacheSeq const& k, \ CacheSeq const& v, uint32_t seqLen, float qScale, float kvScale, float xScale, \ - uint32_t slidingWinSize) + uint32_t slidingWinSize, float* attentionSinks) #endif INSTANTIATE_refAttention(InputElem, false, false); INSTANTIATE_refAttention(InputElem, false, true); diff --git a/cpp/kernels/xqa/test/refAttention.h b/cpp/kernels/xqa/test/refAttention.h index bfab1418294..a073ed0e801 100644 --- a/cpp/kernels/xqa/test/refAttention.h +++ b/cpp/kernels/xqa/test/refAttention.h @@ -83,7 +83,7 @@ struct CacheSeq template Eigen::Matrix refFlashAttention(IOHead const* q, CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize); + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks); template #if SPEC_DEC @@ -93,7 +93,7 @@ Eigen::Matrix refAttenti #else Eigen::Matrix refAttention(IOHead const* q, CacheSeq const& k, CacheSeq const& v, uint32_t seqLen, float qScale, - float kvScale, float xScale, uint32_t slidingWinSize); + float kvScale, float xScale, uint32_t slidingWinSize, float* attentionSinks); #endif template diff --git a/cpp/kernels/xqa/test/test.cpp b/cpp/kernels/xqa/test/test.cpp index b9228578623..91b35f3e1a4 100644 --- a/cpp/kernels/xqa/test/test.cpp +++ b/cpp/kernels/xqa/test/test.cpp @@ -130,7 +130,7 @@ template #endif #endif void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, bool verbose = false, - bool saveData = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30) + bool saveData = false, bool hasAttentionSinks = false, uint32_t ctxLen = ~0U, uint32_t slidingWinSize = 1U << 30) { #if IS_MLA if (nbKHeads != 1) @@ -613,6 +613,17 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, } } + // Allocate the attention sinks (per head) + auto attentionSinks = ManagedMemBuf(nbQHeads); + // The attention sinks ptr. + float* attentionSinksPtr = hasAttentionSinks ? reinterpret_cast(attentionSinks.get()) : nullptr; + // Initialize the attention sinks (use large values to detect the potential bugs). + for (uint32_t i = 0; i < nbQHeads; i++) + { + // Range: [2, 5] + attentionSinks.get()[i] = 2.f + float(i % 4); + } + if (verbose) { printf("migrating data to gpu\n"); @@ -640,6 +651,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, #if BEAM_WIDTH > 1 cacheIndir.prefetch(dev, stream); #endif + attentionSinks.prefetch(dev, stream); }; prefetchToDevice(device); checkCuda(cudaMemsetAsync(semaphores.get(), 0, 4 * nbSemaphores, stream)); @@ -720,6 +732,7 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, &qHeads[0][0][0], #endif #endif + attentionSinksPtr, #if PAGED_KV_CACHE_LAYOUT == 1 && USE_PAGED_KV_CACHE cacheKHeads.get(), cacheVHeads.get(), #else @@ -1028,10 +1041,13 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, hostMask, qSeqLen, q_len); #else Eigen::Matrix refOutput; + auto const refAttentionSinks + = hasAttentionSinks ? attentionSinksPtr + headGrpSize * idxKHead : nullptr; if (useQGMMA) { refOutput = refFlashAttention(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, - vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize); + vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, + refAttentionSinks); // refOutput = refAttention(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, // vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize); } @@ -1039,8 +1055,9 @@ void runTest(uint32_t batchSize, uint32_t seqLen, bool testPerf, bool refCheck, { // refOutput = refFlashAttention(&qHeads[req][b][headGrpSize * idxKHead], // kCacheSeq, vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale); - refOutput = refAttention(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, - vCacheSeq, seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize); + refOutput + = refAttention(&qHeads[req][b][headGrpSize * idxKHead], kCacheSeq, vCacheSeq, + seqLen, qScaleForRef, kvCacheScale[0], xScale, slidingWinSize, refAttentionSinks); } #endif if (lowPrecOutput) @@ -1196,11 +1213,23 @@ TEST(RefCheck, llama_V2_70b) runTest<2>(2, 514, false, true); runTest<1>(1, 4096, false, true); #if SLIDING_WINDOW - runTest<2>(2, 4096, false, true, false, false, ~0, 256); - runTest<2>(2, 400, false, true, false, false, ~0U, 256); + runTest<2>(2, 4096, false, true, false, false, false, ~0, 256); + runTest<2>(2, 400, false, true, false, false, false, ~0U, 256); #endif runTest<8>(120, 367, false, true); - // runTest<8>(1792, 2048, false, true); + runTest<8>(1792, 2048, false, true); +} + +TEST(RefCheck, attention_sinks) +{ + auto runAttentionSinksTest = [](uint32_t batchSize, uint32_t seqLen) + { runTest<8>(batchSize, seqLen, false, true, false, false, /*hasAttentionSinks*/ true); }; + + runAttentionSinksTest(2, 2); + runAttentionSinksTest(2, 15); + runAttentionSinksTest(2, 256); + runAttentionSinksTest(2, 514); + runAttentionSinksTest(1, 4096); } TEST(Perf, tracing_long) @@ -1264,7 +1293,7 @@ TEST(Perf, mlperf_gptj) #ifndef NDEBUG GTEST_SKIP() << "Skipping perf tests for debug build"; #endif - runTest<32>(396, 800 + 224, true, false, false, false, 800); + runTest<32>(396, 800 + 224, true, false, false, false, false, 800); } TEST(Perf, mlperf_llama) diff --git a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h index 565c170e1df..2559ae54840 100644 --- a/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h +++ b/cpp/micro_benchmarks/mixtureOfExpertsBackendBenchmarkFixture.h @@ -53,6 +53,7 @@ using namespace CUTLASS_MOE_GEMM_KERNELS_NAMESPACE; using CUTLASS_MOE_GEMM_NAMESPACE::TmaWarpSpecializedGroupedGemmInput; using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::CutlassMoeFCRunner; using CUTLASS_MOE_GEMM_NAMESPACE::ActivationType; +using CUTLASS_MOE_GEMM_KERNELS_NAMESPACE::ActivationParams; using CUTLASS_MOE_GEMM_NAMESPACE::isGatedActivation; static BufferManager::CudaStreamPtr streamPtr; @@ -980,11 +981,11 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture auto stream = streamPtr->get(); MoeMinLatencyParams min_latency_params; #ifdef USING_OSS_CUTLASS_MOE_GEMM - mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, + mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true, mSelectedExperts + mSelectedExpertsSize * mBufferIndex, mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr, mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex, - mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex, + ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex, mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex, mFinalOutput + mFinalOutputSize * mBufferIndex, @@ -992,11 +993,11 @@ class MixtureOfExpertsBenchmark : public ::benchmark::Fixture /*enable_alltoall=*/false, mUseLora, mLoraParams[mBufferIndex], /*use_fp8_block_scaling=*/false, /*min_latency_mode=*/false, min_latency_params, stream); #else - mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, + mMoERunner.runMoe(mInputTensor + mInputTensorSize * mBufferIndex, nullptr, true, mSelectedExperts + mSelectedExpertsSize * mBufferIndex, mUseFinalScale ? mScaleProbs + mScaleProbsSize * mBufferIndex : nullptr, mExpertWeight1 + mExpertWeight1Size * mBufferIndex, mExpertBias1 + mExpertBias1Size * mBufferIndex, - mActType, mExpertWeight2 + mExpertWeight2Size * mBufferIndex, + ActivationParams(mActType), mExpertWeight2 + mExpertWeight2Size * mBufferIndex, mExpertBias2 + mExpertBias2Size * mBufferIndex, mQuantParams[mBufferIndex], mTotalTokens, mHiddenSize, mInterSize, mNumExperts, mK, mWorkspace + mWorkspaceSize * mBufferIndex, mFinalOutput + mFinalOutputSize * mBufferIndex, diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp index d95ca1b412b..fbe03d1dcc8 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.cpp @@ -75,6 +75,7 @@ BlockRange getBlockRangeForReceiving(BaseKVCacheManager* cacheManager, LlmReques bool CacheFormatter::needSendCache( CacheState const& selfConfig, CacheState const& destConfig, runtime::SizeType32 selfIdx) { + // int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); if (targetInfo.mDupHeadFactor <= 1) { @@ -89,9 +90,8 @@ bool CacheFormatter::needSendCache( = selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize; selfTpRankInDpGroup = selfTpRank % selfTPNumInDPGroup; } - int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0; - return (destDPRank % targetInfo.mDupHeadFactor) == (selfTpRankInDpGroup % targetInfo.mDupHeadFactor); + return selfTpRankInDpGroup % targetInfo.mDupHeadFactor == 0; } void checkAlternateWindow(BaseKVCacheManager* cacheManager, BaseCacheFormatter::CacheState const& selfConfig, @@ -128,12 +128,11 @@ std::vector CacheFormatter::pickRecvConnections( return ret; } TLLM_CHECK(numConnections == targetInfo.mIRanks.size()); - int selfDPRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0; std::vector ret; for (int i = 0; i < targetInfo.mDomainTPSize; i++) { - if ((i % targetInfo.mPeerDupHeadFactor) == (selfDPRank % targetInfo.mPeerDupHeadFactor)) + if (i % targetInfo.mPeerDupHeadFactor == 0) { for (int j = 0; j < targetInfo.mDomainPPSize; j++) { @@ -361,7 +360,7 @@ void CacheFormatter::format(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (connections.size() > 1) @@ -713,7 +712,7 @@ void CacheFormatter::unformat(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (pickUpConnections.size() > 1) { diff --git a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h index ee199c2fb1c..8ae8ee5f2ca 100644 --- a/cpp/tensorrt_llm/batch_manager/cacheFormatter.h +++ b/cpp/tensorrt_llm/batch_manager/cacheFormatter.h @@ -76,15 +76,6 @@ class BaseCacheFormatter /// @brief Destructor. virtual ~BaseCacheFormatter() = default; - - // TODO: better way for context/generation tagging - void markAsSender(bool isSender) - { - kvCacheMeasureHelper.markAsSender(isSender); - } - -protected: - KvCacheMeasureHelper kvCacheMeasureHelper{common::getEnvKVCacheTransferOutputPath()}; }; // Simple cache block copy. Because it does not involve data splitting or merging, it performs best when the diff --git a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp index 93df2f96ec0..84fd13d5c18 100644 --- a/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp +++ b/cpp/tensorrt_llm/batch_manager/createNewDecoderRequests.cpp @@ -176,7 +176,7 @@ void CreateNewDecoderRequests::newRequest(SizeType32 batchSlot, runtime::decoder BufferManager manager{std::make_shared(decoderStream.get())}; - auto const batchSize = decoderState.getMaxBatchSize(); + auto const batchSize = decoderState.getMaxNumSequences(); TLLM_CHECK(0 <= batchSize && batchSlot < batchSize); auto const maxBeamWidth = decoderState.getMaxBeamWidth(); auto const beamWidth = samplingConfig.beamWidth; diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp index a4617c0d53d..522ec80f84a 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.cpp @@ -91,6 +91,43 @@ std::size_t RequestInfo::serializedSize(RequestInfo const& requestInfo) return totalSize; } +void TransferSession::appendMeasure(double delay, double duration, size_t size) +{ + if (!mRecordMeasure) + { + return; + } + auto bandwidth = size * 8 / (duration / 1000) / 1e9; // byte, ms => Gbps + mMeasures.emplace_back(Measure{delay, duration, bandwidth}); +} + +void TransferSession::exportMeasure(std::ofstream& outFile, bool isContext) const +{ + if (mMeasures.empty()) + { + return; + } + // write header if not exist + if (outFile.tellp() == 0) + { + outFile << "RequestID"; + for (size_t i = 0; i < mMeasures.size(); i++) + { + outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; + } + outFile << '\n'; + } + // write measures + TLLM_CHECK(isContext || mRequest->getContextPhaseParams().has_value()); + auto reqId = isContext ? mRequest->mRequestId : mRequest->getContextPhaseParams().value().getReqId(); + outFile << reqId; + for (auto const& measure : mMeasures) + { + outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; + } + outFile << '\n' << std::flush; +} + class DataResponder::Impl { public: diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h index 91215ff66c2..ef66cd1382d 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiver.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiver.h @@ -97,15 +97,23 @@ class RequestInfo class TransferSession { public: + struct Measure + { + double delay; // from last token (ctx) or arrival time (gen), in ms + double duration; // in ms + double bandwidth; // in Gbps + }; + TransferSession(std::vector connections, DataContext dataContext, executor::DataTransceiverState const& selfState, executor::DataTransceiverState otherState, - runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr) + runtime::BufferManager const& bufferManager, LlmRequest const* llmRequest = nullptr, bool recordMeasure = false) : mConnections(std::move(connections)) , mDataContext(dataContext) , mSelfState(&selfState) , mOtherState(std::move(otherState)) , mBufferManager(&bufferManager) , mRequest(llmRequest) + , mRecordMeasure(recordMeasure) { TLLM_CHECK(!mConnections.empty()); } @@ -163,6 +171,11 @@ class TransferSession mRequest = &llmRequest; } + void appendMeasure(double delay, double duration, size_t size); + + // TODO: 1. use global id instead of context request id; 2. export to llm metrics instead of file + void exportMeasure(std::ofstream& outFile, bool isContext) const; + private: std::vector mConnections; DataContext mDataContext; @@ -170,6 +183,8 @@ class TransferSession executor::DataTransceiverState mOtherState; runtime::BufferManager const* mBufferManager; LlmRequest const* mRequest; + bool mRecordMeasure; + std::vector mMeasures; }; // Operators required for data transmission in specific communication protocols. @@ -266,79 +281,4 @@ class DataRequester std::unique_ptr mImpl; }; -class KvCacheMeasureHelper -{ -public: - struct Measure - { - double delay; // from last token (ctx) or arrival time (gen), in ms - double duration; // in ms - double bandwidth; // in Gbps - }; - - KvCacheMeasureHelper(std::string output_path) - : mOutputPath(std::move(output_path)) - { - } - - void markAsSender(bool isSender) - { - mIsSender = isSender; - } - - void appendKVCacheTransfer(LlmRequest::RequestIdType requestId, double delay, double duration, size_t size) - { - auto bandwidth = size * 8 / (duration / 1000) / 1e9; - if (mOutputPath.empty()) - { - return; - } - - std::lock_guard lock(mMutex); - mRequestKVCacheTranfserMeasure[requestId].emplace_back(Measure{delay, duration, bandwidth}); - } - - ~KvCacheMeasureHelper() - { - if (!mRequestKVCacheTranfserMeasure.empty() && !mOutputPath.empty()) - { - TLLM_CHECK(mIsSender.has_value()); - auto rank = mpi::MpiComm::world().getRank(); - std::string outFilePath - = mOutputPath + "rank_" + std::to_string(rank) + "_" + (mIsSender.value() ? "send" : "recv") + ".csv"; - std::ofstream outFile(outFilePath); - - TLLM_CHECK_WITH_INFO(outFile.is_open(), "Cannot write to file " + outFilePath); - - size_t numTransferMeasure = mRequestKVCacheTranfserMeasure.begin()->second.size(); - - outFile << "RequestID"; - for (size_t i = 0; i < numTransferMeasure; i++) - { - outFile << ",Delay(ms),Duration(ms),Bandwidth(Gbps)"; - } - outFile << '\n'; - - for (auto const& [requestID, measures] : mRequestKVCacheTranfserMeasure) - { - outFile << requestID; - - for (auto const& measure : measures) - { - outFile << "," << measure.delay << "," << measure.duration << "," << measure.bandwidth; - } - outFile << '\n'; - } - - outFile.close(); - } - } - -private: - std::map> mRequestKVCacheTranfserMeasure; - std::string mOutputPath; - std::mutex mMutex; - std::optional mIsSender; -}; - } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp index 9a72bf2d00f..1a5c7fab4dd 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.cpp @@ -21,6 +21,8 @@ #include "tensorrt_llm/executor/cache_transmission/agent_utils/connection.h" #include "tensorrt_llm/runtime/utils/mpiUtils.h" +#include + namespace tensorrt_llm::batch_manager { @@ -30,6 +32,21 @@ static int32_t tagFromRequestId(LlmRequest::RequestIdType requestId) return ((requestId & 0xFFF) << 8) | (kDATA_TAG & 0xFF); } +namespace fs = std::filesystem; + +static fs::path getTransferOutputPath(char const* tag) +{ + auto outputPath = common::getEnvKVCacheTransferOutputPath(); + if (!outputPath.empty()) + { + auto rank = mpi::MpiComm::world().getRank(); + auto path = fs::path(outputPath); + fs::create_directories(path); + return path / ("rank_" + std::to_string(rank) + "_" + tag + ".csv"); + } + return {}; +} + DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, executor::kv_cache::CacheState selfCacheState, SizeType32 selfIndex, std::unique_ptr formatter) : mManager{manager} @@ -39,7 +56,6 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, { TLLM_CHECK(mManager); TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); - mFormatter->markAsSender(true); } [[nodiscard]] RequestInfo DataSenderImpl::recvRequestInfo() @@ -86,7 +102,8 @@ DataSenderImpl::DataSenderImpl(executor::kv_cache::ConnectionManager* manager, if (it == mRequestToSession.end()) { auto session = TransferSession(std::vector(peerRelativeRanks.size(), nullptr), - DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager); + DataContext{tagFromRequestId(requestId)}, mSelfState, info.getTransState(), mBufferManager, nullptr, + !common::getEnvKVCacheTransferOutputPath().empty()); it = mRequestToSession.emplace(requestId, std::move(session)).first; } it->second.setConnection(peerIdx, connection); @@ -125,6 +142,17 @@ void DataSenderImpl::release(LlmRequest::RequestIdType requestId) auto it = mRequestToSession.find(requestId); TLLM_CHECK(it != mRequestToSession.end()); std::unique_lock lk(mMtxForMap); + if (!common::getEnvKVCacheTransferOutputPath().empty()) + { + if (!mMeasuresFile.is_open()) + { + auto outputPath = getTransferOutputPath("send"); + mMeasuresFile.open(outputPath); + TLLM_CHECK_WITH_INFO( + mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str()); + } + it->second.exportMeasure(mMeasuresFile, true); + } mRequestToSession.erase(it); } @@ -137,7 +165,6 @@ DataReceiverImpl::DataReceiverImpl(executor::kv_cache::ConnectionManager* manage TLLM_CHECK(mManager); TLLM_CHECK(mManager->getCommState().getSelfIdx() == selfIndex); TLLM_CHECK(mFormatter); - mFormatter->markAsSender(false); } TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) @@ -203,12 +230,24 @@ TransferSession DataReceiverImpl::sendRequestInfo(LlmRequest const& llmRequest) } auto const& resource = getReceiveCacheResource(llmRequest); return TransferSession(std::move(counterPartConnections), DataContext{tagFromRequestId(requestId)}, mSelfState, - contextState, resource->mBufferManager, &llmRequest); + contextState, resource->mBufferManager, &llmRequest, !common::getEnvKVCacheTransferOutputPath().empty()); } void DataReceiverImpl::receiveSync(TransferSession& session) { mFormatter->unformat(session); + if (!common::getEnvKVCacheTransferOutputPath().empty()) + { + std::unique_lock lock(mMeasuresFileMutex); + if (!mMeasuresFile.is_open()) + { + auto outputPath = getTransferOutputPath("recv"); + mMeasuresFile.open(outputPath); + TLLM_CHECK_WITH_INFO( + mMeasuresFile.is_open(), "Failed to open transfer output file: %s", outputPath.string().c_str()); + } + session.exportMeasure(mMeasuresFile, false); + } } void DataReceiverImpl::sendRequestInfo(executor::kv_cache::Connection const* connection, RequestInfo const& info) diff --git a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h index fa8d2728329..2f277f14fff 100644 --- a/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h +++ b/cpp/tensorrt_llm/batch_manager/dataTransceiverImpl.h @@ -23,6 +23,8 @@ #include "tensorrt_llm/common/envUtils.h" #include "tensorrt_llm/executor/cache_transmission/cacheSplitConcat.h" +#include + namespace tensorrt_llm::batch_manager { struct TransceiverTag @@ -67,6 +69,7 @@ class DataSenderImpl : public DataSender, public TransceiverTag std::unique_ptr mFormatter; std::mutex mMtxForMap; runtime::BufferManager mBufferManager; + std::ofstream mMeasuresFile; }; class DataReceiverImpl : public DataReceiver, public TransceiverTag @@ -103,6 +106,8 @@ class DataReceiverImpl : public DataReceiver, public TransceiverTag std::unique_ptr mFormatter; std::unordered_map> mProcessToResources; std::mutex mProcessIoResouceMutex; + std::ofstream mMeasuresFile; + std::mutex mMeasuresFileMutex; }; } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp index 040dcd147e9..ea5f0981074 100644 --- a/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp +++ b/cpp/tensorrt_llm/batch_manager/guidedDecoder.cpp @@ -88,8 +88,7 @@ void GuidedDecoder::build(ScheduledRequests const& scheduledRequests) continue; } auto const seqSlot = llmReq->mSeqSlot.value(); - if (llmReq->isContextInitState() - && llmReq->getContextCurrentPosition() == llmReq->getPrepopulatedPromptLen()) + if (llmReq->isContextInitState() && llmReq->isFirstContextChunk()) { // The request is in the first context forward step (considering kv cache reuse). auto const& guideType = guidedDecodingParams->getGuideType(); diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp index ff2a2f6b787..ac37278d45f 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheEventManager.cpp @@ -18,20 +18,51 @@ #include "tensorrt_llm/batch_manager/kvCacheEventManager.h" #include "tensorrt_llm/batch_manager/kvCacheManager.h" #include "tensorrt_llm/executor/executor.h" +#include "tensorrt_llm/executor/serialization.h" +#include "tensorrt_llm/runtime/utils/mpiUtils.h" namespace tle = tensorrt_llm::executor; namespace tensorrt_llm::batch_manager::kv_cache_manager { -KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries) +KVCacheEventManager::KVCacheEventManager(size_t maxKVEventEntries, std::optional attentionDpRank, + std::optional attentionDpSize, SizeType32 attentionDpEventsGatherPeriodMs) : mRun{true} , mMaxSize{maxKVEventEntries} , mEventId{0} + , mAttentionDpRank{attentionDpRank} + , mAttentionDpSize{attentionDpSize} + , mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs) { TLLM_CHECK(mMaxSize > 0); - // mWorkerThread = std::thread(std::bind(&KVCacheEventManager::worker, this)); + if (mAttentionDpRank) + { + TLLM_CHECK_WITH_INFO( + mAttentionDpSize.has_value(), "If attention DP rank is set, the attention DP size must also be set"); + TLLM_CHECK_WITH_INFO(mAttentionDpRank.value() < mAttentionDpSize.value(), + "Attention DP rank must be less than attention DP size"); + if (mAttentionDpRank.value() == 0) + { + // Rank 0 will gather events from all other ranks + // Need to increase size + mMaxSize *= mAttentionDpSize.value(); + } + // Create a communicator to be used for event exchange + mMpiComm = std::make_unique(COMM_SESSION.split(0, mAttentionDpRank.value())); + } + else + { + TLLM_CHECK_WITH_INFO( + !mAttentionDpSize.has_value(), "If attention DP rank is not set, the attention DP size must not be set"); + } mWorkerThread = std::thread([this]() { this->worker(); }); +#if ENABLE_MULTI_DEVICE + if (mAttentionDpRank) + { + mExchangeAttentionDpThread = std::thread([this]() { this->exchangeAttentionDpThread(); }); + } +#endif }; KVCacheEventManager::~KVCacheEventManager() @@ -40,12 +71,18 @@ KVCacheEventManager::~KVCacheEventManager() mPendingEmptyCV.notify_all(); mEmptyCV.notify_all(); mWorkerThread.join(); +#if ENABLE_MULTI_DEVICE + if (mAttentionDpRank) + { + mExchangeAttentionDpThread.join(); + } +#endif } void KVCacheEventManager::enqueueCreatedEvent( std::vector const& numBlocksPerCacheLevel, SizeType32 windowSize) { - enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize}); + enqueueEvent({mEventId++, tle::KVCacheCreatedData{numBlocksPerCacheLevel}, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueStoredEvent(std::vector const& blocks, SizeType32 windowSize) @@ -68,7 +105,7 @@ void KVCacheEventManager::enqueueStoredEvent(std::vector const& blocks block->isPrimary() ? kPrimaryLevel : kSecondaryLevel, block->getPriority()); } - enqueueEvent({mEventId++, data, windowSize}); + enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 windowSize) @@ -81,13 +118,13 @@ void KVCacheEventManager::enqueueRemovedEvent(BlockPtr const& block, SizeType32 } else { - enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize}); + enqueueEvent({mEventId++, tle::KVCacheRemovedData{{block->getHash()}}, windowSize, mAttentionDpRank}); } } void KVCacheEventManager::enqueueUpdatedEvent(tle::KVCacheUpdatedData const& data, SizeType32 windowSize) { - enqueueEvent({mEventId++, data, windowSize}); + enqueueEvent({mEventId++, data, windowSize, mAttentionDpRank}); } void KVCacheEventManager::enqueueEvent(tle::KVCacheEvent&& event) @@ -120,8 +157,76 @@ void KVCacheEventManager::flush() mPendingEmptyCV.notify_one(); } +void KVCacheEventManager::exchangeAttentionDpThread() +{ +#if ENABLE_MULTI_DEVICE + while (true) + { + TLLM_CHECK(mAttentionDpRank); + + // Check if any of the ranks have been shutdown + int32_t numFinished = 0; + int32_t finished = mRun ? 0 : 1; + mMpiComm->allreduce(&finished, &numFinished, 1, mpi::MpiType::kINT32, mpi::MpiOp::SUM); + if (numFinished > 0) + { + TLLM_LOG_INFO("One of the rank has been shut down, exiting"); + break; + } + + // If we are not rank 0, send events to rank 0 + if (mAttentionDpRank.value() != 0) + { + std::vector serializedEvents; + uint64_t numEvents = 0; + { + std::lock_guard lck(mEventsMutex); + serializedEvents = executor::Serialization::serialize(mEvents); + numEvents = mEvents.size(); + mEvents.clear(); + } + uint64_t vecSize = numEvents > 0 ? serializedEvents.size() : 0; + mMpiComm->send(&vecSize, 1, mpi::MpiType::kUINT64, 0, mpi::MpiTag::kKvCacheEventSize); + if (vecSize > 0) + { + mMpiComm->send(serializedEvents.data(), serializedEvents.size(), mpi::MpiType::kCHAR, 0, + mpi::MpiTag::kKvCacheEvent); + } + } + else + { + TLLM_CHECK(mAttentionDpSize.has_value()); + // Loop until have received events from all ranks + for (int rank = 1; rank < mAttentionDpSize.value(); ++rank) + { + uint64_t vecSize{0}; + mMpiComm->recv(&vecSize, 1, mpi::MpiType::kUINT64, rank, mpi::MpiTag::kKvCacheEventSize); + if (vecSize > 0) + { + std::vector serializedEvents(vecSize); + mMpiComm->recv( + serializedEvents.data(), vecSize, mpi::MpiType::kCHAR, rank, mpi::MpiTag::kKvCacheEvent); + + // Deserialize the events and add them to the local queue + auto rankEvents = executor::Serialization::deserializeKVCacheEvents(serializedEvents); + { + std::lock_guard lck(mEventsMutex); + mEvents.insert(mEvents.end(), rankEvents.begin(), rankEvents.end()); + mEmptyCV.notify_one(); + } + } + } + } + std::this_thread::sleep_for(std::chrono::milliseconds(mAttentionDpEventsGatherPeriodMs)); + } +#else + TLLM_THROW("Multi device support is disabled."); +#endif +} + void KVCacheEventManager::worker() { + while (true) { std::deque events; @@ -151,6 +256,8 @@ void KVCacheEventManager::worker() // If there's still too many events, take from the front of the events queue. mEvents.insert(mEvents.end(), events.begin() + std::max(0, elementsToRemove), events.end()); + + // Notify the empty condition variable to wake up any waiting threads mEmptyCV.notify_one(); } } diff --git a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp index 4202ba348ac..c032c80757c 100644 --- a/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp @@ -504,8 +504,7 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si std::optional const& tempAttentionWindowInputs, nvinfer1::DataType dtype, SizeType32 sinkBubbleLength, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : mNumLayers{static_cast(numKvHeadsPerLayer.size())} , mTokensPerBlock{tokensPerBlock} , mEventManager{std::move(eventManager)} @@ -530,7 +529,7 @@ BlockManager::BlockManager(std::vector const& numKvHeadsPerLayer, Si TLLM_CHECK(allottedPrimaryBlocks > 0); // You can't have a model with negative primary blocks... mWindowBlockManagers.try_emplace(windowSize, dtype, windowSize, layersWithWindowSize, numKvHeadsPerLayer, sizePerHead, tokensPerBlock, allottedPrimaryBlocks, allottedSecondaryBlocks, maxNumSequences, stream, - onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enableHashKey, enablePartialReuse, + onboardBlocks, cacheType, secondaryOffloadMinPriority, mEventManager, enablePartialReuse, copyOnPartialReuse); } @@ -573,8 +572,7 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind SizeType32 sizePerHead, SizeType32 tokensPerBlock, SizeType32 blocksInPrimaryPool, SizeType32 blocksInSecondaryPool, SizeType32 maxNumSequences, std::shared_ptr stream, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : mDataType{dtype} , mWindowSize{windowSize} , mNumPrimaryBlocks{blocksInPrimaryPool} @@ -596,7 +594,6 @@ WindowBlockManager::WindowBlockManager(nvinfer1::DataType dtype, SizeType32 wind , mLogPrefix{tensorrt_llm::common::fmtstr("BlockManager[windowSize=%u]", mWindowSize)} , mReusedTokens{0.0} , mTotalInputTokens{0.0} - , mEnableHashKey{enableHashKey} , mEnablePartialReuse{enablePartialReuse} , mCopyOnPartialReuse{copyOnPartialReuse} { @@ -920,50 +917,6 @@ void BlockManager::setOffsets(tk::KVCacheIndex* offsetsPtr, nvinfer1::Dims const mWindowBlockManagers.at(windowSize).setOffsets(offsetsPtr, offsetsShape, beamIdx, blockIdx, blockId); } -void WindowBlockManager::addBlockToHashMap(BlockPtr const& block) -{ - if (!mEnableHashKey) - { - return; - } - auto range = mContextBlocksByHash.equal_range(block->getHash()); - for (auto it = range.first; it != range.second; ++it) - { - if (it->second == block) - { - // TODO: change to assert when reused block is added only once - TLLM_LOG_TRACE( - "Block %d by %zx exists", block->getBlockId(), block->getHash(), mContextBlocksByHash.size()); - return; - } - } - TLLM_LOG_TRACE( - "Add block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size()); - mContextBlocksByHash.emplace(block->getHash(), std::move(block)); -} - -void WindowBlockManager::removeBlockFromHashMap(BlockPtr const& block) -{ - if (mContextBlocksByHash.empty() || block->getBlockKey().uniqueTokens.empty()) - { - // Hash key not enabled / Empty block - return; - } - auto range = mContextBlocksByHash.equal_range(block->getHash()); - TLLM_LOG_TRACE( - "Remove block %d by %zx, block n = %zu", block->getBlockId(), block->getHash(), mContextBlocksByHash.size()); - for (auto it = range.first; it != range.second; ++it) - { - if (it->second == block) - { - mContextBlocksByHash.erase(it); - return; - } - } - // TODO: should be unreachable - TLLM_LOG_DEBUG("Trying to remove block %d by %zx that is not in hash map", block->getBlockId(), block->getHash()); -} - void BlockManager::onboardBlock(BlockPtr const& offloadBlock, SizeType32 windowSize) { mWindowBlockManagers.at(windowSize).onboardBlock(offloadBlock); @@ -1104,7 +1057,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Reused partially filled block %d", mLogPrefix.c_str(), matchingBlockId); - addBlockToHashMap(matchingBlock); } searchRoot = nullptr; // no matching needed for following blocks } @@ -1114,7 +1066,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& mEvictionPolicy->claimBlock( matchingBlock, perBlockRetentions[bi].retentionPriority, perBlockRetentions[bi].durationMs); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Matched full block %d", mLogPrefix.c_str(), matchingBlockId); - addBlockToHashMap(matchingBlock); searchRoot = matchingBlock; } onboardBlock(matchingBlock); @@ -1145,7 +1096,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& ++blockItr; } freeBlock->setHash(); - addBlockToHashMap(freeBlock); ++mMissedBlocks; } } @@ -1169,7 +1119,6 @@ SizeType32 WindowBlockManager::loadOrAllocateBlocks(std::vector const& ++blockItr; } freeBlock->setHash(); - addBlockToHashMap(freeBlock); TLLM_LOG_DEBUG("%s::loadOrAllocateBlocks - Beam %d. Allocated non-shared block %d for bi %d", mLogPrefix.c_str(), beamIdx, freeBlock->getBlockId(), bi); } @@ -1369,9 +1318,7 @@ void WindowBlockManager::storeBlocks( if (oldHash != newHash) { TLLM_LOG_DEBUG("#%d block hash %zx -> %zx", block->getBlockId(), oldHash, newHash); - removeBlockFromHashMap(block); block->setHash(newHash); - addBlockToHashMap(block); } searchRoot = block; } @@ -1408,7 +1355,6 @@ void WindowBlockManager::replaceSharedBlock(GenerationRequest& sequence, SizeTyp if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block); - removeBlockFromHashMap(block); } } @@ -1473,7 +1419,6 @@ void WindowBlockManager::releaseLastBlock(GenerationRequest& sequence) if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block, true); - removeBlockFromHashMap(block); } // Remove block from allocated blocks allocatedBlocks.pop_back(); @@ -1616,7 +1561,6 @@ void WindowBlockManager::releaseBlocks(GenerationRequest& sequence) if (!block->hasRefs()) { mEvictionPolicy->releaseBlock(block); - removeBlockFromHashMap(block); } } // Remove stored block ids in sequence @@ -1654,8 +1598,7 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::make_shared(reinterpret_cast(stream)), maxSequenceLength, - enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, false, enablePartialReuse, - copyOnPartialReuse) + enableBlockReuse, onboardBlocks, cacheType, std::nullopt, nullptr, enablePartialReuse, copyOnPartialReuse) { } @@ -1682,8 +1625,7 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : mMaxBeamWidth(maxBeamWidth) , mDataType(dtype) , mMaxAttentionWindow(*std::max_element(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end())) @@ -1693,10 +1635,9 @@ KVCacheManager::KVCacheManager(std::vector const& numKvHeadsPerLayer , mBlockManager(numKvHeadsPerLayer, sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, std::move(stream), maxSequenceLength, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, mSinkBubbleLength, onboardBlocks, cacheType, secondaryOffloadMinPriority, std::move(eventManager), - enableHashKey, enablePartialReuse, copyOnPartialReuse) + enablePartialReuse, copyOnPartialReuse) // disable block reuse for sink bubble since chopVectorIntoBlocks does not match KV cache blocks in this case , mEnableBlockReuse{mSinkBubbleLength > 0 ? false : enableBlockReuse} - , mEnableHashKey{enableHashKey} { TLLM_CHECK_DEBUG(std::find(maxAttentionWindowVec.begin(), maxAttentionWindowVec.end(), mMaxAttentionWindow) != maxAttentionWindowVec.end()); @@ -1716,12 +1657,11 @@ KVCacheManager::KVCacheManager(SizeType32 numLayers, SizeType32 numKvHeads, Size SizeType32 sinkTokenLength, CudaStreamPtr stream, std::optional maxSequenceLength, bool enableBlockReuse, bool onboardBlocks, CacheType cacheType, std::optional secondaryOffloadMinPriority, - std::shared_ptr eventManager, bool enableHashKey, bool enablePartialReuse, - bool copyOnPartialReuse) + std::shared_ptr eventManager, bool enablePartialReuse, bool copyOnPartialReuse) : KVCacheManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow, maxNumSequences, maxBeamWidth, maxAttentionWindowVec, tempAttentionWindowInputs, dtype, sinkTokenLength, std::move(stream), maxSequenceLength, enableBlockReuse, onboardBlocks, cacheType, secondaryOffloadMinPriority, - std::move(eventManager), enableHashKey, enablePartialReuse, copyOnPartialReuse) + std::move(eventManager), enablePartialReuse, copyOnPartialReuse) { } @@ -2085,30 +2025,6 @@ void KVCacheManager::addSequence( llmRequest->mRequestId); } mBlockManager.addSequence(sequence, numContextBlocks, unsharedBlockIdx, windowSize); - if (mEnableHashKey && llmRequest.has_value() && beamWidth == 1) - { - constexpr SizeType32 beamIdx = 0; - auto const& blockIds = sequence.getCacheBlockIds(windowSize).at(beamIdx); - auto const& uniqueTokens = llmRequest->getUniqueTokens(beamIdx); - auto blockedUniqueTokens = chopVectorIntoBlocks( - uniqueTokens, uniqueTokens.size() - 1, getTokensPerBlock(), true); - auto blockKeys = buildBlockKeys(blockedUniqueTokens, *llmRequest); - auto tokensPerBlock = static_cast(getTokensPerBlock()); - for (size_t i = 0; i < blockIds.size(); i++) - { - auto const& block = mBlockManager.getBlockById(blockIds[i], windowSize); - if (i < blockKeys.size()) - { - block->setBlockKey(blockKeys[i], blockKeys[i].uniqueTokens.size() == tokensPerBlock); - } - else - { - block->setBlockKey({}, false); - } - block->setHash(); - mBlockManager.addBlockToHashMap(block, windowSize); - } - } } cacheBlockOffsets(sequence, windowSize); } diff --git a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp index a9a4aec5dfc..dcebc9c3ac6 100644 --- a/cpp/tensorrt_llm/batch_manager/llmRequest.cpp +++ b/cpp/tensorrt_llm/batch_manager/llmRequest.cpp @@ -365,4 +365,10 @@ void LlmRequest::moveLoraWeightsToGpu(runtime::BufferManager const& manager) mLoraWeights = gpuLoraWeights; } +void LlmRequest::removeLoraTensors() +{ + mLoraWeights.reset(); + mLoraConfig.reset(); +} + } // namespace tensorrt_llm::batch_manager diff --git a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp index 824a31129f8..22c90095e77 100644 --- a/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp +++ b/cpp/tensorrt_llm/batch_manager/mlaCacheFormatter.cpp @@ -45,12 +45,10 @@ std::vector MLACacheFormatter::pickRecvConnections( auto targetInfo = executor::kv_cache::targetIRanks(destConfig, selfConfig, selfIdx); TLLM_CHECK(numConnections == targetInfo.mIRanks.size()); std::vector ret; - // targetInfo , mRanks [tpranks, ppranks] - int dpRank = selfConfig.getParallelConfig().mEnableAttentionDP ? selfConfig.getParallelConfig().mDPrank : 0; - + // targetInfo , mRanks [tpranks, dpranks] for (int i = 0; i < targetInfo.mDomainPPSize; i++) { - ret.push_back(i + (dpRank % (targetInfo.mDomainTPSize)) * targetInfo.mDomainPPSize); + ret.push_back(i); } return ret; } @@ -60,24 +58,19 @@ bool MLACacheFormatter::needSendCache( { int selfTpRank = selfIdx % selfConfig.getParallelConfig().mTensorParallelism; - int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP - ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize - : destConfig.getParallelConfig().mTensorParallelism; - int destDPRank = destConfig.getParallelConfig().mEnableAttentionDP ? destConfig.getParallelConfig().mDPrank : 0; - if (selfConfig.getParallelConfig().mEnableAttentionDP) { int selfTPNumInDPGroup = selfConfig.getParallelConfig().mTensorParallelism / selfConfig.getParallelConfig().mDPsize; - + int destTPNumInDPGroup = destConfig.getParallelConfig().mEnableAttentionDP + ? destConfig.getParallelConfig().mTensorParallelism / destConfig.getParallelConfig().mDPsize + : destConfig.getParallelConfig().mTensorParallelism; int selfTPrankINDPGroup = selfTpRank % selfTPNumInDPGroup; if (selfTPNumInDPGroup <= destTPNumInDPGroup) { return true; } - - int dupHeadFactor = selfTPNumInDPGroup / destTPNumInDPGroup; - return selfTPrankINDPGroup % dupHeadFactor == destDPRank; + return selfTPrankINDPGroup % (selfTPNumInDPGroup / destTPNumInDPGroup) == 0; } int destTPNum = destConfig.getParallelConfig().mEnableAttentionDP @@ -88,8 +81,7 @@ bool MLACacheFormatter::needSendCache( { return true; } - int dupHeadFactor = selfTPNum / destTPNum; - return selfTpRank % dupHeadFactor == destDPRank; + return selfTpRank % (selfTPNum / destTPNum) == 0; } void MLACacheFormatter::format(TransferSession& session) @@ -244,7 +236,7 @@ void MLACacheFormatter::format(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(llmRequest.mRequestId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (connections.size() > 1) @@ -441,7 +433,7 @@ void MLACacheFormatter::unformat(TransferSession& session) } double cacheTransferTime = std::max(0.0, std::chrono::duration(endTime - startTime).count()); - kvCacheMeasureHelper.appendKVCacheTransfer(ctxReqId, delay, cacheTransferTime, size); + session.appendMeasure(delay, cacheTransferTime, size); }; if (pickUpConnections.size() > 1) diff --git a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp index f513f2a3a10..cc62bd3eb04 100644 --- a/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp +++ b/cpp/tensorrt_llm/batch_manager/peftCacheManager.cpp @@ -591,10 +591,9 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr llmRe TLLM_LOG_DEBUG("%s start", __PRETTY_FUNCTION__); if (llmRequest->getLoraTaskId().has_value()) { - auto taskId = llmRequest->getLoraTaskId().value(); try { - return mHostLoraCache->determineNumPages(taskId); + return mHostLoraCache->determineNumPages(llmRequest->getLoraTaskId().value()); } catch (std::runtime_error& e) { @@ -602,16 +601,6 @@ SizeType32 PeftCacheManager::determineNumPages(std::shared_ptr llmRe { return mHostLoraCache->determineNumPages(llmRequest->getLoraConfig().value()); } - if (!llmRequest->getLoraWeights().has_value()) - { - auto const reqId = llmRequest->mRequestId; - std::string errMsg - = "Request ID " + std::to_string(reqId) + " has no LoRA adapter weights while configured with LoRA task " - + std::to_string(taskId) + " that's not found in LoRA CPU cache." - " Note that currently a request with LoRA task that was already loaded is sent without its LoRA weights to save its serialization, copy and deserialization," - " so if this LoRA task was evicted from LoRA CPU cache, then its reuse is currently not supported."; - throw PeftTaskNotCachedException(errMsg); - } throw; } } diff --git a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp index 4a5ddb89286..d42d798f68b 100644 --- a/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp +++ b/cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp @@ -693,7 +693,7 @@ std::unique_ptr TrtGptModelInflightBatching::c kvCacheConfig.getEventBufferMaxSize() > 0 ? std::make_unique(kvCacheConfig.getEventBufferMaxSize()) : nullptr, - false, kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse()); + kvCacheConfig.getEnablePartialReuse(), kvCacheConfig.getCopyOnPartialReuse()); reshapeKvTensors(kvCacheManager->getOffsetTableDimensions()); diff --git a/cpp/tensorrt_llm/common/attentionOp.cpp b/cpp/tensorrt_llm/common/attentionOp.cpp index 6e1498ba713..03d03eca3af 100644 --- a/cpp/tensorrt_llm/common/attentionOp.cpp +++ b/cpp/tensorrt_llm/common/attentionOp.cpp @@ -55,6 +55,7 @@ struct FusedQKVMaskedAttentionDispatchParams T const* qkv_bias; T const* relative_attention_bias; bool const* attention_mask; + float const* attention_sinks; float const* logn_scaling_ptr; int const* cache_indir; void* context_buf; @@ -71,6 +72,7 @@ struct FusedQKVMaskedAttentionDispatchParams RotaryScalingType rotary_embedding_scale_type; float rotary_embedding_scale; float const* rotary_embedding_inv_freq_cache; + float2 const* rotary_embedding_cos_sin_cache; float rotary_embedding_short_m_scale; float rotary_embedding_long_m_scale; int rotary_embedding_max_positions; @@ -225,6 +227,7 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams.output = generationsParams.context_buf; xqaParams.qkv = generationsParams.attention_input; xqaParams.cache_indir = generationsParams.cache_indir; + xqaParams.attention_sinks = generationsParams.attention_sinks; xqaParams.kv_scale_orig_quant = generationsParams.kv_scale_orig_quant; xqaParams.kv_scale_quant_orig = generationsParams.kv_scale_quant_orig; xqaParams.host_past_key_value_lengths = generationsParams.host_past_key_value_lengths; @@ -275,7 +278,8 @@ bool AttentionOp::convertMMHAParamsToXQAParams(tensorrt_llm::kernels::XQAParams& xqaParams.logn_scaling_ptr = generationsParams.logn_scaling_ptr; xqaParams.total_num_input_tokens = mCpSize > 1 ? generationsParams.num_requests : generationsParams.num_tokens; xqaParams.is_fp8_output = mFP8ContextFMHA; - xqaParams.fp8_out_scale = (mFP8ContextFMHA ? generationsParams.attention_output_orig_quant : nullptr); + xqaParams.fp8_out_scale + = ((mFP8ContextFMHA || mFP8ContextMLA) ? generationsParams.attention_output_orig_quant : nullptr); // Parameters required for FP4 output. xqaParams.output_sf = generationsParams.context_buf_sf; xqaParams.fp4_out_sf_scale = generationsParams.attention_output_sf_scale; @@ -596,6 +600,7 @@ void fusedQKV_masked_attention_dispatch(Multihead_attention_paramsisSeparateQAndKvInput() + int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_v_per_head = (mMLAParams.v_head_dim); + + // Total dimension per token across all heads for Q, K, and V components respectively + int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head; + int const total_k_dim_all_heads + = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout + int const total_v_dim_all_heads + = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout + + int const num_total_qkv_elements + = max_num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + + size_t fp8_qkv_buffer_size = mFP8ContextFMHA && mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? max_num_tokens * size_t(local_hidden_units_qo + 2 * local_hidden_units_kv) : 0; + if (mFP8ContextMLA) + { + fp8_qkv_buffer_size + = mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0; + } + size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; size_t const encoder_padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * max_num_tokens; // Each token holds (batch_idx, token_idx_in_seq) int2. @@ -1342,10 +1369,26 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea size_t const qk_buf_float_size = mEnableContextFMHA ? 0 : sizeof(float) * params.batch_size * mNumHeads * params.input_seq_length * kv_seq_length; - size_t const fp8_qkv_buffer_size - = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() + int const dim_q_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_k_per_head = (mMLAParams.qk_rope_head_dim + mMLAParams.qk_nope_head_dim); + int const dim_v_per_head = (mMLAParams.v_head_dim); + + // Total dimension per token across all heads for Q, K, and V components respectively + int const total_q_dim_all_heads = mNumAttnHeads * dim_q_per_head; + int const total_k_dim_all_heads + = mNumAttnHeads * dim_k_per_head; // Assuming effective num_kv_heads = head_num for layout + int const total_v_dim_all_heads + = mNumAttnHeads * dim_v_per_head; // Assuming effective num_kv_heads = head_num for layout + int const num_total_qkv_elements + = params.num_tokens * (total_q_dim_all_heads + total_k_dim_all_heads + total_v_dim_all_heads); + size_t fp8_qkv_buffer_size = mEnableContextFMHA && mFP8ContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? params.num_tokens * (local_hidden_units_qo + 2 * local_hidden_units_kv) : 0; + if (mFP8ContextMLA) + { + fp8_qkv_buffer_size + = mEnableContextFMHA && !mFmhaDispatcher->isSeparateQAndKvInput() ? num_total_qkv_elements : 0; + } size_t const padding_offset_size = mEnableContextFMHA ? 0 : sizeof(int) * params.batch_size * params.input_seq_length; size_t const encoder_padding_offset_size @@ -1353,8 +1396,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea // Each token holds (batch_idx, token_idx_in_seq) int2. size_t const tokens_info_size = sizeof(int2) * params.num_tokens; size_t const fmha_scheduler_counter = mEnableContextFMHA ? sizeof(uint32_t) : 0; - size_t const fmha_bmm1_scale_size = mFP8ContextFMHA ? sizeof(float) * 2 : 0; - size_t const fmha_bmm2_scale_size = mFP8ContextFMHA ? sizeof(float) : 0; + size_t const fmha_bmm1_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) * 2 : 0; + size_t const fmha_bmm2_scale_size = (mFP8ContextFMHA || mFP8ContextMLA) ? sizeof(float) : 0; // cp workspace size upper bound size_t const cpMaxPadedSequenceLength = params.num_tokens + params.batch_size * (mCpSize - 1); @@ -1601,6 +1644,15 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea params.mla_param->cache_type = cache_type; params.mla_param->cu_q_seqlens = cu_q_seqlens; params.mla_param->quant_scale_kv = params.kv_scale_orig_quant; + // Set BMM scales for FP8 context computation + params.mla_param->bmm1_scale = fmha_bmm1_scale_ptr; + params.mla_param->bmm2_scale = fmha_bmm2_scale_ptr; + params.mla_param->host_bmm1_scale = decoder_params.fmhaHostBmm1Scale; + params.mla_param->quant_attention_input_buf = mFP8ContextMLA ? fp8_qkv_buffer : nullptr; + // Set additional scales for context phase + params.mla_param->quant_scale_o = params.attention_output_orig_quant; + params.mla_param->dequant_scale_q = params.kv_scale_quant_orig; + params.mla_param->dequant_scale_kv = params.kv_scale_quant_orig; if (mPagedContextFMHA && mPagedKVCache) { TLLM_CHECK_WITH_INFO(params.mla_param->context_paged_kv_ptr != nullptr, @@ -1679,8 +1731,8 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea // TODO: set it correctly for contiguous kv buffer (cross-attention). fmhaParams.totalKvSeqLen = isCrossAttention() ? params.num_encoder_tokens : params.num_tokens; // Device buffer pointers. - fmhaParams.qkvPtr = mFP8ContextFMHA ? reinterpret_cast(fp8_qkv_buffer) - : reinterpret_cast(attention_input); + fmhaParams.qkvPtr = (mFP8ContextFMHA || mFP8ContextMLA) ? reinterpret_cast(fp8_qkv_buffer) + : reinterpret_cast(attention_input); fmhaParams.qPtr = reinterpret_cast(q_buf_2_); // TODO: add contiguous kv buffer (cross-attention). fmhaParams.kvPtr = nullptr; @@ -1691,6 +1743,7 @@ int AttentionOp::enqueueContext(EnqueueContextParams const& params, cudaStrea fmhaParams.outputPtr = mCpSize > 1 ? gatherOutBuffer : params.context_buf; // only use [totalLength, h / cpSize, Dh] fmhaParams.outputSfPtr = params.context_buf_sf; + fmhaParams.attentionSinksPtr = params.attention_sinks; fmhaParams.packedMaskPtr = params.attention_packed_mask; if constexpr (std::is_same_v) { @@ -2220,6 +2273,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams const& params, cud dispatch_params.relative_attention_bias_stride = relative_attention_bias_stride; dispatch_params.attention_mask = params.attention_mask; dispatch_params.attention_mask_stride = params.attention_mask_stride; + dispatch_params.attention_sinks = params.attention_sinks; dispatch_params.max_distance = max_distance; dispatch_params.cache_indir = params.cache_indir; dispatch_params.context_buf = mCpSize > 1 ? mhaOutput : params.context_buf; // @@ -2267,6 +2321,7 @@ int AttentionOp::enqueueGeneration(EnqueueGenerationParams const& params, cud dispatch_params.rotary_embedding_scale_type = mRotaryEmbeddingScaleType; dispatch_params.rotary_embedding_scale = mRotaryEmbeddingScale; dispatch_params.rotary_embedding_inv_freq_cache = params.rotary_inv_freq; + dispatch_params.rotary_embedding_cos_sin_cache = params.rotary_cos_sin; dispatch_params.rotary_embedding_short_m_scale = mRotaryEmbeddingShortMscale; dispatch_params.rotary_embedding_long_m_scale = mRotaryEmbeddingLongMscale; dispatch_params.rotary_embedding_max_positions = mRotaryEmbeddingMaxPositions; @@ -2477,7 +2532,7 @@ int AttentionOp::initialize() noexcept } // FP8 FMHA should be used with fp8 workflow together. - if (mFP8ContextFMHA) + if (mFP8ContextFMHA || mFP8ContextMLA) { data_type = DATA_TYPE_E4M3; } @@ -2510,6 +2565,11 @@ int AttentionOp::initialize() noexcept fmhaParams.dataTypeOut = DATA_TYPE_BF16; fmhaParams.dataTypeKv = DATA_TYPE_BF16; } + if (mFP8ContextMLA && mKVCacheQuantMode.hasFp8KvCache()) + { + fmhaParams.dataTypeKv = DATA_TYPE_E4M3; + fmhaParams.dataTypeOut = DATA_TYPE_BF16; + } // TODO: remove forceFp32Acc from MHARunnerFixedParams after adding host_runtime_perf_knobs to // bertAttentionPlugin input tensors, so that we can change mLaunchParams.force_fp32_acc value in runtime. fmhaParams.forceFp32Acc = false; @@ -2563,7 +2623,7 @@ int AttentionOp::initialize() noexcept // Deepseek-V2 Generation needs a differ fmha with different argumments if (mIsMLAEnabled) { - mEnableXQA = (mSM == kSM_120); + mEnableXQA = (mSM == kSM_120) && mIsGenerationMLA; if (mUseTllmGen) { Data_type qDataType = DATA_TYPE_FP32; @@ -2826,6 +2886,7 @@ std::string AttentionOp::toString() const ss << "mPosShiftEnabled: " << std::boolalpha << mPosShiftEnabled << std::endl; ss << "mPagedContextFMHA: " << std::boolalpha << mPagedContextFMHA << std::endl; ss << "mFP8ContextFMHA: " << std::boolalpha << mFP8ContextFMHA << std::endl; + ss << "mFP8ContextMLA: " << std::boolalpha << mFP8ContextMLA << std::endl; ss << "mDenseContextFMHA: " << std::boolalpha << mDenseContextFMHA << std::endl; ss << "mEnableContextFMHA: " << std::boolalpha << mEnableContextFMHA << std::endl; ss << "mFMHAForceFP32Acc: " << std::boolalpha << mFMHAForceFP32Acc << std::endl; diff --git a/cpp/tensorrt_llm/common/attentionOp.h b/cpp/tensorrt_llm/common/attentionOp.h index fb71c06d57b..25d95dfea2b 100644 --- a/cpp/tensorrt_llm/common/attentionOp.h +++ b/cpp/tensorrt_llm/common/attentionOp.h @@ -65,6 +65,8 @@ class AttentionOp T const* qkv_bias = nullptr; // Attention mask input, which has shape of [batch_size, attention_mask_stride]. bool const* attention_mask = nullptr; + // Attention sinks with shape of [num_heads_q] float. + float const* attention_sinks = nullptr; // Rotary inv_freq cache buffer to avoid re-computing. float const* rotary_inv_freq = nullptr; // Rotary cos sin cache buffer to avoid re-computing. @@ -386,6 +388,7 @@ class AttentionOp bool mPosShiftEnabled = false; bool mPagedContextFMHA = false; bool mFP8ContextFMHA = false; + bool mFP8ContextMLA = false; bool mFP8GenerationMLA = false; bool mDenseContextFMHA = false; bool mHasFullAttentionMask = false; diff --git a/cpp/tensorrt_llm/common/envUtils.cpp b/cpp/tensorrt_llm/common/envUtils.cpp index f7480229410..e321a4b07b3 100644 --- a/cpp/tensorrt_llm/common/envUtils.cpp +++ b/cpp/tensorrt_llm/common/envUtils.cpp @@ -386,7 +386,7 @@ size_t getEnvAllReduceWorkspaceSize() return workspaceSize; } -std::string getEnvKVCacheTransferOutputPath() +std::string const& getEnvKVCacheTransferOutputPath() { static std::string outputPath = getStrEnv("TRTLLM_KVCACHE_TIME_OUTPUT_PATH").value_or(""); return outputPath; diff --git a/cpp/tensorrt_llm/common/envUtils.h b/cpp/tensorrt_llm/common/envUtils.h index 5e29dfaca71..b4921af40e9 100644 --- a/cpp/tensorrt_llm/common/envUtils.h +++ b/cpp/tensorrt_llm/common/envUtils.h @@ -76,7 +76,7 @@ bool getEnvDisableKVCacheTransferOverlap(); bool getEnvEnableReceiveKVCacheParallel(); -std::string getEnvKVCacheTransferOutputPath(); +std::string const& getEnvKVCacheTransferOutputPath(); bool getEnvTryZCopyForKVCacheTransfer(); diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp index c10df82d54c..53dc9e053ad 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/detail/collective/mixed_input_utils.hpp @@ -27,6 +27,78 @@ namespace cutlass::gemm::collective::detail { +using namespace cute; + +typedef uint32_t __nv_fp4x8_storage_t; +typedef uint32_t __nv_bf16x2_storage_t; +typedef cutlass::uint128_t __nv_bf16x8_storage_t; + +constexpr int int4_group_size = 128; +constexpr int mxfp4_group_size = 32; + +inline __device__ unsigned prmt(unsigned hi, unsigned lo, unsigned select_code) +{ + unsigned res = 0; + + asm volatile( + "{\n" + "prmt.b32 %0, %1, %2, %3;\n" + "}\n" + : "=r"(res) + : "r"(lo), "r"(hi), "r"(select_code)); + + return res; +} + +__device__ __inline__ __nv_fp8x4_storage_t cvt_lut_bf16(unsigned const index) +{ + const __nv_fp8x4_storage_t h4b_lut = 0x03020100U; // 7654 + const __nv_fp8x4_storage_t l4b_lut = 0xFFFEFC00U; // 3210 + + __nv_fp8x4_storage_t lut_res = prmt(h4b_lut, l4b_lut, index); + + return lut_res; +} + +__device__ __inline__ __nv_bf16x8_storage_t psx_cvt_lut_prmt_fp4x8_to_bf16x8(const __nv_fp4x8_storage_t fp4x8) +{ + __nv_bf16x8_storage_t bf16x8_raw = {0, 0}; + __nv_bf16x2_storage_t* bf16x2_raw = reinterpret_cast<__nv_bf16x2_storage_t*>(&bf16x8_raw); + + unsigned zero_padding = 0x00000000U; + + unsigned h4b_em_fp4x4 = (fp4x8 & 0x77770000U) >> 16U; + unsigned l4b_em_fp4x4 = (fp4x8 & 0x00007777U); + + __nv_fp8x4_storage_t h4b_2to9_bits = cvt_lut_bf16(h4b_em_fp4x4); // 7654 + __nv_fp8x4_storage_t l4b_2to9_bits = cvt_lut_bf16(l4b_em_fp4x4); // 3210 + + bf16x2_raw[0] = prmt(zero_padding, l4b_2to9_bits, 0x1707U) >> 2U; // 1 0 + bf16x2_raw[1] = prmt(zero_padding, l4b_2to9_bits, 0x3727U) >> 2U; // 3 2 + bf16x2_raw[2] = prmt(h4b_2to9_bits, zero_padding, 0x5040U) >> 2U; // 5 4 + bf16x2_raw[3] = prmt(h4b_2to9_bits, zero_padding, 0x7060U) >> 2U; // 7 6 + + __nv_bf16x2_storage_t bf16x2_0to1_bits; + + __nv_fp8x4_storage_t h_fp8x2_0to1_bits = (fp4x8 & 0x0000C0C0U); // 3 1 + __nv_fp8x4_storage_t l_fp8x2_0to1_bits = (fp4x8 & 0x00000C0CU) << 4U; // 2 0 + + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x4707U); // 1 0 + bf16x2_raw[0] = bf16x2_raw[0] | bf16x2_0to1_bits; + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x5717U); // 3 2 + bf16x2_raw[1] = bf16x2_raw[1] | bf16x2_0to1_bits; + + h_fp8x2_0to1_bits = (fp4x8 & 0xC0C00000U); // 7 5 + l_fp8x2_0to1_bits = (fp4x8 & 0x0C0C0000U) << 4U; // 6 4 + + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x6020U); // 5 4 + bf16x2_raw[2] = bf16x2_raw[2] | bf16x2_0to1_bits; + bf16x2_0to1_bits = prmt(h_fp8x2_0to1_bits, l_fp8x2_0to1_bits, 0x7030U); // 7 6 + bf16x2_raw[3] = bf16x2_raw[3] | bf16x2_0to1_bits; + + return bf16x8_raw; +} + template struct MixedGroupedGemmInputUtils { @@ -46,6 +118,7 @@ struct MixedGroupedGemmInputUtils static constexpr auto KernelConversionMode = Collective::KernelConversionMode; static constexpr auto ModeHasScales = Collective::ModeHasScales; static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable; + static constexpr auto UseFP4ToBF16LookupTable = Collective::UseFP4ToBF16LookupTable; public: static constexpr auto elements_per_smem_scale() @@ -239,6 +312,27 @@ struct MixedGroupedGemmInputUtils } } + // The core converter uses a lookup table to converts i4 -> 8 bit value. + template + CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( // Accept mutable temporaries + Tensor const& src, Tensor&& dst) + { + fp4tobf16_lookup_table_convert(src, dst); + } + + template + CUTLASS_DEVICE static void fp4tobf16_lookup_table_convert( + Tensor const& src, Tensor& dst) + { + + // View the input as reg + auto&& src_ = cute::recast<__nv_fp4x8_storage_t>(src)(0); + auto&& dst_ = cute::recast<__nv_bf16x8_storage_t>(dst)(0); + + dst_ = psx_cvt_lut_prmt_fp4x8_to_bf16x8(src_); + } + /// Utilities to dequantize A. template CUTLASS_DEVICE static void static_check_scale(Layout const& tensor) @@ -253,7 +347,6 @@ struct MixedGroupedGemmInputUtils static_check_scale(flatten(Layout{})); } - // dequantize_A_kblock is here!!! template CUTLASS_DEVICE static void dequantize_A_kblock(Tensor const& tCrA_load, Tensor& tCrA_mma, cute::tuple& partitioned_extra_info, int const k_block) @@ -288,8 +381,6 @@ struct MixedGroupedGemmInputUtils } else if constexpr (UseScaleLookupTable) { - // this path - constexpr int num_elements = decltype(size(src))::value; static_assert(is_same_v, "Lookup table only supports int4 being the quant type now."); @@ -424,7 +515,6 @@ struct MixedGroupedGemmInputUtils static_assert(size_v == cosize_v); static_assert(size_v == cosize_v); using SrcType = typename EngineIn::value_type; - using DstType = typename EngineOut::value_type; Tensor src = tCrA_load(_, _, k_block); Tensor dst = tCrA_mma(_, _, k_block); @@ -441,7 +531,14 @@ struct MixedGroupedGemmInputUtils CUTLASS_PRAGMA_UNROLL for (int i = 0; i < size<1>(dst_vm); ++i) { - LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + if constexpr (UseFP4ToBF16LookupTable) + { + fp4tobf16_lookup_table_convert(src_vm(_, i), dst_vm(_, i)); + } + else + { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + } } } diff --git a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp index 1ee109fd648..2332950629f 100644 --- a/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp +++ b/cpp/tensorrt_llm/cutlass_extensions/include/cutlass_extensions/gemm/collective/sm90_mma_array_tma_gmma_rs_warpspecialized_mixed_input_.hpp @@ -30,37 +30,12 @@ #include "cute/atom/mma_atom.hpp" #include "cute/numeric/arithmetic_tuple.hpp" -#define GROUP_SIZE 128 - ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass::gemm::collective { using namespace cute; -template -CUTE_HOST_DEVICE void warpgroup_wait_() -{ -#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) - cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N); - asm volatile("wgmma.wait_group.sync.aligned %0;\n" ::"n"(N) : "memory"); -#else - CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90A_ENABLED"); -#endif -} - -CUTLASS_DEVICE void warpgroup_wait_dispatch(int onthefly_count) -{ - switch (onthefly_count) - { - case 0: warpgroup_wait_<0>(); break; - case 4: warpgroup_wait_<4>(); break; - case 8: warpgroup_wait_<8>(); break; - case 12: warpgroup_wait_<12>(); break; - default: assert(false && "Invalid onthefly_count value"); - } -} - ///////////////////////////////////////////////////////////////////////////////////////////////// // WarpSpecialized Mainloop @@ -91,7 +66,7 @@ struct CollectiveMmaArrayMixedInput< private: template friend struct detail::MixedGroupedGemmInputUtils; - using CollectiveType = CollectiveMma; using Utils = detail::MixedGroupedGemmInputUtils; @@ -146,6 +121,11 @@ struct CollectiveMmaArrayMixedInput< static_assert(cutlass::gemm::detail::is_mn_major(), "Scale must be MN major [Col Major if A is scaled, Row Major if B is scaled]."); + static constexpr bool IsMXFP4 = cute::is_same_v; + // Group size 128 for int4 weights + // Group size 32 for mxfp4 weights + static constexpr int ScalingGroupSize = IsMXFP4 ? detail::mxfp4_group_size : detail::int4_group_size; + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); using TiledMma = TiledMma_; using ElementAccumulator = typename TiledMma::ValTypeC; @@ -268,6 +248,8 @@ struct CollectiveMmaArrayMixedInput< || KernelConversionMode == ConversionMode::ConvertAndScaleWithZero; static constexpr bool UseScaleLookupTable = KernelConversionMode == ConversionMode::ConvertAndScale && cutlass::detail::is_Array_v; + static constexpr bool UseFP4ToBF16LookupTable = KernelConversionMode == ConversionMode::ConvertAndScale + && cute::is_same_v && cute::is_same_v; static constexpr size_t SmemAlignmentA = cutlass::detail::alignment_for_swizzle(SmemLayoutA{}); static constexpr size_t SmemAlignmentB = cutlass::detail::alignment_for_swizzle(SmemLayoutB{}); static constexpr size_t SmemAlignmentScale = cute::max(SmemAlignmentA, SmemAlignmentB); @@ -705,7 +687,7 @@ struct CollectiveMmaArrayMixedInput< { // The real scale_k that actually works // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / GROUP_SIZE; + auto scale_k = K / ScalingGroupSize; Tensor mS_mkl = mainloop_params.tma_load_scale.get_tma_tensor(make_shape(M, scale_k, L)); // (m,scale_k,l) Tensor gS_mkl = local_tile(mS_mkl, ScaleTileShape{}, make_coord(_, _)); // (BLK_M,BLK_Scale_K,m,scale_k,l) @@ -872,7 +854,6 @@ struct CollectiveMmaArrayMixedInput< } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { - // zero copy auto tZgZ = get<2>(extra_input_partitions); auto tZsZ = get<3>(extra_input_partitions); if (cute::elect_one_sync()) @@ -979,7 +960,8 @@ struct CollectiveMmaArrayMixedInput< return make_tensor_like(tCsA(_, _, _, Int<0>{})); } }(); - Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCsB = mma_warpgroup_slice.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + // tCrB is just a view of the tensor tCsB Tensor tCrB = mma_warpgroup_slice.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) // @@ -1013,8 +995,8 @@ struct CollectiveMmaArrayMixedInput< multiply_add fma; - constexpr int NumMMAsPerChunk = GROUP_SIZE / cute::get<0, 1>(tCsB.shape())(); - constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / GROUP_SIZE; + constexpr int NumMMAsPerChunk = ScalingGroupSize / cute::get<0, 1>(tCsB.shape())(); + constexpr int NumChunksPerTileK = cute::size<1>(sA.shape())() / ScalingGroupSize; cute::array intermediate_array; constexpr int K_BLOCK_MAX = size<2>(tCrA_load); @@ -1045,8 +1027,6 @@ struct CollectiveMmaArrayMixedInput< // src: tCrA_load, dst: tCrA_mma Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); - tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; - // Unroll the K mode manually to set scale D to 1 CUTLASS_PRAGMA_UNROLL for (int chunk_id = 0; chunk_id < NumChunksPerTileK; ++chunk_id) @@ -1079,10 +1059,11 @@ struct CollectiveMmaArrayMixedInput< } } + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { - warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk); warpgroup_fence_operand(intermediate_array[chunk_id_]); // Apply the group-wise scaling @@ -1129,7 +1110,6 @@ struct CollectiveMmaArrayMixedInput< Utils::copy_tensors_MK(smem_tiled_copy_A, tCsA, tCrA_copy_view, partitioned_extra_info, copy_partitions_extra_info, 1, smem_pipe_read.index()); - warpgroup_wait(); Utils::convert_A_kblock(tCrA_load, tCrA_mma, 0); } } @@ -1169,8 +1149,6 @@ struct CollectiveMmaArrayMixedInput< tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait(); // We have K_BLOCK_MAX - 1 GMMA instructions pending for this stage, - // so we can release prior barrier if (k_block == K_BLOCK_MAX - 1) { pipeline.consumer_release( @@ -1187,10 +1165,11 @@ struct CollectiveMmaArrayMixedInput< { // The last k_block + warpgroup_wait<0>(); + CUTLASS_PRAGMA_UNROLL for (int chunk_id_ = 0; chunk_id_ < NumChunksPerTileK; ++chunk_id_) { - warpgroup_wait_dispatch((NumChunksPerTileK - chunk_id_ - 1) * NumMMAsPerChunk); warpgroup_fence_operand(intermediate_array[chunk_id_]); // Apply the group-wise scaling @@ -1257,7 +1236,6 @@ struct CollectiveMmaArrayMixedInput< tiled_mma.accumulate_ = GMMA::ScaleOut::One; warpgroup_commit_batch(); - warpgroup_wait(); if (k_block == K_BLOCK_MAX - 1) { // release prior barrier @@ -1318,7 +1296,7 @@ struct CollectiveMmaArrayMixedInput< smem_pipe_release.advance(k_tile_count); // Wait on all GMMAs to complete - warpgroup_wait<0>(); + // warpgroup_wait<0>(); for (int count = 0; count < prologue_mma_count; ++count) { @@ -1462,7 +1440,7 @@ struct CollectiveMmaArrayMixedInput< { NonVoidElementScale const* ptr_S = nullptr; // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / GROUP_SIZE; + auto scale_k = K / ScalingGroupSize; Tensor tensor_scale = make_tensor( detail::get_logical_ptr(ptr_S), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); cute::detail::fill_tma_gmem_shape_stride( @@ -1472,7 +1450,7 @@ struct CollectiveMmaArrayMixedInput< { ElementZero const* ptr_Z = nullptr; // auto scale_k = K / mainloop_params.chunk_size; - auto scale_k = K / GROUP_SIZE; + auto scale_k = K / ScalingGroupSize; Tensor tensor_zero = make_tensor( detail::get_logical_ptr(ptr_Z), make_shape(M, scale_k, Int<1>{}), mainloop_params.dS[next_group]); cute::detail::fill_tma_gmem_shape_stride( diff --git a/cpp/tensorrt_llm/executor/executor.cpp b/cpp/tensorrt_llm/executor/executor.cpp index 70ca2be41ab..091bb512823 100644 --- a/cpp/tensorrt_llm/executor/executor.cpp +++ b/cpp/tensorrt_llm/executor/executor.cpp @@ -132,10 +132,12 @@ std::optional> Executor::getKVCacheEventMan return mImpl->getKVCacheEventManager(); } -KVCacheEvent::KVCacheEvent(size_t eventId, KVCacheEventData data, SizeType32 windowSize) +KVCacheEvent::KVCacheEvent( + size_t eventId, KVCacheEventData data, SizeType32 windowSize, std::optional attentionDpRank) : eventId{eventId} , data{std::move(data)} , windowSize{windowSize} + , attentionDpRank{attentionDpRank} { } diff --git a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp index 51b047ebd27..21cf314c875 100644 --- a/cpp/tensorrt_llm/executor/kvCacheConfig.cpp +++ b/cpp/tensorrt_llm/executor/kvCacheConfig.cpp @@ -27,6 +27,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co std::optional const& hostCacheSize, bool onboardBlocks, std::optional const& crossKvCacheFraction, std::optional secondaryOffloadMinPriority, size_t eventBufferMaxSize, bool enablePartialReuse, bool copyOnPartialReuse, bool useUvm, + SizeType32 attentionDpEventsGatherPeriodMs, std::optional const& runtimeDefaults) : mEnableBlockReuse(enableBlockReuse) , mHostCacheSize(hostCacheSize) @@ -36,6 +37,7 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co , mEnablePartialReuse{enablePartialReuse} , mCopyOnPartialReuse{copyOnPartialReuse} , mUseUvm{useUvm} + , mAttentionDpEventsGatherPeriodMs(attentionDpEventsGatherPeriodMs) { if (maxTokens) { @@ -61,6 +63,8 @@ KvCacheConfig::KvCacheConfig(bool enableBlockReuse, std::optional co { fillEmptyFieldsFromRuntimeDefaults(runtimeDefaults.value()); } + TLLM_CHECK_WITH_INFO( + mAttentionDpEventsGatherPeriodMs > 0, "Attention DP events gather period must be greater than 0"); } bool KvCacheConfig::getEnableBlockReuse() const @@ -128,6 +132,11 @@ bool KvCacheConfig::getUseUvm() const return mUseUvm; } +SizeType32 KvCacheConfig::getAttentionDpEventsGatherPeriodMs() const +{ + return mAttentionDpEventsGatherPeriodMs; +} + void KvCacheConfig::setEnableBlockReuse(bool enableBlockReuse) { mEnableBlockReuse = enableBlockReuse; @@ -204,6 +213,12 @@ void KvCacheConfig::setUseUvm(bool useUvm) mUseUvm = useUvm; } +void KvCacheConfig::setAttentionDpEventsGatherPeriodMs(SizeType32 attentionDpEventsGatherPeriodMs) +{ + TLLM_CHECK(attentionDpEventsGatherPeriodMs > 0); + mAttentionDpEventsGatherPeriodMs = attentionDpEventsGatherPeriodMs; +} + void KvCacheConfig::fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults) { if (!mMaxAttentionWindowVec && runtimeDefaults.maxAttentionWindowVec) diff --git a/cpp/tensorrt_llm/executor/loraConfig.cpp b/cpp/tensorrt_llm/executor/loraConfig.cpp index 058b1a86710..c8499f36d4d 100644 --- a/cpp/tensorrt_llm/executor/loraConfig.cpp +++ b/cpp/tensorrt_llm/executor/loraConfig.cpp @@ -27,26 +27,29 @@ LoraConfig::LoraConfig(IdType taskId, std::optional weights, std::option , mWeights(std::move(weights)) , mConfig(std::move(config)) { - if (mWeights.has_value() || mConfig.has_value()) + if (mConfig.has_value()) { - TLLM_CHECK_WITH_INFO(mWeights.has_value() && mConfig.has_value(), - "Request for LoRA inference must have both lora weights and lora config"); - - SizeType32 constexpr expectedWeightsDims = 2; SizeType32 constexpr expectedConfigDims = 2; - - TLLM_CHECK_WITH_INFO( - mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions"); TLLM_CHECK_WITH_INFO( mConfig.value().getShape().size() == expectedConfigDims, "Expected config tensor to have 2 dimensions"); - TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU - && mWeights.value().getMemoryType() != MemoryType::kUNKNOWN, - "Expected lora weights to be in CPU memory"); TLLM_CHECK_WITH_INFO(mConfig.value().getMemoryType() != MemoryType::kGPU && mConfig.value().getMemoryType() != MemoryType::kUNKNOWN, - "Expected lora weights to be in CPU memory"); + "Expected lora config to be in CPU memory"); TLLM_CHECK_WITH_INFO( mConfig.value().getDataType() == DataType::kINT32, "Expected lora config tensor to have type kINT32"); + } + if (mWeights.has_value()) + { + SizeType32 constexpr expectedWeightsDims = 2; + TLLM_CHECK_WITH_INFO( + mConfig.has_value(), "Request for LoRA inference with lora weights must also have lora config"); + + TLLM_CHECK_WITH_INFO( + mWeights.value().getShape().size() == expectedWeightsDims, "Expected weights tensor to have 2 dimensions"); + + TLLM_CHECK_WITH_INFO(mWeights.value().getMemoryType() != MemoryType::kGPU + && mWeights.value().getMemoryType() != MemoryType::kUNKNOWN, + "Expected lora weights to be in CPU memory"); TLLM_CHECK_WITH_INFO(mConfig.value().getShape()[0] == mWeights.value().getShape()[0], "Expected dim 0 of lora weights and lora config to have the same size"); diff --git a/cpp/tensorrt_llm/executor/serialization.cpp b/cpp/tensorrt_llm/executor/serialization.cpp index 65718f0405d..38256edbc75 100644 --- a/cpp/tensorrt_llm/executor/serialization.cpp +++ b/cpp/tensorrt_llm/executor/serialization.cpp @@ -23,6 +23,7 @@ #include "tensorrt_llm/executor/serializeUtils.h" #include "tensorrt_llm/executor/types.h" #include "tensorrt_llm/runtime/cudaStream.h" +#include #include #include #include @@ -1162,10 +1163,11 @@ KvCacheConfig Serialization::deserializeKvCacheConfig(std::istream& is) auto secondaryOffloadMinPriority = su::deserialize>(is); auto eventBufferMaxSize = su::deserialize(is); auto useUvm = su::deserialize(is); + auto attentionDpEventsGatherPeriodMs = su::deserialize(is); return KvCacheConfig{enableBlockReuse, maxTokens, maxAttentionWindowVec, sinkTokenLength, freeGpuMemoryFraction, hostCacheSize, onboardBlocks, crossKvCacheFraction, secondaryOffloadMinPriority, eventBufferMaxSize, - enablePartialReuse, copyOnPartialReuse, useUvm}; + enablePartialReuse, copyOnPartialReuse, useUvm, attentionDpEventsGatherPeriodMs}; } void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& os) @@ -1183,6 +1185,7 @@ void Serialization::serialize(KvCacheConfig const& kvCacheConfig, std::ostream& su::serialize(kvCacheConfig.getSecondaryOffloadMinPriority(), os); su::serialize(kvCacheConfig.getEventBufferMaxSize(), os); su::serialize(kvCacheConfig.getUseUvm(), os); + su::serialize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs(), os); } size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig) @@ -1202,6 +1205,7 @@ size_t Serialization::serializedSize(KvCacheConfig const& kvCacheConfig) totalSize += su::serializedSize(kvCacheConfig.getSecondaryOffloadMinPriority()); totalSize += su::serializedSize(kvCacheConfig.getEventBufferMaxSize()); totalSize += su::serializedSize(kvCacheConfig.getUseUvm()); + totalSize += su::serializedSize(kvCacheConfig.getAttentionDpEventsGatherPeriodMs()); return totalSize; } @@ -2181,6 +2185,237 @@ std::vector Serialization::deserializeRequestStatsPerI return iterRequestStatsVec; } +// KVCacheEvents deque +std::vector Serialization::serialize(std::deque const& eventQueue) +{ + // Compute the size of serialized buffer + size_t totalSize = 0; + totalSize += sizeof(size_t); + for (auto const& event : eventQueue) + { + totalSize += su::serializedSize(event); + } + + std::vector buffer(totalSize); + std::stringbuf strbuf(std::ios_base::out | std::ios_base::in); + strbuf.pubsetbuf(buffer.data(), buffer.size()); + std::ostream os(&strbuf); + + su::serialize(eventQueue.size(), os); + for (auto const& event : eventQueue) + { + su::serialize(event, os); + } + return buffer; +} + +std::deque Serialization::deserializeKVCacheEvents(std::vector& buffer) +{ + std::deque kvCacheEvents; + su::VectorWrapBuf strbuf(buffer); + std::istream is(&strbuf); + auto numEvents = su::deserialize(is); + for (std::size_t event = 0; event < numEvents; ++event) + { + kvCacheEvents.emplace_back(Serialization::deserializeKVCacheEvent(is)); + } + return kvCacheEvents; +} + +// KVCacheEvent +size_t Serialization::serializedSize(KVCacheEvent const& event) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(event.eventId); + totalSize += su::serializedSize(event.data); + totalSize += su::serializedSize(event.windowSize); + totalSize += su::serializedSize(event.attentionDpRank); + return totalSize; +} + +void Serialization::serialize(KVCacheEvent const& event, std::ostream& os) +{ + su::serialize(event.eventId, os); + su::serialize(event.data, os); + su::serialize(event.windowSize, os); + su::serialize(event.attentionDpRank, os); +} + +KVCacheEvent Serialization::deserializeKVCacheEvent(std::istream& is) +{ + auto eventId = su::deserialize(is); + auto data = su::deserialize(is); + auto windowSize = su::deserialize(is); + auto attentionDpRank = su::deserialize>(is); + + return KVCacheEvent{eventId, data, windowSize, attentionDpRank}; +} + +// KVCacheCreatedData +size_t Serialization::serializedSize(KVCacheCreatedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.numBlocksPerCacheLevel); + return totalSize; +} + +void Serialization::serialize(KVCacheCreatedData const& data, std::ostream& os) +{ + su::serialize(data.numBlocksPerCacheLevel, os); +} + +KVCacheCreatedData Serialization::deserializeKVCacheCreatedData(std::istream& is) +{ + auto numBlocksPerCacheLevel = su::deserialize>(is); + return KVCacheCreatedData{numBlocksPerCacheLevel}; +} + +// KVCacheStoredData +size_t Serialization::serializedSize(KVCacheStoredData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.parentHash); + totalSize += su::serializedSize(data.blocks); + return totalSize; +} + +void Serialization::serialize(KVCacheStoredData const& data, std::ostream& os) +{ + su::serialize(data.parentHash, os); + su::serialize(data.blocks, os); +} + +KVCacheStoredData Serialization::deserializeKVCacheStoredData(std::istream& is) +{ + auto parentHash = su::deserialize>(is); + auto blocks = su::deserialize>(is); + return KVCacheStoredData{parentHash, blocks}; +} + +// KVCacheStoredBlockData +size_t Serialization::serializedSize(KVCacheStoredBlockData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHash); + totalSize += su::serializedSize(data.tokens); + totalSize += su::serializedSize(data.loraId); + totalSize += su::serializedSize(data.cacheLevel); + totalSize += su::serializedSize(data.priority); + return totalSize; +} + +void Serialization::serialize(KVCacheStoredBlockData const& data, std::ostream& os) +{ + su::serialize(data.blockHash, os); + su::serialize(data.tokens, os); + su::serialize(data.loraId, os); + su::serialize(data.cacheLevel, os); + su::serialize(data.priority, os); +} + +KVCacheStoredBlockData Serialization::deserializeKVCacheStoredBlockData(std::istream& is) +{ + auto blockHash = su::deserialize(is); + auto tokens = su::deserialize(is); + auto loraId = su::deserialize>(is); + auto cacheLevel = su::deserialize(is); + auto priority = su::deserialize(is); + + return KVCacheStoredBlockData{blockHash, tokens, loraId, cacheLevel, priority}; +} + +// KVcacheRemovedData + +size_t Serialization::serializedSize(KVCacheRemovedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHashes); + return totalSize; +} + +void Serialization::serialize(KVCacheRemovedData const& data, std::ostream& os) +{ + su::serialize(data.blockHashes, os); +} + +KVCacheRemovedData Serialization::deserializeKVCacheRemovedData(std::istream& is) +{ + auto blockHashes = su::deserialize>(is); + return KVCacheRemovedData{blockHashes}; +} + +// KVCacheEventDiff +template +size_t Serialization::serializedSize(KVCacheEventDiff const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.oldValue); + totalSize += su::serializedSize(data.newValue); + return totalSize; +} + +template +void Serialization::serialize(KVCacheEventDiff const& data, std::ostream& os) +{ + su::serialize(data.oldValue, os); + su::serialize(data.newValue, os); +} + +template +KVCacheEventDiff Serialization::deserializeKVCacheEventDiff(std::istream& is) +{ + auto oldValue = su::deserialize(is); + auto newValue = su::deserialize(is); + return KVCacheEventDiff{oldValue, newValue}; +} + +// KVCacheUpdatedData +size_t Serialization::serializedSize(KVCacheUpdatedData const& data) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(data.blockHash); + totalSize += su::serializedSize(data.cacheLevel); + totalSize += su::serializedSize(data.priority); + return totalSize; +} + +void Serialization::serialize(KVCacheUpdatedData const& data, std::ostream& os) +{ + su::serialize(data.blockHash, os); + su::serialize(data.cacheLevel, os); + su::serialize(data.priority, os); +} + +KVCacheUpdatedData Serialization::deserializeKVCacheUpdatedData(std::istream& is) +{ + auto blockHash = su::deserialize(is); + auto cacheLevel = su::deserialize>>(is); + auto priority = su::deserialize>>(is); + return KVCacheUpdatedData{blockHash, cacheLevel, priority}; +} + +// UniqueToken +size_t Serialization::serializedSize(tensorrt_llm::runtime::UniqueToken const& token) +{ + size_t totalSize = 0; + totalSize += su::serializedSize(token.tokenId); + totalSize += su::serializedSize(token.tokenExtraId); + return totalSize; +} + +void Serialization::serialize(tensorrt_llm::runtime::UniqueToken const& token, std::ostream& os) +{ + su::serialize(token.tokenId, os); + su::serialize(token.tokenExtraId, os); +} + +tensorrt_llm::runtime::UniqueToken Serialization::deserializeUniqueToken(std::istream& is) +{ + auto tokenId = su::deserialize(is); + auto tokenExtraId = su::deserialize(is); + return tensorrt_llm::runtime::UniqueToken{tokenId, tokenExtraId}; +} + // String std::string Serialization::deserializeString(std::istream& is) { diff --git a/cpp/tensorrt_llm/executor/serializeUtils.h b/cpp/tensorrt_llm/executor/serializeUtils.h index 8f26c58d622..40b50f92309 100644 --- a/cpp/tensorrt_llm/executor/serializeUtils.h +++ b/cpp/tensorrt_llm/executor/serializeUtils.h @@ -122,6 +122,14 @@ static_assert(hasSerializedSize(size_t())); static_assert(!hasSerializedSize(size_t())); static_assert(!hasSerializedSize>(size_t())); static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize>(size_t())); +static_assert(hasSerializedSize(size_t())); +static_assert(hasSerializedSize(size_t())); template size_t serializedSize(T const& data) @@ -219,6 +227,14 @@ static_assert(hasSerialize(nullptr)); static_assert(!hasSerialize(nullptr)); static_assert(!hasSerialize>(nullptr)); static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize>(nullptr)); +static_assert(hasSerialize(nullptr)); +static_assert(hasSerialize(nullptr)); template void serialize(T const& data, std::ostream& os) @@ -291,6 +307,22 @@ struct get_variant_alternative_type } }; +template +T deserialize(std::istream& is); + +// Helper function to deserialize variant by index using template recursion +template +T deserializeVariantByIndex(std::istream& is, std::size_t index, std::index_sequence /*indices*/) +{ + T result; + bool found = ((Is == index ? (result = deserialize>(is), true) : false) || ...); + if (!found) + { + TLLM_THROW("Invalid variant index during deserialization: " + std::to_string(index)); + } + return result; +} + // Deserialize template T deserialize(std::istream& is) @@ -511,6 +543,38 @@ T deserialize(std::istream& is) { return Serialization::deserializeCacheTransceiverConfig(is); } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheEvent(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheCreatedData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheStoredData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheStoredBlockData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheRemovedData(is); + } + else if constexpr (std::is_same_v>) + { + return Serialization::deserializeKVCacheEventDiff(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeKVCacheUpdatedData(is); + } + else if constexpr (std::is_same_v) + { + return Serialization::deserializeUniqueToken(is); + } // Optional else if constexpr (std::is_same_v::type>>) { @@ -547,23 +611,7 @@ T deserialize(std::istream& is) std::size_t index = 0; is.read(reinterpret_cast(&index), sizeof(index)); - // TODO: Is there a better way to implement this? - T data; - if (index == 0) - { - using U = std::variant_alternative_t<0, T>; - data = deserialize(is); - } - else if (index == 1) - { - using U = std::variant_alternative_t<1, T>; - data = deserialize(is); - } - else - { - TLLM_THROW("Serialization of variant of size > 2 is not supported."); - } - return data; + return deserializeVariantByIndex(is, index, std::make_index_sequence>{}); } else { diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu index 27d041618e7..84710a96365 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.cu @@ -256,9 +256,9 @@ public: constexpr int SF_VEC_SIZE = 16; using PackedVec = PackedVec; PackedVec pack_val = *reinterpret_cast(&val); - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt, token_id, - m_access_id_in_token, std::nullopt, m_params.hidden_dim, - reinterpret_cast(m_params.scale_out), m_params.layout); + auto sf_out = cvt_quant_get_sf_out_offset(std::nullopt, token_id, m_access_id_in_token, + std::nullopt, m_params.hidden_dim / SF_VEC_SIZE, reinterpret_cast(m_params.scale_out), + m_params.layout); reinterpret_cast(m_params.quant_out)[m_access_id] = cvt_warp_fp16_to_fp4(pack_val, m_scale_factor, sf_out); } diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h index dbf45ebe1cc..52487b25d4e 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/allReduceFusionKernels.h @@ -132,7 +132,7 @@ struct AllReduceFusionParams float rms_eps; float* scale_factor; bool use_oneshot; - FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED; + QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED; cudaStream_t stream; AllReduceFusionPattern pattern; bool trigger_completion_at_end = true; diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu index 2176ba759f4..c38abd95785 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.cu @@ -99,15 +99,15 @@ __device__ struct __attribute__((aligned(32))) LamportFlags uint32_t* offset_access_ptr; uint32_t* buffer_flags; - __device__ explicit LamportFlags(uint32_t* buffer_flags) + __device__ explicit LamportFlags(uint32_t* buffer_flags, uint32_t buffer_size) : offset_access_ptr(&buffer_flags[4]) , buffer_flags(buffer_flags) + , buffer_size(buffer_size) { uint4 flag = reinterpret_cast(buffer_flags)[0]; - buffer_size = flag.z; input_offset = flag.x * (buffer_size << 1U); clear_offset = flag.y * (buffer_size << 1U); - num_tokens_prev = flag.w; + num_tokens_prev = flag.z; } __device__ void cta_arrive() @@ -135,7 +135,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags uint4 flag = reinterpret_cast(buffer_flags)[0]; buffer_flags[0] = (flag.x + 1) % 3; buffer_flags[1] = (flag.y + 1) % 3; - buffer_flags[3] = num_tokens; + buffer_flags[2] = num_tokens; *(offset_access_ptr) = 0; } } @@ -144,7 +144,7 @@ __device__ struct __attribute__((aligned(32))) LamportFlags template __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ptrs, T* mcast_ptr, int num_tokens, - int buffer_M, int token_dim, int rank, uint32_t* buffer_flags, bool wait_for_results) + int buffer_M, int token_dim, int rank, uint32_t buffer_size, uint32_t* buffer_flags, bool wait_for_results) { int elt = blockIdx.y * blockDim.x + threadIdx.x; if (elt >= token_dim) @@ -155,7 +155,7 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ cudaGridDependencySynchronize(); #endif - LamportFlags flags(buffer_flags); + LamportFlags flags(buffer_flags, buffer_size); // Capture the number of tokens in previous iteration so that we can properly clear the buffer // The scatter stage will use the buffer in WORLD_SIZE granularity, thus we need to round up @@ -217,15 +217,17 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) cudaTriggerProgrammaticLaunchCompletion(); #endif - - // Similarly clear broadcast buffer here - for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) + if (elt < token_dim) { - uint32_t clr_token_idx = token + clr_tok * gridDim.x; - if (clr_token_idx < buffer_M) + // Similarly clear broadcast buffer here + for (int clr_tok = 0; clr_tok < clr_toks_cta; clr_tok++) { - input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt] - = fromFloat(-0.f); + uint32_t clr_token_idx = token + clr_tok * gridDim.x; + if (clr_token_idx < buffer_M) + { + input_ptrs[rank][flags.clear_offset + buffer_M * token_dim + clr_token_idx * token_dim + elt] + = fromFloat(-0.f); + } } } @@ -240,20 +242,24 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ // blockDim.x / ELTS_PER_LOAD should be at least the size of a warp (32) if (threadIdx.x < (blockDim.x / ELTS_PER_LOAD)) { - uint64_t current_pos = blockIdx.x * token_dim + blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; - - void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; - // We have 2 assumptions here: - // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B - // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) - float2 val = loadfloat2(lamport_ptr); - while (isNegZero(*(T*) &val)) - { - val = loadfloat2(lamport_ptr); - } - if (output_ptr) + uint64_t elt_load_offset = blockIdx.y * blockDim.x + threadIdx.x * ELTS_PER_LOAD; + if (elt_load_offset < token_dim) { - *((float2*) &output_ptr[current_pos]) = val; + uint64_t current_pos = blockIdx.x * token_dim + elt_load_offset; + + void* lamport_ptr = (void*) &input_ptrs[rank][flags.input_offset + buffer_M * token_dim + current_pos]; + // We have 2 assumptions here: + // 1. The write is atomic in 8B granularity -> Each buffer in the buffer group should be aligned to 8B + // 2. The num_token * token_dim is divisible by ELTS_PER_LOAD (4 for BF16 and 2 for FP32) + float2 val = loadfloat2(lamport_ptr); + while (isNegZero(*(T*) &val)) + { + val = loadfloat2(lamport_ptr); + } + if (output_ptr) + { + *((float2*) &output_ptr[current_pos]) = val; + } } } @@ -263,10 +269,11 @@ __global__ void twoshot_allreduce_kernel(T* output_ptr, T* shard_ptr, T** input_ } #define LAUNCH_ALL_REDUCE_KERNEL(WORLD_SIZE, T) \ - TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel, \ - reinterpret_cast(params.output), reinterpret_cast(params.input), \ - reinterpret_cast(params.buffer_ptrs_dev), (T*) params.multicast_ptr, params.num_tokens, params.buffer_M, \ - params.token_dim, params.rank, reinterpret_cast(params.buffer_flags), params.wait_for_results)); + TLLM_CUDA_CHECK( \ + cudaLaunchKernelEx(&config, &twoshot_allreduce_kernel, reinterpret_cast(params.output), \ + reinterpret_cast(params.input), reinterpret_cast(params.buffer_ptrs_dev), \ + (T*) params.multicast_ptr, params.num_tokens, params.buffer_M, params.token_dim, params.rank, \ + params.buffer_size, reinterpret_cast(params.buffer_flags), params.wait_for_results)); void twoshot_allreduce_op(AllReduceParams const& params) { @@ -369,20 +376,33 @@ inline __device__ T add(T a, T b) } #define FINAL_MASK 0xffffffff +#define WARP_SIZE 32 template __inline__ __device__ T warpReduceSum(T val) { + // Get the actual number of active threads in this warp + int active_warp_size = min(WARP_SIZE, blockDim.x - (threadIdx.x & ~(WARP_SIZE - 1))); + unsigned int mask = (1U << active_warp_size) - 1; + #pragma unroll - for (int mask = 16; mask > 0; mask >>= 1) - val = add(val, __shfl_xor_sync(FINAL_MASK, val, mask, 32)); //__shfl_sync bf16 return float when sm < 80 + for (int offset = 16; offset > 0; offset >>= 1) + { + if (offset < active_warp_size) + { + val = add(val, __shfl_xor_sync(mask, val, offset, WARP_SIZE)); + } + } return val; } inline __device__ float block_reduce_sum(float val) { - __shared__ float smem[32]; - int lane_id = threadIdx.x % 32, warp_id = threadIdx.x / 32, warp_num = blockDim.x / 32; + __shared__ float smem[WARP_SIZE]; + int lane_id = threadIdx.x % WARP_SIZE; + int warp_id = threadIdx.x / WARP_SIZE; + int warp_num = (blockDim.x + WARP_SIZE - 1) / WARP_SIZE; // Ceiling division to include partial warps + val = warpReduceSum(val); if (lane_id == 0) { @@ -391,6 +411,7 @@ inline __device__ float block_reduce_sum(float val) __syncthreads(); val = lane_id < warp_num ? smem[lane_id] : 0.f; val = warpReduceSum(val); + return val; } @@ -410,7 +431,7 @@ __device__ float4 loadfloat4(void const* ptr) template __global__ void __launch_bounds__(128, 1) RMSNorm(T_IN* input_plus_residual, T_OUT* output_norm, T_IN const* buffer_input, T_IN const* gamma, float epsilon, - T_IN const* residual, int batch_size, uint32_t* buffer_flags) + T_IN const* residual, int batch_size, uint32_t buffer_size, uint32_t* buffer_flags) { #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900)) static bool const LAMPORT = true; @@ -433,7 +454,7 @@ __global__ void __launch_bounds__(128, 1) int offsets[NUM_INPUTS][DIM / (1 * ELTS_PER_THREAD * NUM_THREADS)]; - LamportFlags flags(buffer_flags); + LamportFlags flags(buffer_flags, buffer_size); T_IN const* input = &buffer_input[flags.input_offset + flags.buffer_size]; cudaTriggerProgrammaticLaunchCompletion(); @@ -598,16 +619,15 @@ __global__ void __launch_bounds__(128, 1) #endif } -template +template void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T const* gamma, double epsilon, - T const* residual, uint32_t* buffer_flags, int batch, cudaStream_t stream) + T const* residual, uint32_t buffer_size, uint32_t* buffer_flags, int batch, cudaStream_t stream) { // input to rmsnorm is the buffer in the twoshot ar // We should use prenorm output to determine the actual used size float _epsilon{static_cast(epsilon)}; - static constexpr int NUM_THREADS = 128; static constexpr int CGA_THREADS = NUM_THREADS; constexpr int iters = H_DIM / CGA_THREADS; @@ -628,28 +648,34 @@ void twoshot_rmsnorm(T* prenorm_output, T* normed_output, T const* input, T cons &RMSNorm, cudaFuncAttributeMaxDynamicSharedMemorySize, shmem_size); config.dynamicSmemBytes = shmem_size; TLLM_CUDA_CHECK(cudaLaunchKernelEx(&config, &RMSNorm, prenorm_output, normed_output, - input, gamma, _epsilon, residual, batch, buffer_flags)); + input, gamma, _epsilon, residual, batch, buffer_size, buffer_flags)); } -#define LAUNCH_RMSNORM_KERNEL(T, H_DIM) \ - twoshot_rmsnorm(static_cast(params.residual_output), static_cast(params.output), \ +#define LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS) \ + twoshot_rmsnorm(static_cast(params.residual_output), static_cast(params.output), \ static_cast(params.input), static_cast(params.gamma), params.epsilon, \ - static_cast(params.residual), params.buffer_flags, params.batch, params.stream) + static_cast(params.residual), params.buffer_size, params.buffer_flags, params.batch, params.stream) void twoshot_rmsnorm_op(RMSNormParams const& params) { auto dtype = params.dtype; + +#define CASE_DISPATCH_RMSNORM(T, H_DIM, NUM_THREADS) \ + case H_DIM: LAUNCH_RMSNORM_KERNEL(T, H_DIM, NUM_THREADS); break; + +#define TYPE_DISPATCH_RMSNORM(T) \ + CASE_DISPATCH_RMSNORM(T, 2048, 128) \ + CASE_DISPATCH_RMSNORM(T, 2880, 120) \ + CASE_DISPATCH_RMSNORM(T, 4096, 128) \ + CASE_DISPATCH_RMSNORM(T, 5120, 128) \ + CASE_DISPATCH_RMSNORM(T, 7168, 128) \ + CASE_DISPATCH_RMSNORM(T, 8192, 128) + if (dtype == nvinfer1::DataType::kFLOAT) { switch (params.hidden_dim) { - case 2048: LAUNCH_RMSNORM_KERNEL(float, 2048); break; - case 4096: LAUNCH_RMSNORM_KERNEL(float, 4096); break; - // Llama-4 Hidden Dimension - case 5120: LAUNCH_RMSNORM_KERNEL(float, 5120); break; - // DeepSeek Hidden Dimension - case 7168: LAUNCH_RMSNORM_KERNEL(float, 7168); break; - case 8192: LAUNCH_RMSNORM_KERNEL(float, 8192); break; + TYPE_DISPATCH_RMSNORM(float); default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); } } @@ -657,13 +683,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params) { switch (params.hidden_dim) { - case 2048: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 2048); break; - case 4096: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 4096); break; - // Llama-4 Hidden Dimension - case 5120: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 5120); break; - // DeepSeek Hidden Dimension - case 7168: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 7168); break; - case 8192: LAUNCH_RMSNORM_KERNEL(__nv_bfloat16, 8192); break; + TYPE_DISPATCH_RMSNORM(__nv_bfloat16); default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); } } @@ -671,13 +691,7 @@ void twoshot_rmsnorm_op(RMSNormParams const& params) { switch (params.hidden_dim) { - case 2048: LAUNCH_RMSNORM_KERNEL(__nv_half, 2048); break; - case 4096: LAUNCH_RMSNORM_KERNEL(__nv_half, 4096); break; - // Llama-4 Hidden Dimension - case 5120: LAUNCH_RMSNORM_KERNEL(__nv_half, 5120); break; - // DeepSeek Hidden Dimension - case 7168: LAUNCH_RMSNORM_KERNEL(__nv_half, 7168); break; - case 8192: LAUNCH_RMSNORM_KERNEL(__nv_half, 8192); break; + TYPE_DISPATCH_RMSNORM(__nv_half); default: TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported hidden_dim."); } } @@ -685,6 +699,8 @@ void twoshot_rmsnorm_op(RMSNormParams const& params) { TLLM_CHECK_WITH_INFO(false, "[MNNVL TwoShot RMSNorm]: unsupported dtype."); } +#undef TYPE_DISPATCH_RMSNORM +#undef CASE_DISPATCH_RMSNORM } } // namespace tensorrt_llm::kernels::mnnvl diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h index ccca256b5a2..3a0fb753db2 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/mnnvlTwoShotAllreduceKernels.h @@ -30,6 +30,7 @@ struct AllReduceParams int buffer_M; int num_tokens; int token_dim; + uint32_t buffer_size; void** buffer_ptrs_dev; void* multicast_ptr; void* buffer_flags; @@ -50,6 +51,7 @@ struct RMSNormParams void const* gamma; double epsilon; void* residual; + uint32_t buffer_size; uint32_t* buffer_flags; int batch; int hidden_dim; diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu index 577f4b5ff4f..7bc9e326fb2 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.cu @@ -150,8 +150,8 @@ __device__ __forceinline__ void fused_op( constexpr int SF_VEC_SIZE = 16; using PackedVec = PackedVec; PackedVec pack_val = *reinterpret_cast(&norm_val); - auto sf_out = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, - token_id, access_id_in_token, std::nullopt /* numRows */, params.hidden_dim, + auto sf_out = cvt_quant_get_sf_out_offset(std::nullopt /* batchIdx */, token_id, + access_id_in_token, std::nullopt /* numRows */, params.hidden_dim / SF_VEC_SIZE, reinterpret_cast(params.scale_out), params.layout); reinterpret_cast(params.quant_out)[access_id] = cvt_warp_fp16_to_fp4(pack_val, *params.scale_factor, sf_out); diff --git a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h index 9ebc7de6509..4a35d14bf09 100644 --- a/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h +++ b/cpp/tensorrt_llm/kernels/communicationKernels/moeAllReduceFusionKernels.h @@ -55,7 +55,7 @@ struct AllReduceFusionParams void* rms_gamma; float rms_eps; float* scale_factor; - FP4QuantizationSFLayout layout = FP4QuantizationSFLayout::SWIZZLED; + QuantizationSFLayout layout = QuantizationSFLayout::SWIZZLED; cudaStream_t stream; }; diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_32_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_32_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 81208594d0f..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_32_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:d5bb139b12206a563daec9fa473dda422319bde5ae5f965d37cf5ca67d325c49 -size 1005546 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_64_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_64_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 7086ad9f485..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_bf16_128_64_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:c4357a935656d47414a459939720b66311c67213f450168715e1cb0238653768 -size 1066324 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp index 0acae9aa71b..2ae91e52cd7 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0a0671e7cbbed9f51dc0c47e4b970e2f72067d629ff6562c9d65f9cd55c68578 -size 361861 +oid sha256:c709dce149c0f4500539e495c90d1da2d86cec28c4187ee9494b015642e158cf +size 363441 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp index 4cb6bcd1c18..bce0c66bcf1 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5ec9817bebb07483ce29d8d91c45d35c2c05f0101bfa70146fba5a6576a6b825 -size 1091614 +oid sha256:b9170581da010aca67f4bafd9f6f59aaaf5fd1958a1fdd336aa208146599ac06 +size 1094770 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp index 470904148ad..caa735d5724 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0540cdb398818ec54a60c34b462c158e169347db73d244d633669d74211696ba -size 1467312 +oid sha256:2147a246067f7ea74ca382fbc8c02a26332479e5205ecfbe08fb84161a3a87ec +size 1483888 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp index 281985341d5..0b584163a86 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:69bdfba64f1faff30ed8389a28b7b9ef37c0d180b1df643722b280011c8f74e8 -size 692990 +oid sha256:279bd48b8ac53690bb4e37dffbe9060428db80c1417ff29c6f4d4a10ab35a7c9 +size 700094 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp index 8b8738474dd..496df695fcc 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c8173308813999ab64ba8236016b23fbfd3f3f1501f61290bf71ea027ead2920 -size 642456 +oid sha256:db5d186ce70d7a94cae2b6619b3449ca557903944beba1ee738d2ee425792d74 +size 652718 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index 6ca952af647..c6692932cdb 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f41ae066b01b2a9c3b5165535f743461a9a1d559f6fcd0a00a04c554f8a50962 -size 414757 +oid sha256:089a98cf8ab0bbd7530e69821c42220ea02578b740bff62a3e6e33de45209114 +size 416335 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 1a973c5d2e6..555f6268648 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ab0be8e667d459e13135f96469613f1c095e47187b24e5d40c7c57583351a076 -size 1194236 +oid sha256:1f0cc486ec5e9c1720f495a2a5e7c26d42e737694d307d4746a08b6ead5cc225 +size 1197394 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index 8faf85254d9..b5884bba556 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:03d86280f76994e2e01d43747cb5c811496b8340d031ebb0c3bdd46437422994 -size 1654394 +oid sha256:398965e34c1a4c747b42d8836c04934daaa43903b7931586ed12120e17a61f76 +size 1672548 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index 53f3032a30e..696620f8791 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:35c5715bcb1a16c343f3a28be105fb6fee1bbca24cf832f71a7d0f20cf9a0b3e -size 365015 +oid sha256:77cbd7d45164d24be73e021bc0a8745b4f021e4369a254e216ee00b36d3c7263 +size 366593 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp index 89a4eaa580c..22a4ff75bf6 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a3335a8d4b2c0ca63f006c3f957d57aa3f808ef06d4adda322c311a333286d84 +oid sha256:3a3f74fbe72ef54b9c028d957353c1ecbff1d20bcc9619ff17ee37471934a2ab size 1126352 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp index 9cb2eb33c23..e0b9335b45e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fdc0bf099862d352b3b765e117437240a82e4749d3efd104881647dd4ea14562 +oid sha256:b3af082c6742f385d0d2c96489ff1de314458eb992d6d5a251c737f8ec912e79 size 644092 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index 153555cbe42..ec999849faf 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ccd938df8f78af4eae306c6e9e669599c2baf6f095f956318470063c560fbd3c -size 1091610 +oid sha256:8e26f3b8cc173301b3cf07ba1ca7893b6f140432410b0b298361ecff597604c2 +size 1095556 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp index cab205493aa..284e084f3df 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ce4d35ab4c7b65476f0dcec635db1791fcb718afd6b3531338712f5b2bc9aa84 -size 1460204 +oid sha256:32220d11bc3542e9edcc36d51b4866bf40044213114d7e237e003afc1fc7c464 +size 1478358 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp index ab21a448f54..69a3f4789c8 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_q_paged_kv_64_sm86.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d088ce37b21d335ba1f92034cf97f78fc968d7fecaa0c4f9ec83a0d5165f1d99 +oid sha256:3ee5ae75df4866d848e90616562345d3740b17b68c90f06329dc074dba5217a9 size 482709 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp index 2fa6ba246ed..c19635d6887 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:40653ec672098e2cb1f94c473fa67852efcf6b49a6e8109e4fcf39422281acb4 +oid sha256:817ae5c1eb8a8c6f22a76ab0b88075fd3391d06abb7dd6d9ab51206b809cd69d size 657930 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp index ebdb0563ef9..a625def240f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:96348957990518db6f51af7c681a71e625dede568cc8f8303dd2de8ad09bfc28 +oid sha256:680734da0abb1c3029dce32e892687f649c4219f66574acb15ab88471f508263 size 677218 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp index 7cd5b267e07..1691a77e1fe 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_bf16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4687df80ac2fa9454b0564b0a80d78cfaedc2c7796c8f3a1010dd7ebbf722c83 +oid sha256:c27e871dd680022920081c30c5e239613e53b42129680fdb1d17668b5c5ddd9a size 369401 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp index f4da9b9d86f..6e7098d6c73 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d8b9985065f5f2c62b74c05f8eed02b1909c96656b26fbd7779cc57a2146b037 -size 947140 +oid sha256:3e1ecaa635067924b692b665241d86e1d8c1d60a19290de7adde1ff2ca7dbeb0 +size 956612 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index 8ffdb6589d9..c38c3b29fd6 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:23599e63b07ad966df921daf3cb97a9ed5cde27eeda0fd96ba5abd835b48f89a -size 590779 +oid sha256:d3018c622303f89c6f22f037ec99eaeaeea9cfe8911e22463b48a22c13116805 +size 592357 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 1153714c7e1..5d286a73e53 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cd1c452565583b20913d835de9b14c2f19c0cc431bc926ea6c92295362a85bca -size 1813864 +oid sha256:a7a381f2855236f418a40124a5254401c95001d5e15c074a704e22cc7ed89aa2 +size 1818600 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index b6383dcbd5c..5290f97cfb8 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b20de2c6bb3081564ddfbf7ece80fb2c17e66f4e7ff0e0969da4e4655e90d1ec -size 2407418 +oid sha256:9bb49ace4dedc4faa3de2b9c22e09db0f3990129ce7ab4afb6419c38a5d48a16 +size 2427152 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index 3713748af50..cb3d89f0704 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:33a0e8bb2391128e688e5c6356f09a5ed189ce5c1bcdeef4efc0ce0415dc2849 -size 555245 +oid sha256:9769d7cb9754718798be515c84c45ff48e43322573f3f12e31c2e42e99d8dbd4 +size 557613 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp index 795d4d68fc9..de925119b38 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_sage_64_64_256_output_bf16_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4b014f41b1cfdf6ed2729778841213a36440191eb3c087346a02c21510bd3f0e -size 665794 +oid sha256:134f4a73e0e6b02b717319ec49e3b3ea0a585cad385a1f300e6c5761f12de9d7 +size 671320 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index 5c8dbe22b24..64bb52e0df9 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bd77afeb7dcd1ff8d6be80788b20e92e4fbc8c3026ba12d1d522c99316754a7c -size 1740442 +oid sha256:7935b0f053a79a7e620c0efe274fa5b4c840fc9c6e439a381c4d380446e1cb68 +size 1744388 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp index ee1a46c9bc9..87d96af432a 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_64_256_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b674707d02aac297b66d523de8b11618ca1598c49eeaf7ce9b1c9d516ce95c4b -size 2247958 +oid sha256:74ecbbaa19b2efe97a3b12c488f0e03c2102f16c460239df4bfc19976fc4365e +size 2266902 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp index 349c2efdfe3..15ad1d62a91 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_32_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7556f88488e05ee669e763b839afa1b7690060cfa9d8482d419c0ca336df9352 +oid sha256:813265d25709bd2d39982efbaf092c9163b124bd990fccab505b3c22134522aa size 595585 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp index 2ccc55f1447..4e62255a629 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_kv_64_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ac9d879aa0c70967bb3a79cd7034998baf43a544c0dd4444ebddeb76e78df5ae +oid sha256:dd36195c01bf7c2a2013d5f31d2e74c2579c471385d7b45be7e35ea2f0652608 size 908162 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp index ec1ef8aae91..10ee7b3d8c4 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_32_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4e781c0278fc46142f578ae51bfeb38767e89d9c25b92023215948f99dd1d3ed +oid sha256:31d4d6dca68c4632d1f435e9179582cfe2ad7a75ee0f7625ee67b0044c914f10 size 1371512 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp index d904de0acb2..407d34a6552 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_40_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d608e9e3ec460d2a38f43067a7d7a2dd408e068db690806bbafb11007e175336 +oid sha256:6570d3ee7b651dec797e82b31eb21fd3261c6e2639fb7c9b157f251bf98bb3bf size 1419662 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp index 798e8482b41..d6b829a9a00 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_48_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9c1e1d300866c6425c2495e550230051debdca0a7eb85874ae33c0c2de8a81cb +oid sha256:88b972677c5436b90fe85870278e3b23d6f709608f99295bddf0be3861d95d1a size 1419662 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp index bbcce09e729..7cac9a83250 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_q_paged_kv_64_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:132d83639e34af1b431abdcb3f09542d0389030b85752e18a3ae221ead7d24a3 +oid sha256:d975f605d62c3070d6cf72f6114d98642c520e66989ed2d2845c3213e921ebf7 size 1965880 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp index 83287a0376a..9dd7d6bf8e6 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_32_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4a96710f6c691580c2363c187a75fd436f5e6be732810a1a45182ce72dc52d1e +oid sha256:ef5a2728cbd3241f45f3d8285c91a818e11b2a9fedf322f343a9461d31a6ad30 size 1380182 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp index 00623779346..1b6d6cddf5e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_40_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:a6339f008f451d030aa36a6b3fac7179e7534f7f2474d641fa0ebfbf487074e7 +oid sha256:16b5f3d3f8760dabc0849217cf11edf18d19896dda475a5fc233bbfd444faf33 size 1401494 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp index 0d719af97a3..90decb87938 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_48_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:57ebcae2b70fc28881f2b3969868d64c203ef4a9cbc9588a9e28051c5f5b6849 +oid sha256:cbacb235f39adaeabd68e2fc46c51aac6ca26cdf96293a6a7eb60b5be40640ef size 1401494 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp index ceab132d423..5628ced1f3e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_128_128_S_qkv_64_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5e2a4ce1b944feb2b3ed535943089a2d5968bf523b149885df78f7fa4bd7e835 +oid sha256:e6f3e068435339a64d47673f8018b66c202f6259d68e0a97a4a30acb7505a7fd size 1935872 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp index 2780675d9d0..552a78df4f2 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f5d456b30f89ad05ba5b852fabcffb3f8269913d83ef8c0e4e319f2243dee54d +oid sha256:7c2d7ab0692de5405b26d19a0c57d720285366ac12a8550bbabca1613cce7f0c size 305897 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp index 2aa3fd4b0a3..ca2d2a604da 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_kv_72_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:85593d3c2fecb6842a72952c6dcbde19a70e6b26245829d279ca50bb391eb636 +oid sha256:91a26adfddc0bcaf8b42249f59f1a0b9f74be0f82c7378fe4b56f3a2fa3d4bf1 size 290109 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp index b050acbb5aa..da475b4a2d1 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_104_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:69cd61bd8334d2109067ef0460a91b8dba4c2cb07392eb636d72d025ccb15bf9 +oid sha256:6ef79c9e2e2d8bba55d7803dc8dc147b5d8babc29e906a43407a8722bbd8d939 size 498507 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp index e741d50f4cd..09b401a0036 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0427b7729ce3cfa652a4595d04f936a947febec8f2c96ce33eed7cbaaa05613e +oid sha256:0eef025f8e8581868b02bcea37ff225afebcbb2966450fb29fb0e32ac54eccd4 size 668214 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp index eee064e2804..0c6a45eacc1 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_160_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:321bcd81b8965c8dfc08682f775508ae18e3ff711490ee8dff5fe56c20f74843 +oid sha256:abb2857ffb85cc36aae90ebb674635dffee2b2c5f7ad1ea81bb8002b65d5a0f8 size 711628 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp index 33f4d9cab3b..9ecb64bd23f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aa77d3789c0ca314689125ec303a8af76554120a708a4b63395c69b7aad07f04 +oid sha256:49a3661535314b139e2794fe16f6f3e0a8d45742b68ea59ba99a9113068adf2c size 752698 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp index 31383430901..d836cccd03a 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_192_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aa35aa70d0fa304c776c076a1a189d32a054d3f696dac5d99018085d1108c73b +oid sha256:d76fb6c4f8bb2de687bc5f9f275389356934119c1f0db9983dcf0ec7b68c6197 size 748726 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp index ca7815f7109..79e1e96e9bb 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_256_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d1a702d456b5acf279487dd810e3e33efdd1c7bd82530ceb5a32ad30ec30396c +oid sha256:be8ee89f4489c430d0ff6e9c6cf4e07379ac05abf468d47e34e084ad594b2037 size 946060 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp index 8bb9403c511..3c8b2528fc3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_72_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:558aa7d42de329c49361c94c4baef16738304b21b6adbe675d77c7819ef37660 +oid sha256:aa4be8ca2dd52e56c9a6af76b90ac353d217fad5fa931b21129ac5a811b5283a size 489823 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp index 0754f76695b..22fce024ea0 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_80_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7b5baa6048e6c33e74c6d343eb7c76252ff2e534fe467b3189af12b5d64af37c +oid sha256:cb0482b768a40bc7f8a86fa23a84bab62fb82c205f3237ff60becda50cbafc90 size 489823 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp index 68de134acba..c02b557e7f4 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_q_paged_kv_96_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:e17cb191ad092e6db255ea503e49ea883ed56322fc58ed8d68710f6687376c1f +oid sha256:95b1796f4e7c905eca82ed3691427025f68e765797440b962b0114a5ab32b1d7 size 500083 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp index 3ebcc110ecd..cbc081aae2c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_104_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bfca5660a931e08941347f7a0aefa82c214940e8eaa6b6d89cfded621f34a490 +oid sha256:2d9f13977fc865e716f1f35dfdb222a38000b224ff7394134230ed5c88119947 size 496125 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp index c0c882331e1..cc613cc08d5 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fffd2cd799953808034d7e7b89a57d4fede24db124bfb0d3938188177acbdfeb +oid sha256:007e32a06fcac853159dc5786940447281c57ba70406d38beb6f089fd037053d size 182023 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp index 458aa250b4a..d8ba5241135 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sage_64_32_32_output_fp16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:19ada3a5d449542f103077db8d193bc2293a8f48ccee201e366473964287314c +oid sha256:26241ea5909395116e1b1a0f19cadc448886f6a6ab2b3ba76c092b67cd0148f0 size 182023 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp index 65edc3e52ac..0206f719811 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b9c32124cd708aab7da30637d85437da0af9bf2157d163c19c6fe14498698cda +oid sha256:86e4ca60a459117c5e701631fbd3c67ca66e81d177c394c1fc9ad3b66396e69a size 661096 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp index 8213475b06f..3444d759b7f 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_160_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7f248fd42759509c61d20f912ae74dc3a85448a9c8386370ea92492ed9031e80 +oid sha256:770db1f4ec1c2d3c25767593b60cb095e49f7a6eb7abe054bbdec6e72db97f8d size 672936 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp index 75bd11ff6e7..b99affa0208 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:190fd946ddc7e1b5e9ca2172ec1de39c6288829773d9ce29fe98374256eff566 +oid sha256:0b6428cae2d0c8c813925be9589c94771098cfe5a6d0ff2036104d3e36384b81 size 721900 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp index ed5e241d9e9..e93db30f53a 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_192_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b7cd5976c836bcd75c0cadfe968050ac60bf89b93df021ad6c1681e159c497c5 +oid sha256:36c6932301fe3dc29631c28fcb8cb6b08652103bc7a36fd74a03a8189a1c77e4 size 717928 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp index 44ce0c307f1..8f42d5a2769 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_256_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7c536d725e1d9ebd2cb836dfe3993edcc81101534db6b7f1943c8a9443838bf4 +oid sha256:d858f6dcaf3f49fb3fa18b1c8c20ee1b933e2c8ddd1a429c8d3b5b4d269fb875 size 927892 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp index 0216db308c5..0cb2a134102 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_72_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:b5907da5a2f68c010d44bbbd0d780e097f9625be15b2f85e8dd1f00dd4c31ff9 +oid sha256:7dc92ab65ed0fc5f9d821f52a396a6d55ea9ae37e080eac7ff9e9c14eae741e7 size 631890 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp index c63b37264a5..648e3acb008 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9cf14c71134a89ed6ffc83c0b7db06ed10e22b55294dc15ddf7f016427f01033 +oid sha256:d66606a37cfe8eb78ccc3f548a231f770df9f46e70f6d3ba22fb8abe6216480e size 159919 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp index 7d1ac808673..6028cc1f326 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sage_64_32_32_output_fp16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f2b83c70dbc8ab0b3695dab3f4d2069b7ee7119e9140d7860b8c19f59a498589 +oid sha256:b723b296cff04602f64a5da9928e6f9b6a03c5cc608ba9ef7d8055f23f1f4ea2 size 159919 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp index 4041bfc97a4..b1ee67b880c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_80_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:fc8369f5701dceea91d429a713ddcbb4ecb0ad08d3c9042688557ead5f00e9da +oid sha256:d40578a5684262cd8136705367e2c98493ea9b9fcfc123c7efa3ead14017b5b8 size 483493 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp index f0afe3fcf10..4ce3d2dba50 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_32_S_qkv_96_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4e9fffff2d13d49613e5f9334a010ca9bcde43b3bb55a792fd97fe2c867760dc +oid sha256:60cc82b9d11c53392de91a7c4c097263c20a56f9b346278c7c9af12ef2bb5fbf size 496123 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp index 03a4b33cefc..d24465ed9c8 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dd3041ba5a52263f7f02d64f1911c50e346151bf529e865c1abf22583abd3e21 +oid sha256:8f685b6b2a0a573953f31fad89fa37e949361db245de69c0c06ce0bbb14eacef size 443285 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp index 6984f3c1700..dc49a306271 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_192x128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:12482099b086249163085e6e3421a61f6e304f865aaf56dd15382614be5e48e7 +oid sha256:834f0f3601c589893a21b957be2864df594f96b34b2cfd6018ada8319986aa21 size 441683 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp index 2bb4cc25821..4763a29923c 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bfea1ea1627eaef7b614db08bad00bda8b611c8e466c858e050c0ce2aee2eafb +oid sha256:3d81a070e7ed49f1e1a322d38a757a3505186cf5cbded99814e950e07229a46a size 298049 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp index 7e76c5e13df..c8587a81d35 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_q_paged_kv_576x512_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f828600699faa3a0474085cbbe88d2e0ac7c8e056c976b81a882c3a72682e527 +oid sha256:b9de5bc49d888699da1880d24ccf6a9cb6c0049d7a244d1ae9ab64b7365ecd5a size 296445 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp index 1c1f7bdc42f..7d299b87052 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_output_bf16_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2d4b297922065ecb79b4a1278d048b253b57601d011fc5833a32f9fc1b78e58e +oid sha256:e30ed0df4b0d0b1da1ace5831dc0a7a526e04001b25860f862345c78acff5a43 size 427485 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp index 68394c07c1c..47eeb69632b 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_e4m3_fp32_64_64_S_qkv_192x128_sm89.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:3fd5305445c9856fbd5d9dfaffdd7f87b9014638f33fb63fb2cb4fce9893b20b +oid sha256:030015dc1811e3dc2ae36ed770f51063a3f46deae42ead5e1523c977b438a133 size 425883 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp index 51778ad0e9d..1a5b22eed8a 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_128_128_S_q_paged_kv_64_sm80.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2b7fee97097f799830df2bcb1c782c7ea9018243cbd5cd0e0f47ec299b49db79 +oid sha256:6921a204892e1336cef2a308be38855f3c888e56bd6a16752d2806aa9e93c431 size 1524634 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp index 537871847de..834fa7d1c0b 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:8ac2f9270988bc02329ce11ef3413395b2b8cdc55fcf4911d170536c6e618317 -size 403697 +oid sha256:200df98fb2fcc734e8fc012c98c5d78c2061e5718eef6ffd50c2358a3d664197 +size 406065 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp index 6bf814ac8a9..e085961e987 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1234cf31a3a6b84ed25fa0ad6c4df9b53f673f6bac2f639a66086ba50f8717ba -size 1120818 +oid sha256:430194fe07e526ad01a1e0fb43273b240c269215b132c9af248ba386dcbda23e +size 1124766 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp index 3bebbebcf15..2d56be2925e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0fff300932a16d30844e317ace515a178f159c483e436f6955983b96c5c424c6 -size 1549402 +oid sha256:53a07904a7bfbf82380c96af99c5e24bc86f77906c5d6fdc85ef9720639d76d2 +size 1569136 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp index ef64a376820..6d074921cde 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ed10767ec913d314936fc5dbd1fd70c5381a622bf3fcf1590f837da6d3285bca -size 723774 +oid sha256:1ce4d27b11fee3e5f6489510b55613177e174660b6c7a6fb4efed862b62c50d7 +size 731668 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp index d0bc52f1318..a6268993164 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7e7a7a9653a9c4e4e9b0514fc1d70abbb4521c7edbede52568d17d0779d62ffb -size 671662 +oid sha256:3992d7bd34e72089c5cffc4fc6de3f70a3995145b989811f83b00b47c96b5159 +size 681924 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index 3056a533d67..d95d392d536 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:1e18db0cd4de65e76e30f219d24ec00095fb16005882c43322182c5fa3f59032 -size 445541 +oid sha256:521417177fc0447809c07ff86b58725fedbf1a6b9412ace4c50268a20bc2680d +size 447119 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp index 50d7f1becef..c405f483aed 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_sm80.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:9aceb502c1a95f58f1eab515cf2aeac92be6d255ef405008a4fd871fd54e9ba6 +oid sha256:cb063c946558e6928faabb85df9775fecd2b9444b40b3e06cf0f863db80a5ad8 size 1242842 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 1a74df12889..e88a310b64b 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ec96248452f638bb9ca50d3630dd67caf71322c01b17aff301c4a98eb7e27974 -size 1215548 +oid sha256:31e6b7442b277f5206cc1d70fa6021f36170265b311106281e88b4611d1a5b6b +size 1220284 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index e03f7c2575c..0db1249a289 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:dabc44860e81532e9b7ecb35773d0ad409d45361e20c9510d24387039999a7c3 -size 1720698 +oid sha256:c1342769efa91794d5bd35ac623b3014738b075b2671441668e2f0d5c1eef78a +size 1739642 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index b1d87c1278f..4d68087ca12 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0d9c8d1fe282f46c12898ed4851a2640cb33ba5d75c5fe9da8a988f818a0e733 -size 407639 +oid sha256:a49dd8abcca57a64eb2ab4e00e4e0d26edf68488fb67086a4b466f8e6651522e +size 410007 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp index 2a12ddb7118..deb498b1a29 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:849a280994b3fa1f18ca6c3866a16a68a9b02831f134f8dfcf0d34502c1d6772 +oid sha256:a7013b1eea12719ebeaf47facc37ef730bb0d6af03ca2ad890724a25448616a9 size 1102672 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp index a2c78e856df..4bf37280a0e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:4e209b01409585433406f8392c77a7398270ee1b58446b728cf74faa6fe1bf9a +oid sha256:a16aeaf5d11a4c25461452b5f3145136b31861ef9c443d7ec82066565275d6f8 size 629884 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index 61bbc8d762e..0115c2c36f3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:0a22bb0202916831eced0a44acbab769d5647937155e0a2b5e6d0d0cb83c726f -size 1122394 +oid sha256:a7d4526887fe860e0d9c482fc7fe2cfe646c7a20bc8a0813ce33a01fd9cc733c +size 1125550 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp index e0170f8db7f..5d1d2207551 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:582d17d48c7a751a345f74cc8c74f9b8c05278ddfc185da4906310a4973a9bdb -size 1547030 +oid sha256:b880e78ffc354edb541bd612e543dd894843fc4163f7bd65ce53282892381b8a +size 1566764 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp index 456d75f72fe..fbab68022c3 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:70f02b7329eef7ceeb73dd43c3bf8f6ea6132c593bba6dbbed720d8b8ff0c287 +oid sha256:de26acaa532f197e339b6d5b2a2dd8032d505c9e169fce38000b02b2a4188eff size 603809 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp index 0c0712acaf1..8315c080842 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_64_32_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:f67d4e70c39bf379ed0f3ef73a3690ac64efaee1e7134c793a760924c270f046 +oid sha256:cef5bcfe63650bc924d9e45d2755b50940534999fb4fbad3a8abf0ba73b9245a size 329935 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp index f35d06ef066..c57602da24a 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:c2c284c6cb66207bd204bd1b6abe45aa8bf2e0c92631681861df237b8f849a46 -size 363451 +oid sha256:b332d4c6047c98b504cd3be72cc5028d240621c8e0a3260d64c17804982104db +size 365029 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp index 73d9547cf2d..a0fe210d9b0 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:d3bede327d80be420e7bf011ee1a4156365afff7020bbf5a8434da18cb19fb23 -size 1093202 +oid sha256:a16c23767a2e5efbd7330728ed87af2ec62a7731debe1da557705c6db6d3268e +size 1096360 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp index 998e46d1f16..3c10c481369 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_k_v_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5ee7695bd5bb0a03eafe29a497060d84caec96ca4d159e99e4f02b99977dd2a6 -size 1469690 +oid sha256:66950bc137b734d509f0574152bcf9cf7efcb17a7483450d5fdbf480e9f83001 +size 1486266 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp index a76bf3814f7..0b4847611fd 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_softmax_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:cecca7ad5c652989a3008c8219177811ab9c7d617adbbc9ed8548141803c66f5 -size 694578 +oid sha256:bba586d9fe487c49cef2abfbfb0a078dde907d28e04b4d2335018cdb7031879c +size 701682 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp index 71a5743dd98..fb1751942e2 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:bd6847c0e897eb794a9b1ff67e64358527fe64c3e01fc214545cf76ec60edc6d -size 644046 +oid sha256:d3e45ab30e471f4649807f5b7640512e2c6678cf623cadfcb26c93eb4ad60ec0 +size 654306 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp index ea50fb06310..ca8b31a0105 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:118cc6d4a5e3e12ce0f2727361fd1d52d1a49c67d0bd1837c24e528c064a0dd7 -size 415557 +oid sha256:1932937b7f4ad0370341c77a03db133dd676bdf844b13eb45ec10243d1dfd16b +size 417135 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp index 285c32ec70e..85d85fa4d99 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:36d6c97af5fb15f32cd1ff13f53dd98a7d670cb80ee766765f42cc453f730812 -size 1195826 +oid sha256:c11f5d464b0486023b78babfdfe9d2768e4b0d13caeb436d6f73110ede72498c +size 1198982 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp index bd266daa63a..465fcafeced 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_q_paged_kv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:7775bbc1b43487236cf7570d2ed900f1c9830eab70aac1fa9dc59c439cc0c687 -size 1657562 +oid sha256:3bac9b40302bbfc6ee5a49e5c45d3238f46cff45619acd1b098d90e758d3ce30 +size 1675716 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp index 2d3c2887bea..c65fa93d24e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_alibi_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:199b1ff3cc3d0ff04477ff8f1e6390dd62b3a7c9dd264cc73ce6c716af20a0f9 -size 366603 +oid sha256:26f09ab86b52c40b283652e555f677850f00902151d17e375e016b9a99a97794 +size 368183 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp index e0073c3730b..36bdbdda6bf 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:2e743b470f9607abcbc8b71e7ef67455e6104daf3a80d0bd012a96ecf90a8f18 +oid sha256:960c3f9e4fe46fc6390207ba0ed85ec25435045e2213b60c5d44ea9ab4fa56aa size 1128730 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp index 1553e77aee6..58a89a84a2e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:366aa4e9f3263f73c4e76c0ea8008c0449b6d89bcade761500af949912786e32 +oid sha256:ac167d89ea3150f7b65614645ef09f13e2543bdc0523c1eddce5bbd9cfd306ee size 644892 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp index cd0531dde0e..cd64d2fe381 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_softcapping_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:5b8a8d76e17a24afd7af1dc5e112828f98ace78e3f85a7efaadb0cf1937085cc -size 1093198 +oid sha256:9d0cf59a8114940070448d87d02d9e83d53bb371ca9915c3983e03626d17024e +size 1097144 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp index 54fd20f69c9..f3194ad186e 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_128_S_qkv_128_tma_ws_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:aeffa2db467fbae3ace85fae9f31e2b8a7c0923ab349ade42318ae6f55249ac8 -size 1462582 +oid sha256:ff1449b6795f5beda0b6a62e8a1171ce952b07c4e63b607c06f5fedddb2debe9 +size 1480736 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp index 673041f7af9..87c5afddecc 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:ffc92513e64631c33290f1e88e5666f5b85251506d527745c493f2e90da39de4 +oid sha256:cb14ae0271f8a83216f67c111530d3fe1be2231541ded5f992ff45226ae90e69 size 678808 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp index c39e7fa450e..dad37ebd422 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_flash_attention_fp16_fp32_64_32_S_qkv_128_softcapping_sm90.cubin.cpp @@ -1,3 +1,3 @@ version https://git-lfs.github.com/spec/v1 -oid sha256:faad8cb1e44f5e16f61720966d2a6c9e782461c209cd8000263b50d42093444d +oid sha256:46a0d8e0a9495e03f72526b4ee04fa3d2a2d87984057b44550cabf4ffa745ef4 size 370201 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_32_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_32_ldgsts_sm90.cubin.cpp deleted file mode 100644 index e2ee736b49d..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_32_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:dd930ed415b0303a973a37550ee33fa4975ad6be0cc58d461370b127f9a90f8e -size 1020542 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_64_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_64_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 95d9b2bf647..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_128_64_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:4f2b243127e1ce00a850a10cca104ffc42512711f434fbdf8683eeeb49b8ce42 -size 1056062 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_32_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_32_ldgsts_sm90.cubin.cpp deleted file mode 100644 index 0c093db643c..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_32_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:2ce9cc89b1db7f7e4b76b94cf1c3b04db49a2d86b529b1fc85b19057a99bc9fa -size 1007924 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_64_ldgsts_sm90.cubin.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_64_ldgsts_sm90.cubin.cpp deleted file mode 100644 index c24e239dd0c..00000000000 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/cubin/fmha_v2_fp16_fp32_128_64_ldgsts_sm90.cubin.cpp +++ /dev/null @@ -1,3 +0,0 @@ -version https://git-lfs.github.com/spec/v1 -oid sha256:e176513fa0074d688620299dfca53adc3902491e97ea9b6938a4ceb2fcf17ef5 -size 1068702 diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp index a0197d8083a..29c9eea3391 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fmhaRunner.cpp @@ -238,6 +238,9 @@ void FusedMHARunnerV2::setupKernelParams(MHARunnerParams runnerParams) mKernelParams.packed_mask_ptr = runnerParams.packedMaskPtr; mKernelParams.cu_mask_rows = reinterpret_cast(runnerParams.cuMaskRowsPtr); } + TLLM_CHECK_WITH_INFO( + runnerParams.attentionSinksPtr == nullptr || mSM == kSM_90, "The attention sinks is only supported on SM90."); + mKernelParams.attention_sinks_ptr = runnerParams.attentionSinksPtr; mKernelParams.cu_q_seqlens = reinterpret_cast(runnerParams.cuQSeqLenPtr); mKernelParams.tile_id_counter_ptr = reinterpret_cast(runnerParams.tileCounterPtr); // TRT doesn't support host scales. Use device scales instead. diff --git a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h index 96435cca528..e9098866161 100644 --- a/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h +++ b/cpp/tensorrt_llm/kernels/contextFusedMultiHeadAttention/fused_multihead_attention_common.h @@ -263,6 +263,8 @@ struct MHARunnerParams void* outputSfPtr; // The softmax_status ptr for RingAttention. void* softmaxStatsPtr; + // The attention sinks ptr. + float const* attentionSinksPtr; // The packed mask ptr. void const* packedMaskPtr; // The cumulative Q sequence lengths. @@ -352,6 +354,8 @@ struct Fused_multihead_attention_params_v2 KVBlockArrayForContextFMHA paged_kv_cache; // The mask to implement drop-out. void const* packed_mask_ptr; + // The attention sinks. + float const* attention_sinks_ptr; // The O matrix (output). void* o_ptr; // The Softmax stats vector of layout [2, B, S, H], including softmax_sum and softmax_max diff --git a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h index 934679a944c..6758558e277 100644 --- a/cpp/tensorrt_llm/kernels/customAllReduceKernels.h +++ b/cpp/tensorrt_llm/kernels/customAllReduceKernels.h @@ -56,6 +56,8 @@ enum class AllReduceStrategyType : int8_t ONESHOT = 4, TWOSHOT = 5, LOWPRECISION = 6, + MNNVL = 7, + NCCL_SYMMETRIC = 8, }; enum class AllReduceStrategyConfig : int8_t diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt index 7a02cdee73f..fd89ae4a194 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/CMakeLists.txt @@ -218,6 +218,11 @@ if(USING_OSS_CUTLASS_MOE_GEMM) set(MOE_GEMM_SRC_CU_LAUNCHER ${MOE_GEMM_SRC_CU}) list(FILTER MOE_GEMM_SRC_CU_LAUNCHER EXCLUDE REGEX ".*moe_gemm_kernels_.*") list(FILTER MOE_GEMM_SRC_CU INCLUDE REGEX ".*moe_gemm_kernels_.*") + set(MOE_GEMM_SRC_CU_HOPPER_FP4 ${MOE_GEMM_SRC_CU}) + list(FILTER MOE_GEMM_SRC_CU_HOPPER_FP4 INCLUDE REGEX + ".*moe_gemm_kernels_(bf16|fp16)_fp4.*") + list(FILTER MOE_GEMM_SRC_CU EXCLUDE REGEX + ".*moe_gemm_kernels_(bf16|fp16)_fp4.*") set(MOE_GEMM_SRC_CU_FP4 ${MOE_GEMM_SRC_CU}) list(FILTER MOE_GEMM_SRC_CU_FP4 INCLUDE REGEX ".*fp4.*") list(FILTER MOE_GEMM_SRC_CU EXCLUDE REGEX ".*fp4.*") @@ -230,6 +235,10 @@ if(USING_OSS_CUTLASS_MOE_GEMM) add_library(_moe_gemm_launcher OBJECT ${MOE_GEMM_SRC_CU_LAUNCHER}) add_cuda_architectures(_moe_gemm_launcher 89) + add_library(_moe_gemm_hopper_fp4 OBJECT ${MOE_GEMM_SRC_CU_HOPPER_FP4}) + set_cuda_architectures(_moe_gemm_hopper_fp4 90) + process_target(_moe_gemm_hopper_fp4 true false) + add_library(_moe_gemm_fp4 OBJECT ${MOE_GEMM_SRC_CU_FP4}) set_cuda_architectures(_moe_gemm_fp4 100f 120f) process_target(_moe_gemm_fp4 false true) @@ -239,8 +248,9 @@ if(USING_OSS_CUTLASS_MOE_GEMM) process_target(_moe_gemm_fp8 true true) add_instantiations(moe_gemm_src ${INSTANTIATION_GENERATION_DIR}/gemm_grouped) - target_link_libraries(moe_gemm_src PRIVATE _moe_gemm_launcher _moe_gemm_fp4 - _moe_gemm_fp8) + target_link_libraries( + moe_gemm_src PRIVATE _moe_gemm_launcher _moe_gemm_hopper_fp4 _moe_gemm_fp4 + _moe_gemm_fp8) target_include_directories( moe_gemm_src PUBLIC ${PROJECT_SOURCE_DIR}/tensorrt_llm/cutlass_extensions/include) diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h index e6c3a6bbfa2..646be2575ca 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/common.h @@ -27,6 +27,7 @@ enum class ActivationType Silu, Swiglu, Geglu, + SwigluBias, Identity, InvalidType }; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h index 7ddd756e0d0..1237884d13c 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_gemm_kernels.h @@ -210,8 +210,10 @@ struct TmaWarpSpecializedGroupedGemmInput struct INT4GroupwiseParams { - constexpr static int group_size = 128; // Unused, hard-coded to 128 + constexpr static int int4_group_size = 128; + constexpr static int wfp4a16_group_size = 32; bool enabled = false; + bool use_wfp4a16 = false; using SFA = __nv_bfloat16; using SFB = __nv_bfloat16; // Unused using ProblemShapeInt = cutlass::gemm::GroupProblemShape>; @@ -254,7 +256,8 @@ struct TmaWarpSpecializedGroupedGemmInput constexpr bool isGatedActivation(ActivationType activation_type) { - return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu; + return activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu + || activation_type == ActivationType::SwigluBias; } template && (std::is_same_v || std::is_same_v); +#else + static constexpr bool use_wfp4a16 = std::is_same_v && std::is_same_v; +#endif #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v @@ -282,6 +291,7 @@ class MoeGemmRunner static constexpr bool use_w4afp8 = false; static constexpr bool use_wfp4afp4 = false; #endif + static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; #if defined(ENABLE_FP4) static constexpr bool use_fp4 = std::is_same_v; @@ -306,9 +316,9 @@ class MoeGemmRunner [[nodiscard]] bool isTmaWarpSpecialized(cutlass_extensions::CutlassGemmConfig gemm_config) const; [[nodiscard]] bool supportsTmaWarpSpecialized() const; - [[nodiscard]] bool isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const; - [[nodiscard]] bool supportsFusedGatedActivation(bool is_gated_activation, int gemm_n, int gemm_k) const; + [[nodiscard]] bool isFusedGatedActivation(cutlass_extensions::CutlassGemmConfig gemm_config, + ActivationType activation_type, int gemm_n, int gemm_k) const; + [[nodiscard]] bool supportsFusedGatedActivation(ActivationType activation_type, int gemm_n, int gemm_k) const; size_t getMaxWorkspaceSize(int num_experts) const; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h index c7c9a55b959..ca256ae0d6b 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_kernels.h @@ -87,6 +87,62 @@ struct LoraParams namespace cutlass_kernels { +static inline size_t pad_to_multiple_of_16(size_t const& input) +{ + static constexpr int ALIGNMENT = 16; + return ALIGNMENT * ((input + ALIGNMENT - 1) / ALIGNMENT); +} + +class CubKeyValueSorter +{ +public: + CubKeyValueSorter(); + + CubKeyValueSorter(int const num_experts_per_node); + + void updateNumExperts(int const num_experts_per_node); + + static size_t getWorkspaceSize(size_t const num_key_value_pairs, int const num_experts_per_node); + + void run(void* workspace, size_t const workspace_size, int const* keys_in, int* keys_out, int const* values_in, + int* values_out, size_t const num_key_value_pairs, cudaStream_t stream); + +private: + static int expertsToBits(int experts); + int num_experts_; + int num_bits_; +}; + +struct ActivationParams +{ + ActivationType activation_type; + float const* swiglu_alpha = nullptr; + float const* swiglu_beta = nullptr; + float const* swiglu_limit = nullptr; + + explicit ActivationParams(ActivationType activation_type) + : activation_type(activation_type) + { + TLLM_CHECK_WITH_INFO(activation_type != ActivationType::SwigluBias, + "SwigluBias is not supported in ActivationParams without swiglu_alpha and swiglu_beta"); + } + + ActivationParams( + ActivationType activation_type, float const* swiglu_alpha, float const* swiglu_beta, float const* swiglu_limit) + : activation_type(activation_type) + , swiglu_alpha(swiglu_alpha) + , swiglu_beta(swiglu_beta) + , swiglu_limit(swiglu_limit) + { + } + + // TODO Port everything properly and get rid of these implicit conversions + operator ActivationType() const + { + return activation_type; + } +}; + /** * \brief Describes what parallelism mode the MoE is using * @@ -392,14 +448,14 @@ class CutlassMoeFCRunnerInterface = 0; virtual std::vector getTactics() = 0; - virtual void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, - float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, - QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, - bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) + virtual void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, + int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, + void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, + void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) = 0; // Aliases for profiling the gemms @@ -410,7 +466,7 @@ class CutlassMoeFCRunnerInterface float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids) @@ -474,6 +530,13 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface = tensorrt_llm::kernels::fp8_blockscale_gemm::CutlassFp8BlockScaleGemmRunnerInterface; using ScaleBiasType = BackBoneType; using Self = CutlassMoeFCRunner; + +#if defined(ENABLE_BF16) + static constexpr bool use_wfp4a16 + = std::is_same_v && (std::is_same_v || std::is_same_v); +#else + static constexpr bool use_wfp4a16 = std::is_same_v && std::is_same_v; +#endif #if defined(ENABLE_FP8) static constexpr bool use_fp8 = (std::is_same_v || std::is_same_v) &&!std::is_same_v; @@ -485,6 +548,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface static constexpr bool use_fp8 = false; static constexpr bool use_w4afp8 = false; #endif + static constexpr bool use_w4_groupwise = use_w4afp8 || use_wfp4a16; #if defined(ENABLE_FP4) static constexpr bool act_fp4 = std::is_same_v; static constexpr bool weight_fp4 = std::is_same_v; @@ -539,14 +603,14 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface return RunnerType::getConfigs(sm); } - void runMoe(void const* input_activations, void const* input_sf, int const* token_selected_experts, - float const* token_final_scales, void const* fc1_expert_weights, void const* fc1_expert_biases, - ActivationType fc1_activation_type, void const* fc2_expert_weights, void const* fc2_expert_biases, - QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts, int const experts_per_token, char* workspace_ptr, void* final_output, - int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, bool const enable_alltoall, - bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, bool min_latency_mode, - MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; + void runMoe(void const* input_activations, void const* input_sf, bool const swizzled_input_sf, + int const* token_selected_experts, float const* token_final_scales, void const* fc1_expert_weights, + void const* fc1_expert_biases, ActivationParams fc1_activation_type, void const* fc2_expert_weights, + void const* fc2_expert_biases, QuantParams quant_params, int64_t const num_rows, int64_t const hidden_size, + int64_t const inter_size, int const num_experts, int const experts_per_token, char* workspace_ptr, + void* final_output, int* unpermuted_row_to_permuted_row, MOEParallelismConfig parallelism_config, + bool const enable_alltoall, bool use_lora, LoraParams& lora_params, bool use_deepseek_fp8_block_scale, + bool min_latency_mode, MoeMinLatencyParams& min_latency_params, cudaStream_t stream) override; // We make these GEMM1 & GEMM2 static because they need to be stateless for the profiler to work static void gemm1(MoeGemmRunner& gemm_runner, @@ -563,7 +627,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids); @@ -591,7 +655,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface float const* const fc2_fp8_quant, TmaWarpSpecializedGroupedGemmInput::ElementSF const* fc1_fp4_act_flat, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc2_fp4_act_flat, QuantParams quant_params, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, - int const num_experts_per_node, ActivationType fc1_activation_type, float const** alpha_scale_ptr_array, + int const num_experts_per_node, ActivationParams fc1_activation_type, float const** alpha_scale_ptr_array, bool bias_is_broadcast, bool use_deepseek_fp8_block_scale, cudaStream_t stream, cutlass_extensions::CutlassGemmConfig config, bool min_latency_mode, int* num_active_experts_per, int* active_expert_global_ids) override @@ -679,7 +743,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface private: std::pair setupTmaWarpSpecializedInputs( - int64_t num_rows, int64_t expanded_num_rows, ActivationType fc1_activation_type, int64_t hidden_size, + int64_t num_rows, int64_t expanded_num_rows, ActivationParams fc1_activation_type, int64_t hidden_size, int64_t inter_size, int64_t num_experts_per_node, void const* input_activations_void, TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void* final_output, WeightType const* fc1_expert_weights, WeightType const* fc2_expert_weights, QuantParams quant_params, @@ -727,7 +791,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface bool mayHaveFinalizeFused() const { return moe_gemm_runner_.supportsTmaWarpSpecialized() && moe_gemm_runner_.getSM() == 90 - && !use_deterministic_hopper_reduce_ && !use_w4afp8; + && !use_deterministic_hopper_reduce_ && !use_w4_groupwise; } // TODO: This should eventually take the quant params to give more flexibility @@ -758,7 +822,7 @@ class CutlassMoeFCRunner : public CutlassMoeFCRunnerInterface WeightType const* const fc1_expert_weights, ScaleBiasType const* const fc1_expert_biases, float const* const fc2_fp8_quant, int64_t const num_rows, int64_t const expanded_num_rows, int64_t const hidden_size, int64_t const inter_size, int const num_experts_per_node, - ActivationType fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); + ActivationParams fc1_activation_type, QuantParams& quant_params, cudaStream_t stream); static void BlockScaleFC2(DeepSeekBlockScaleGemmRunner& gemm_runner, T const* const input, void* const gemm_output, OutputType* const final_output, int64_t const* const expert_first_token_offset, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h index b1676993ded..0b86afda684 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/include/moe_util_kernels.h @@ -58,7 +58,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream); + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + void const* prequant_scales, cudaStream_t stream); template void finalizeMoeRoutingKernelLauncher(GemmOutputType const* expanded_permuted_rows, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl index a0ebfbde343..651b7f14060 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/launchers/moe_gemm_tma_ws_mixed_input_launcher.inl @@ -85,15 +85,14 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput::type; - using ElementA = cutlass::float_e4m3_t; + using ElementA = typename TllmToCutlassTypeAdapter::type; using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) // B matrix configuration - // using ElementB = typename TllmToCutlassTypeAdapter::type; - using ElementB = typename cutlass::int4b_t; + using ElementB_ = typename TllmToCutlassTypeAdapter::type; + using ElementB = std::conditional_t, cutlass::int4b_t, ElementB_>; using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of @@ -108,9 +107,13 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput>; // Scale configuration - constexpr int PackedScalesNum = get<2>(CTAShape{}) / 128; - using ElementScalePacked - = cutlass::Array; + constexpr bool use_wfp4a16 = std::is_same_v; + constexpr int group_size = use_wfp4a16 ? cutlass::gemm::collective::detail::mxfp4_group_size + : cutlass::gemm::collective::detail::int4_group_size; + constexpr int PackedScalesNum = get<2>(CTAShape{}) / group_size; + using ElementScale = std::conditional_t; + using ElementScalePacked = cutlass::Array; using LayoutScale = cutlass::layout::RowMajor; // C/D matrix configuration @@ -170,20 +173,21 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput(hopper_inputs.ptr_b), hopper_inputs.stride_b, reinterpret_cast(hopper_inputs.ptr_a), hopper_inputs.stride_a, reinterpret_cast(hopper_inputs.int4_groupwise_params.ptr_s_a), - hopper_inputs.int4_groupwise_params.stride_s_a, int(inputs.groupwise_quant_group_size)}, + hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), hopper_inputs.stride_c, reinterpret_cast(hopper_inputs.default_epilogue.ptr_d), hopper_inputs.default_epilogue.stride_d}, @@ -205,7 +209,7 @@ void sm90_generic_mixed_moe_gemm_kernelLauncher(GroupedGemmInput(hopper_inputs.ptr_b), hopper_inputs.stride_b, reinterpret_cast(hopper_inputs.ptr_a), hopper_inputs.stride_a, reinterpret_cast(hopper_inputs.int4_groupwise_params.ptr_s_a), - hopper_inputs.int4_groupwise_params.stride_s_a, int(inputs.groupwise_quant_group_size)}, + hopper_inputs.int4_groupwise_params.stride_s_a, group_size}, {fusion_args, reinterpret_cast(hopper_inputs.ptr_c), hopper_inputs.stride_c, reinterpret_cast(hopper_inputs.default_epilogue.ptr_d), hopper_inputs.default_epilogue.stride_d}, diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu new file mode 100644 index 00000000000..be29019bc6a --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_bf16_fp4.cu @@ -0,0 +1,24 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_template_dispatch.h" + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +#ifdef ENABLE_BF16 +template class MoeGemmRunner<__nv_bfloat16, __nv_fp4_e2m1, __nv_bfloat16>; +#endif +} // namespace tensorrt_llm::kernels::cutlass_kernels diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu new file mode 100644 index 00000000000..f1a885ea77d --- /dev/null +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_kernels_fp16_fp4.cu @@ -0,0 +1,22 @@ +/* + * Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "moe_gemm_template_dispatch.h" + +namespace tensorrt_llm::kernels::cutlass_kernels +{ +template class MoeGemmRunner; +} diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h index ff582ec6e68..56a8299f18f 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch.h @@ -99,6 +99,7 @@ struct genericMoeGemmKernelLauncher static_assert(cutlass::platform::is_same::value || cutlass::platform::is_same::value + || cutlass::platform::is_same::value || cutlass::platform::is_same::value); static_assert(arch::kMinComputeCapability < 90, "Sm90+ architecture should use specialized kernels"); @@ -503,7 +504,8 @@ MoeGemmRunner::getAmpereConfigs(int sm auto config_type_param = static_cast( weight_only_flag | simt_only_flag | grouped_gemm_flag | enable_hopper | fp8_only_flag); - if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation() || (use_w4afp8 && sm != 89)) + if (!kernels::cutlass_kernels::isValidAmpereMOESpecialisation() || (use_w4afp8 && sm != 89) + || use_wfp4a16) { return {}; } @@ -580,18 +582,19 @@ int MoeGemmRunner::getSM() const // currently support sm80 bf16/fp16 gate activation, only set predication tensor for m direction template bool MoeGemmRunner::supportsFusedGatedActivation( - bool is_gated_activation, int gemm_n, int gemm_k) const + ActivationType activation_type, int gemm_n, int gemm_k) const { constexpr bool ENABLE_FUSED_GATED_ACTIVATION = true; - return is_gated_activation && std::is_same_v && !std::is_same_v && !use_fp8 - && (this->getSM() >= 80) && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; + return (activation_type == ActivationType::Swiglu || activation_type == ActivationType::Geglu) + && std::is_same_v && !std::is_same_v && !use_fp8 && (this->getSM() >= 80) + && (gemm_k % 64 == 0) && (gemm_n % 64 == 0) && ENABLE_FUSED_GATED_ACTIVATION; } template bool MoeGemmRunner::isFusedGatedActivation( - cutlass_extensions::CutlassGemmConfig gemm_config, bool is_gated_activation, int gemm_n, int gemm_k) const + cutlass_extensions::CutlassGemmConfig gemm_config, ActivationType activation_type, int gemm_n, int gemm_k) const { - return supportsFusedGatedActivation(is_gated_activation, gemm_n, gemm_k) && !gemm_config.is_tma_warp_specialized; + return supportsFusedGatedActivation(activation_type, gemm_n, gemm_k) && !gemm_config.is_tma_warp_specialized; } template @@ -623,26 +626,41 @@ void MoeGemmRunner::dispatchToArch( if (sm_ >= 75 && sm_ < 80) { - dispatchMoeGemmToCutlass( - inputs, multi_processor_count_); + if constexpr (!std::is_same_v) + { + dispatchMoeGemmToCutlass( + inputs, multi_processor_count_); + } + else + { + TLLM_THROW("FP4 data type is not supported on SM < 90"); + } } else if (sm_ >= 80 && sm_ < 90) { - if constexpr (use_fp8 || use_w4afp8) + + if constexpr (!std::is_same_v) { + if constexpr (use_fp8 || use_w4afp8) + { #if defined(ENABLE_FP8) - static_assert(!std::is_same_v && !std::is_same_v, - "FP8 GEMM Output not supported"); + static_assert(!std::is_same_v && !std::is_same_v, + "FP8 GEMM Output not supported"); #endif - TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); - dispatchMoeGemmToCutlass( - inputs, multi_processor_count_); + TLLM_CHECK_WITH_INFO(sm_ == 89, "For sm >= 80 and < 90, fp8 is only supported with sm == 89"); + dispatchMoeGemmToCutlass( + inputs, multi_processor_count_); + } + else + { + dispatchMoeGemmToCutlass( + inputs, multi_processor_count_); + } } else { - dispatchMoeGemmToCutlass( - inputs, multi_processor_count_); + TLLM_THROW("FP4 data type is not supported on SM < 90"); } } else if (sm_ >= 90) @@ -659,7 +677,7 @@ void MoeGemmRunner::dispatchToArch( } if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() - && !use_w4afp8) + && !use_w4_groupwise) { // We allow both tma warp specialized and SM80 configurations to coexist because for some cases with small // numbers of tokens SM80 is faster. We check here to see which is selected @@ -701,33 +719,39 @@ void MoeGemmRunner::dispatchToArch( // Hopper finegrained INT4 WS grouped GEMM if constexpr (use_w4afp8) { - if (inputs.gemm_config.is_tma_warp_specialized) + TLLM_CHECK_WITH_INFO( + inputs.gemm_config.is_tma_warp_specialized, "w4afp8 is only supported for TMA warp specialization"); + // EpilogueTag is ignored + if (inputs.k % 512 == 0) { - // EpilogueTag is ignored - if (inputs.k % 512 == 0) - { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( - inputs, hopper_inputs, multi_processor_count_, nullptr); - } - else if (inputs.k % 256 == 0) - { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( - inputs, hopper_inputs, multi_processor_count_, nullptr); - } - else if (inputs.k % 128 == 0) - { - sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( - inputs, hopper_inputs, multi_processor_count_, nullptr); - } - else - { - TLLM_THROW("Invalid GEMM K size %d", (int) inputs.k); - } - return; - }; + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(inputs, hopper_inputs, multi_processor_count_, nullptr); + } + else if (inputs.k % 256 == 0) + { + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(inputs, hopper_inputs, multi_processor_count_, nullptr); + } + else if (inputs.k % 128 == 0) + { + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(inputs, hopper_inputs, multi_processor_count_, nullptr); + } + else + { + TLLM_THROW("Invalid GEMM K size %d", (int) inputs.k); + } + return; + } + + if constexpr (use_wfp4a16) + { + TLLM_CHECK_WITH_INFO( + inputs.gemm_config.is_tma_warp_specialized, "wfp4a16 is only supported for TMA warp specialization"); + // EpilogueTag is ignored + sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass(inputs, hopper_inputs, multi_processor_count_, nullptr); + return; } #endif @@ -779,7 +803,7 @@ size_t MoeGemmRunner::getMaxWorkspaceS template size_t MoeGemmRunner::calcMaxWorkspaceSize(int num_experts) const { - if constexpr (use_w4afp8) + if constexpr (use_w4_groupwise) { return calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput( num_experts, multi_processor_count_); @@ -788,7 +812,8 @@ size_t MoeGemmRunner::calcMaxWorkspace { return 0; } - if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() && !use_w4afp8) + if constexpr (kernels::cutlass_kernels::isValidTmaWarpSpecializedMOESpecialisation() && !use_w4afp8 + && !use_wfp4a16) { auto configs = getTmaWarpSpecializedConfigs(sm_); auto fpX_block_scaling_type = TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE; diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h index 9a9f2ebeb38..affa4d8c409 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_gemm_template_dispatch_tma_ws_mixed_dtype.h @@ -153,10 +153,13 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( // We also only instantiate configs here where threadblockShapeM == warpShapeM since those usually perform the best // for mixed type gemms. - constexpr int Ktile = 128 * PackedScalesNum / sizeof(T); - TLLM_CHECK(sizeof(T) == 1); + constexpr int Ntile = (std::is_same_v) ? 64 : 128; + constexpr int Ktile = (std::is_same_v) ? 128 : 128 * PackedScalesNum / sizeof(T); + TLLM_CHECK(sizeof(T) == (std::is_same_v) ? 2 : 1); + using _Ntile = Int; using _Ktile = Int; + switch (inputs.gemm_config.tile_config_sm90) { case tkc::CutlassTileConfigSM90::CtaShape64x16x128B: @@ -172,8 +175,8 @@ void sm90_dispatch_moe_mixed_dtype_gemm_to_cutlass( inputs, hopper_inputs, sm_count_, workspace_size); break; case tkc::CutlassTileConfigSM90::CtaShape64x128x128B: - sm90_dispatch_moe_mixed_dtype_gemm_config>( - inputs, hopper_inputs, sm_count_, workspace_size); + sm90_dispatch_moe_mixed_dtype_gemm_config>(inputs, hopper_inputs, sm_count_, workspace_size); break; // case tkc::CutlassTileConfigSM90::CtaShape64x256x128B: // sm90_dispatch_moe_mixed_dtype_gemm_config size_t calcMaxWorkspaceSizeTmaWarpSpecializedMixedInput(int num_experts, int sm_count_) { size_t count = 0; + constexpr int Ktile = (std::is_same_v) ? 256 : 512; + using _Ktile = Int; + #ifdef COMPILE_HOPPER_TMA_GROUPED_GEMMS GroupedGemmInput inputs{}; inputs.num_experts = num_experts; sm90_generic_mixed_moe_gemm_kernelLauncher, Shape<_1, _1, _1>, + tensorrt_llm::cutlass_extensions::EpilogueOpDefault, Shape<_128, _64, _Ktile>, Shape<_1, _1, _1>, cutlass::gemm::KernelTmaWarpSpecializedCooperative, cutlass::epilogue::TmaWarpSpecializedCooperative, cutlass::WeightOnlyQuantOp::FINEGRAINED_SCALE_ONLY>( inputs, TmaWarpSpecializedGroupedGemmInput{}, sm_count_, &count); diff --git a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu index 0caf687b569..ae4c25f379f 100644 --- a/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu +++ b/cpp/tensorrt_llm/kernels/cutlass_kernels/moe_gemm/moe_kernels.cu @@ -997,12 +997,12 @@ __device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_s TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType scaling_type) { constexpr bool is_fp8 = std::is_same_v; - static constexpr int NumThreadsPerSF = VecSize / CVT_FP4_ELTS_PER_THREAD; + static constexpr int NumThreadsPerSF = VecSize / CVT_ELTS_PER_THREAD; // Quantize the input to FP4 static_assert(std::is_same_v || std::is_same_v); - static_assert(ComputeElem::kElements == CVT_FP4_ELTS_PER_THREAD); + static_assert(ComputeElem::kElements == CVT_ELTS_PER_THREAD); PackedVec packed_vec{}; - for (int i = 0; i < CVT_FP4_ELTS_PER_THREAD / 2; i++) + for (int i = 0; i < CVT_ELTS_PER_THREAD / 2; i++) { packed_vec.elts[i].x = static_cast(post_act_val[i * 2 + 0]); packed_vec.elts[i].y = static_cast(post_act_val[i * 2 + 1]); @@ -1013,10 +1013,9 @@ __device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_s = act_sf_flat + getOffsetActivationSF(expert_id, num_tokens_before_expert, num_cols, scaling_type); // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this expert - auto sf_out - = cvt_quant_to_fp4_get_sf_out_offset( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, - num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); + auto sf_out = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, act_sf_expert, QuantizationSFLayout::SWIZZLED); // Do the conversion and set the output and scaling factor auto func = [&]() @@ -1043,7 +1042,7 @@ __device__ auto quantizePackedFPXValue(ComputeElem& post_act_val, float global_s template __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int64_t source_token_id, int64_t token_id, int64_t elem_idx, int64_t num_cols, TmaWarpSpecializedGroupedGemmInput::ElementSF* act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf = true) { static constexpr int NumThreadsPerSF = VecSize / ElementsPerThread; @@ -1055,20 +1054,31 @@ __device__ void writeSF(int64_t num_tokens_before_expert, int64_t expert_id, int : TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::MXFPX); // Use `token - num_tokens_before_expert` because we want this to be relative to the start of this expert - auto sf_out - = cvt_quant_to_fp4_get_sf_out_offset( - std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, - num_cols, act_sf_expert, FP4QuantizationSFLayout::SWIZZLED); + auto sf_out = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, token_id - num_tokens_before_expert, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, act_sf_expert, QuantizationSFLayout::SWIZZLED); if (sf_out) { if (input_sf) { - auto const sf_in - = cvt_quant_to_fp4_get_sf_out_offset(std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, - num_cols, const_cast(input_sf), - FP4QuantizationSFLayout::SWIZZLED); - *sf_out = *sf_in; + if (swizzled_input_sf) + { + auto const sf_in + = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, const_cast(input_sf), + QuantizationSFLayout::SWIZZLED); + *sf_out = *sf_in; + } + else + { + auto const sf_in + = cvt_quant_get_sf_out_offset( + std::nullopt /* batchIdx */, source_token_id, elem_idx, std::nullopt /* numRows */, + num_cols / VecSize, const_cast(input_sf), + QuantizationSFLayout::LINEAR); + *sf_out = *sf_in; + } } else { @@ -1162,7 +1172,12 @@ __device__ void computeTmaWarpSpecializedInputStrides( { layout_info.int4_groupwise_params.stride_s_a[out_idx] = cutlass::make_cute_packed_stride(TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::StrideSFA{}, - cute::make_shape(gemm_n, gemm_k / 128, 1)); + cute::make_shape(gemm_n, + gemm_k + / (layout_info.int4_groupwise_params.use_wfp4a16 + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size + : TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size), + 1)); } } @@ -1185,8 +1200,13 @@ __device__ void computeTmaWarpSpecializedInputPointers(TmaWarpSpecializedGrouped } if (layout_info.int4_groupwise_params.enabled) { - layout_info.int4_groupwise_params.ptr_s_a[out_idx] - = safe_inc_ptr(w4a8_weight_scale, expert * (gemm_n * gemm_k / 128)); + // The group size of wfp4a16 is multiplied by 2 because each scale uses 1 byte instead of 2 bytes + layout_info.int4_groupwise_params.ptr_s_a[out_idx] = safe_inc_ptr(w4a8_weight_scale, + expert + * (gemm_n * gemm_k + / (layout_info.int4_groupwise_params.use_wfp4a16 + ? TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::wfp4a16_group_size * 2 + : TmaWarpSpecializedGroupedGemmInput::INT4GroupwiseParams::int4_group_size))); } } @@ -1452,8 +1472,8 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp int const* permuted_row_to_unpermuted_row, int64_t const num_tokens, int64_t const hidden_size, int64_t const k, float const* fc1_act_global_scale, bool use_per_expert_act_scale, int64_t const* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, int64_t const num_experts_per_node, - InputActivationsType const* prequant_scales = nullptr) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + int64_t const num_experts_per_node, InputActivationsType const* prequant_scales = nullptr) { static_assert(BlockScalingType == TmaWarpSpecializedGroupedGemmInput::FpXBlockScalingType::NONE || !PRE_QUANT_AWQ, "AWQ and Block Scaling are mutually exclusive"); @@ -1487,7 +1507,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp : TmaWarpSpecializedGroupedGemmInput::MXFPXBlockScaleVectorSize; constexpr int64_t ELEM_PER_THREAD - = (is_nvfp4 || is_mxfp8) ? CVT_FP4_ELTS_PER_THREAD : (128 / sizeof_bits::value); + = (is_nvfp4 || is_mxfp8) ? CVT_ELTS_PER_THREAD : (128 / sizeof_bits::value); // This should be VecSize * 4 elements // We assume at least VecSize alignment or the quantization will fail @@ -1555,7 +1575,7 @@ __global__ void expandInputRowsKernel(InputActivationsType const* unpermuted_inp { assert(act_scale_idx == 0 && "Cannot use per-expert act scale for pre-quantized activations"); writeSF(num_tokens_before_expert, expert, source_row, permuted_row, - elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf); + elem_index, padded_hidden_size, fc1_act_sf_flat, input_sf, swizzled_input_sf); dest_row_ptr[elem_index] = in_vec; } } @@ -1656,7 +1676,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int const* permuted_row_to_unpermuted_row, int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, cudaStream_t stream) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, + void const* prequant_scales, cudaStream_t stream) { #ifdef ENABLE_FP4 TLLM_CHECK_WITH_INFO( @@ -1732,8 +1753,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, config.attrs = attrs; cudaLaunchKernelEx(&config, func, unpermuted_input, permuted_output, unpermuted_scales, permuted_scales, permuted_row_to_unpermuted_row, num_rows, hidden_size, k, quant_params.fp4.fc1.act_global_scale, - use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, num_experts_per_node, - reinterpret_cast(prequant_scales)); + use_per_expert_act_scale, expert_first_token_offset, fc1_act_sf_flat, input_sf, swizzled_input_sf, + num_experts_per_node, reinterpret_cast(prequant_scales)); } #define INSTANTIATE_EXPAND_INPUT_ROWS(InputActivationsType, ExpandedActivationsType) \ @@ -1743,8 +1764,8 @@ void expandInputRowsKernelLauncher(InputActivationsType const* unpermuted_input, int64_t const num_rows, int64_t const hidden_size, int const k, int const num_experts_per_node, \ QuantParams const& quant_params, bool use_per_expert_act_scale, int64_t* expert_first_token_offset, \ TmaWarpSpecializedGroupedGemmInput::ElementSF* fc1_act_sf_flat, \ - TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, void const* prequant_scales, \ - cudaStream_t stream) + TmaWarpSpecializedGroupedGemmInput::ElementSF const* input_sf, bool const swizzled_input_sf, \ + void const* prequant_scales, cudaStream_t stream) // Instantiate the data types that are used by the external pytorch op INSTANTIATE_EXPAND_INPUT_ROWS(float, float); @@ -2007,16 +2028,67 @@ INSTANTIATE_FINALIZE_MOE_ROUTING(float, float, float); INSTANTIATE_FINALIZE_MOE_ROUTING(__nv_bfloat16, __nv_bfloat16, __nv_bfloat16); #endif +// ============================== Activation Adaptors ================================= +template