Skip to content
20 changes: 15 additions & 5 deletions cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include "tensorrt_llm/runtime/worldConfig.h"
#include <NvInferRuntime.h>

#include <array>
#include <cstdint>
#include <limits>
#include <list>
Expand Down Expand Up @@ -68,6 +69,9 @@ using VecUniqueTokens = tensorrt_llm::runtime::VecUniqueTokens;
using LoraTaskIdType = tensorrt_llm::runtime::LoraTaskIdType;
using BlocksPerWindow = std::map<SizeType32, std::tuple<SizeType32, SizeType32>>;

// Type alias for multimodal hash key (hash array + start offset)
using MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>;

template <typename T>
using OptionalRef = tensorrt_llm::common::OptionalRef<T>;

Expand Down Expand Up @@ -107,6 +111,10 @@ struct BlockKey
std::optional<LoraTaskIdType> loraTaskId = std::nullopt;
VecUniqueTokens uniqueTokens;

// Extra keys for multimodal data (similar to VLLM's approach)
// Each extra key is a pair of (mm_hash, start_offset_in_block)
std::vector<MmKey> extraKeys;

BlockKey() = default;

explicit BlockKey(VecTokens const& tokens, std::optional<LoraTaskIdType> loraTaskId = std::nullopt)
Expand All @@ -119,23 +127,25 @@ struct BlockKey
}
}

BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens)
: usesExtraIds(usesExtraIds)
explicit BlockKey(bool usesExtraIds, std::optional<LoraTaskIdType> loraTaskId, VecUniqueTokens uniqueTokens,
std::vector<MmKey> extraKeys = {})
: usesExtraIds{usesExtraIds}
, loraTaskId{loraTaskId}
, uniqueTokens{std::move(uniqueTokens)}
, extraKeys{std::move(extraKeys)}
{
}

bool operator==(BlockKey const& other) const noexcept
{
return (
usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId && uniqueTokens == other.uniqueTokens);
return (usesExtraIds == other.usesExtraIds && loraTaskId == other.loraTaskId
&& uniqueTokens == other.uniqueTokens && extraKeys == other.extraKeys);
}

int partialMatch(BlockKey const& other) const noexcept
{
SizeType32 numMatched{0};
if (loraTaskId == other.loraTaskId)
if (loraTaskId == other.loraTaskId && extraKeys == other.extraKeys)
{
auto [matchEnd, otherMatchEnd] = std::mismatch(
uniqueTokens.begin(), uniqueTokens.end(), other.uniqueTokens.begin(), other.uniqueTokens.end());
Expand Down
107 changes: 103 additions & 4 deletions cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -76,14 +76,82 @@ std::list<std::vector<T>> chopVectorIntoBlocks(
return blockedVectors;
}

inline uint8_t getNthByte(SizeType32 hashPart, uint8_t byteIdx) noexcept
{
return static_cast<uint8_t>((hashPart >> (24 - byteIdx * 8)) & 0xFF);
}

std::vector<MmKey> generateBlockHashExtraKeys(
tensorrt_llm::batch_manager::LlmRequest const& llmRequest, SizeType32 startTokenIdx, SizeType32 endTokenIdx)
{
auto const multimodalHashes = llmRequest.getMultimodalHashes();
auto const multimodalPositions = llmRequest.getMultimodalPositions();
auto const multimodalLengths = llmRequest.getMultimodalLengths();

if (!multimodalHashes || !multimodalPositions || !multimodalLengths || !(*multimodalHashes)
|| (*multimodalHashes)->empty() || !(*multimodalPositions) || (*multimodalPositions)->empty()
|| !(*multimodalLengths) || (*multimodalLengths)->empty())
{
return {};
}

if ((*multimodalHashes)->size() != (*multimodalPositions)->size()
|| (*multimodalPositions)->size() != (*multimodalLengths)->size())
{
TLLM_LOG_WARNING("Multimodal data arrays have mismatched sizes");
return {};
}

std::vector<MmKey> extraKeys; // MmKey = std::pair<std::array<uint8_t, 32>, SizeType32>
extraKeys.reserve((*multimodalPositions)->size());
std::array<uint8_t, 32> mmHashArray;

for (size_t i = 0; i < (*multimodalPositions)->size(); ++i)
{
auto const& startPos = (*(*multimodalPositions))[i];
auto const& length = (*(*multimodalLengths))[i];
auto const& mmHashVector = (*(*multimodalHashes))[i];

TLLM_CHECK_WITH_INFO(mmHashVector.size() == 8, "Multimodal hash vector has unexpected size: %zu (expected 8)",
mmHashVector.size());

// mmHashVector[j] comes from Python's int(hex_chunk, 16)
// where hex_chunk like "00010203" means 0x00 is MSB and 0x03 is LSB (big endian)
// Convert 8x 32-bit integers into a 32-byte array preserving Blake3 hash byte order
// Example: hashPart = 0x00010203 → mmHashArray[0:3] = [0x00, 0x01, 0x02, 0x03]
for (size_t j = 0; j < 8; ++j)
{
auto const& hashPart = mmHashVector[j];
for (uint8_t byteIdx = 0; byteIdx < 4; ++byteIdx)
{
mmHashArray[j * 4 + byteIdx] = getNthByte(hashPart, byteIdx);
}
}

// Check if this multimodal content overlaps with the current block
if (endTokenIdx > startPos && startTokenIdx < startPos + length)
{
SizeType32 mmStartInBlock = (startPos >= startTokenIdx) ? 0 : startTokenIdx - startPos;
extraKeys.emplace_back(mmHashArray, mmStartInBlock);
}
}

return extraKeys;
}

std::vector<BlockKey> buildBlockKeys(
std::list<VecUniqueTokens>& blockedUniqueTokens, tensorrt_llm::batch_manager::LlmRequest const& llmRequest)
{
std::vector<BlockKey> blockKeys;

SizeType32 currentTokenIdx = 0;
for (auto& uniqueTokens : blockedUniqueTokens)
{
blockKeys.emplace_back(
llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(), std::move(uniqueTokens));
auto extraKeys = generateBlockHashExtraKeys(llmRequest, currentTokenIdx, currentTokenIdx + uniqueTokens.size());
currentTokenIdx += uniqueTokens.size();

blockKeys.emplace_back(llmRequest.getInputTokensExtraIds().has_value(), llmRequest.getLoraTaskId(),
std::move(uniqueTokens), std::move(extraKeys));
}
return blockKeys;
}
Expand All @@ -92,9 +160,11 @@ std::vector<BlockKey> buildBlockKeys(

namespace tensorrt_llm::batch_manager::kv_cache_manager
{

size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) noexcept
{
// Hashing algorithm adapted from StackOverflow:
// https://stackoverflow.com/questions/664014/what-integer-hash-function-are-good-that-accepts-an-integer-hash-key
// Constants provide very good distribution - each input bit affects each output bit with ~50% probability.
size_t seed = blockKey.uniqueTokens.size() ^ parentHash * UINT64_C(0xbf58476d1ce4e5b9);

for (auto const& uniqueToken : blockKey.uniqueTokens)
Expand Down Expand Up @@ -122,7 +192,36 @@ size_t BlockKeyHasher::hash(BlockKey const& blockKey, std::size_t parentHash) no
c = c ^ (c >> 31);
seed ^= c + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
// TODO: support external hashes for multimodal

// Add extra keys for multimodal data mixing in external multimodal item hash and token offset within this sequence
// block
if (!blockKey.extraKeys.empty())
{
for (auto const& [mmHash, startOffset] : blockKey.extraKeys)
{
// Hash the multimodal hash array in 32-bit chunks (more efficient)
for (size_t i = 0; i < 32; i += 4)
{
// Combine 4 bytes into a 32-bit word (construct as little endian order)
uint32_t word = static_cast<uint32_t>(mmHash[i]) | (static_cast<uint32_t>(mmHash[i + 1]) << 8)
| (static_cast<uint32_t>(mmHash[i + 2]) << 16) | (static_cast<uint32_t>(mmHash[i + 3]) << 24);

// Mix the word into the seed
word = ((word >> 16) ^ word) * 0x45d9f3b;
word = ((word >> 16) ^ word) * 0x45d9f3b;
word = (word >> 16) ^ word;
seed ^= word + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}

// Hash the start offset
uint64_t e = static_cast<uint64_t>(startOffset);
e = (e ^ (e >> 30)) * UINT64_C(0xbf58476d1ce4e5b9);
e = (e ^ (e >> 27)) * UINT64_C(0x94d049bb133111eb);
e = e ^ (e >> 31);
seed ^= e + 0x9e3779b9 + (seed << 6) + (seed >> 2);
}
}

return seed;
}

Expand Down
176 changes: 176 additions & 0 deletions cpp/tests/unit_tests/batch_manager/kvCacheManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1034,6 +1034,182 @@ TEST_F(KVCacheManagerTest, BlockManagerReuseWithExtraIdTest)
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
}

TEST_F(KVCacheManagerTest, BlockManagerReuseWithMultimodalHashTest)
{
using VecTokenExtraIds = LlmRequest::VecTokenExtraIds;

auto constexpr numLayers = 12;
auto constexpr numKvHeads = 6;
auto constexpr sizePerHead = 16;
auto constexpr tokensPerBlock = 4;
auto constexpr maxBlocksPerSeq = 4;
auto constexpr blocksInPrimaryPool = 16;
auto constexpr blocksInSecondaryPool = 0;
auto constexpr maxNumSequences = 8;
auto const stream = std::make_shared<tr::CudaStream>();
auto constexpr onboardBlocks = true;
auto constexpr numReturnSequences = 1;
auto constexpr maxAttentionWindow = tokensPerBlock * maxBlocksPerSeq;
auto constexpr beamWidth = 1;

auto const blocksPerWindow = BlocksPerWindow{{maxAttentionWindow, {blocksInPrimaryPool, blocksInSecondaryPool}}};

BlockManager blockManager(std::vector(numLayers, numKvHeads), sizePerHead, tokensPerBlock, blocksPerWindow,
maxNumSequences, stream, maxAttentionWindow, beamWidth,
std::vector<BlockManager::SizeType32>{maxAttentionWindow}, std::nullopt, nvinfer1::DataType::kHALF, 0,
onboardBlocks);
blockManager.allocatePools(false);

EXPECT_EQ(blockManager.getTokensPerBlock(), tokensPerBlock);
EXPECT_EQ(blockManager.getMaxNumBlocks(), blocksInPrimaryPool);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);

SizeType32 constexpr maxNewTokens{0};
tr::SamplingConfig const samplingConfig{beamWidth};
bool constexpr isStreaming{false};

// Create multimodal hash data (256-bit hash = 8 int32 values)
auto multimodalHashes = std::make_shared<std::vector<std::vector<SizeType32>>>(std::vector<std::vector<SizeType32>>{
{0x12345678, -0x6F543211, 0x11111111, 0x22222222, 0x33333333, 0x44444444, 0x55555555, 0x66666666} // Hash 1
});
auto multimodalPositions
= std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2}); // Start at token 2
auto multimodalLengths = std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{4}); // Length 4 tokens
// assume prompt id starts from 100
auto inputTokens = std::make_shared<VecTokens>(VecTokens{100, 101, 102, 103, 104, 105, 0, 1, 2});
auto const inputLength = static_cast<SizeType32>(inputTokens->size());
LlmRequest::RequestIdType requestId{0};
auto llmRequest0 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);

GenerationRequest seq0{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};

///////////////////////////////////////////////////////////////////////////
// add request and then remove it
auto constexpr beamIdx = 0;
auto promptLen0 = llmRequest0->getNumTokens(beamIdx);
auto numContextBlocks0 = tc::ceilDiv(promptLen0, blockManager.getTokensPerBlock());
blockManager.addSequence(seq0, promptLen0, numContextBlocks0, *llmRequest0, maxAttentionWindow);
EXPECT_EQ(llmRequest0->getContextCurrentPosition(), 0);
EXPECT_THAT(seq0.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 2}));
llmRequest0->addNewToken(3, beamIdx);
llmRequest0->addNewToken(4, beamIdx);
auto numTokens = llmRequest0->getNumTokens(beamIdx);
auto numBlocks = tc::ceilDiv(numTokens, tokensPerBlock);
EXPECT_EQ(numBlocks, 3);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);

// Input: [100, 101, 102, 103, 104, 105, 0, 1, 2] (9 tokens)
// Multimodal: starts at token 2, length 4 → [102, 103, 104, 105]

// Block 0: [100, 101, 102, 103] ← Contains multimodal (102, 103)
// Block 1: [104, 105, 0, 1] ← Contains multimodal (104, 105)
// Block 2: [2, 3, 4] ← No multimodal
blockManager.releaseBlocks(seq0, llmRequest0);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);

///////////////////////////////////////////////////////////////////////////
// new request with same tokens and same multimodal hash - should reuse
requestId = 1;
auto llmRequest1 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
multimodalHashes, multimodalPositions, multimodalLengths, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
GenerationRequest seq1{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};

// should reuse blocks 0, 1 and get new block 3
auto promptLen1 = llmRequest1->getNumTokens(beamIdx);
auto numContextBlocks1 = tc::ceilDiv(promptLen1, blockManager.getTokensPerBlock());
blockManager.addSequence(seq1, promptLen1, numContextBlocks1, *llmRequest1, maxAttentionWindow);
EXPECT_EQ(llmRequest1->getContextCurrentPosition(), 2 * tokensPerBlock);
EXPECT_THAT(seq1.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 1, 3}));
llmRequest1->addNewToken(3, beamIdx);
llmRequest1->addNewToken(4, beamIdx);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);
// block 3 matches block 2 and will be freed
blockManager.releaseBlocks(seq1, llmRequest1);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);

///////////////////////////////////////////////////////////////////////////
// Test Case 2: Different multimodal hash
requestId = 2;
auto multimodalHashes2
= std::make_shared<std::vector<std::vector<SizeType32>>>(std::vector<std::vector<SizeType32>>{
{0x45678123, 0x23456789, 0x34567890, 0x12121212, 0x56565656, 0x78787878, 0x54545454, 0x67676767} // Hash 2
});
auto multimodalPositions2
= std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2}); // Start at token 2
auto multimodalLengths2 = std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{4}); // Length 4 tokens
auto llmRequest2 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
multimodalHashes2, multimodalPositions2, multimodalLengths2, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);

GenerationRequest seq2{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// no reuse, get new blocks 4, 5, 6
auto promptLen2 = llmRequest2->getNumTokens(beamIdx);
auto numContextBlocks2 = tc::ceilDiv(promptLen2, blockManager.getTokensPerBlock());
blockManager.addSequence(seq2, promptLen2, numContextBlocks2, *llmRequest2, maxAttentionWindow);
EXPECT_EQ(llmRequest2->getContextCurrentPosition(), 0);
EXPECT_THAT(seq2.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({4, 5, 6}));
llmRequest2->addNewToken(9, beamIdx);
numTokens = llmRequest2->getNumTokens(beamIdx);
numBlocks = tc::ceilDiv(numTokens, tokensPerBlock);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks);

///////////////////////////////////////////////////////////////////////////
// Test Case 3: Multiple multimodal hashes and partial reuse
requestId = 3;
auto multimodalHashes3
= std::make_shared<std::vector<std::vector<SizeType32>>>(std::vector<std::vector<SizeType32>>{
{0x12345678, -0x6F543211, 0x11111111, 0x22222222, 0x33333333, 0x44444444, 0x55555555, 0x66666666}, // Hash 1
{0x45678123, 0x23456789, 0x34567890, 0x12121212, 0x56565656, 0x78787878, 0x54545454, 0x67676767} // Hash 2
});
auto multimodalPositions3
= std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2, 4}); // Start at token 2 and 4
auto multimodalLengths3
= std::make_shared<std::vector<SizeType32>>(std::vector<SizeType32>{2, 2}); // Length 2 tokens

auto llmRequest3 = std::make_shared<LlmRequest>(requestId, maxNewTokens, inputTokens, samplingConfig, isStreaming,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt,
multimodalHashes3, multimodalPositions3, multimodalLengths3, std::nullopt, std::nullopt, std::nullopt,
std::nullopt, std::nullopt, std::nullopt, std::nullopt, std::nullopt, false, false, false, std::nullopt,
std::nullopt, false, std::nullopt, false, std::nullopt, false, std::nullopt, 0.5, std::nullopt, std::nullopt,
std::nullopt, LlmRequestType::LLMREQUEST_TYPE_CONTEXT_AND_GENERATION, std::nullopt, numReturnSequences);
GenerationRequest seq3{requestId, inputLength, beamWidth, blockManager.getWindowSizesMetadata()};
// reuse block 0, get new blocks 7, 8
auto promptLen3 = llmRequest3->getNumTokens(beamIdx);
auto numContextBlocks3 = tc::ceilDiv(promptLen3, blockManager.getTokensPerBlock());
blockManager.addSequence(seq3, promptLen3, numContextBlocks3, *llmRequest3, maxAttentionWindow);
EXPECT_EQ(llmRequest3->getContextCurrentPosition(),
tokensPerBlock); // only reuse block 0 [100, 101, 102, 103] with same hash/offset
EXPECT_THAT(seq3.getCacheBlockIds(maxAttentionWindow).at(beamIdx), ::testing::ElementsAreArray({0, 7, 8}));
llmRequest3->addNewToken(11, beamIdx);
numTokens = llmRequest3->getNumTokens(beamIdx);
numBlocks = tc::ceilDiv(numTokens, tokensPerBlock);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), numBlocks * 2);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool - numBlocks * 2);

// clean up
blockManager.releaseBlocks(seq2, llmRequest2);
blockManager.releaseBlocks(seq3, llmRequest3);
EXPECT_EQ(blockManager.getNumAllocatedBlocks(), 0);
EXPECT_EQ(blockManager.getNumFreeBlocks(), blocksInPrimaryPool);
}

TEST_F(KVCacheManagerTest, BlockManagerReuseWithLoraTaskIdTest)
{
// tc::Logger::getLogger()->setLevel(tc::Logger::Level::DEBUG);
Expand Down
Loading