Skip to content

Commit 6d4b045

Browse files
authored
refactor: Remove enforced sorted order of batch slots (#3502)
Signed-off-by: Robin Kobus <[email protected]>
1 parent f5f5be9 commit 6d4b045

26 files changed

+103
-167
lines changed

cpp/include/tensorrt_llm/batch_manager/runtimeBuffers.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -232,13 +232,9 @@ class RuntimeBuffers
232232

233233
GenerationLogitsCache generationLogitsCache;
234234

235-
//! Helper for KV cache rewind
235+
//! Mapping from batch idx to slot id
236236
TensorPtr seqSlots;
237237
TensorPtr seqSlotsDevice;
238-
TensorPtr sortedSeqSlots;
239-
//! For KV cache rewind
240-
TensorPtr seqSlotRemappingHost; // [numSequences]
241-
TensorPtr seqSlotRemappingDevice; // [numSequences]
242238

243239
//! Explicitly device-copy src offsets to reduce warp stalls in copy batch kernel invocation
244240
//! [mMaxNumRequests], on gpu

cpp/tensorrt_llm/batch_manager/assignReqSeqSlots.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ void tensorrt_llm::batch_manager::AssignReqSeqSlots::operator()(SequenceSlotMana
3737
llmReq->setFirstScheduledTime();
3838
}
3939
auto const reqSeqSlot = seqSlotManager.getSequenceSlot(isReqNew, llmReq->mRequestId);
40-
TLLM_CHECK_WITH_INFO(reqSeqSlot, "Unable to get batch slot for reqId");
40+
TLLM_CHECK_WITH_INFO(reqSeqSlot, "Unable to get batch slot for request ID %lu", llmReq->mRequestId);
4141
llmReq->mSeqSlot = reqSeqSlot;
4242
}
4343
}

cpp/tensorrt_llm/batch_manager/makeDecodingBatchInputOutput.cpp

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -92,29 +92,20 @@ namespace
9292
std::pair<std::vector<SizeType32>, std::vector<SizeType32>> getActiveSlots(
9393
RequestVector const& contextRequests, RequestVector const& generationRequests)
9494
{
95-
std::vector<std::pair<SizeType32, SizeType32>> slots;
95+
std::vector<SizeType32> activeSlots;
96+
std::vector<SizeType32> generationSteps;
9697
for (auto const& requests : {contextRequests, generationRequests})
9798
{
9899
for (auto const& llmReq : requests)
99100
{
100101
if (llmReq->isGenerationInProgressState() || llmReq->isLastContextChunk())
101102
{
102-
slots.push_back({llmReq->mSeqSlot.value(), llmReq->getDecodingIter()});
103+
activeSlots.push_back(llmReq->mSeqSlot.value());
104+
generationSteps.push_back(llmReq->getDecodingIter());
103105
}
104106
}
105107
}
106108

107-
std::sort(slots.begin(), slots.end(),
108-
[](std::pair<SizeType32, SizeType32> const& a, std::pair<SizeType32, SizeType32> const& b)
109-
{ return a.first < b.first; });
110-
111-
std::vector<SizeType32> activeSlots, generationSteps;
112-
for (auto const& slot : slots)
113-
{
114-
activeSlots.push_back(slot.first);
115-
generationSteps.push_back(slot.second);
116-
}
117-
118109
return {activeSlots, generationSteps};
119110
}
120111

cpp/tensorrt_llm/batch_manager/runtimeBuffers.cpp

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -104,9 +104,6 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
104104
logits = manager.emptyTensor(MemoryType::kGPU, logitsType);
105105
}
106106

107-
seqSlotRemappingHost = manager.emptyTensor(MemoryType::kPINNEDPOOL, nvinfer1::DataType::kINT32);
108-
seqSlotRemappingDevice = manager.emptyTensor(MemoryType::kGPU, nvinfer1::DataType::kINT32);
109-
110107
// TODO: check which tensors can be allocated as pinned for max size
111108
requestTypes = manager.emptyTensor(MemoryType::kCPU, TRTDataType<runtime::RequestType>::value);
112109

@@ -129,7 +126,6 @@ void RuntimeBuffers::create(SizeType32 maxBatchSize, SizeType32 maxBeamWidth,
129126
auto const maxBatchSizeShape = ITensor::makeShape({maxBatchSize});
130127
seqSlots = tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT32);
131128
seqSlotsDevice = manager.gpu(maxBatchSizeShape, nvinfer1::DataType::kINT32);
132-
sortedSeqSlots = tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT32);
133129

134130
cacheIndirDecoderIOBatchedCopySrcOffsets
135131
= tensorrt_llm::runtime::BufferManager::pinnedPool(maxBatchSizeShape, nvinfer1::DataType::kINT64);
@@ -383,9 +379,6 @@ void RuntimeBuffers::reshape(TllmRuntime const& runtime, ModelConfig const& mode
383379
auto const numRequestsShape = ITensor::makeShape({numRequests});
384380
seqSlots->reshape(numRequestsShape);
385381
seqSlotsDevice->reshape(numRequestsShape);
386-
sortedSeqSlots->reshape(numRequestsShape);
387-
seqSlotRemappingHost->reshape(numRequestsShape);
388-
seqSlotRemappingDevice->reshape(numRequestsShape);
389382

390383
auto const numTokens = getNumTokens();
391384
inputsIds->reshape(ITensor::makeShape({numTokens}));
@@ -740,20 +733,6 @@ void RuntimeBuffers::setFromInputs(RequestVector const& contextRequests, Request
740733
std::fill_n(sequenceLengthsHostPtr + numSequences, reqBeamWidth, sequenceLen);
741734
numSequences += reqBeamWidth;
742735
}
743-
if (modelConfig.getSpeculativeDecodingMode().needsKVCacheRewind())
744-
{
745-
auto remappingSeqSlotIndices = BufferRange<SizeType32>(*seqSlotRemappingHost);
746-
auto const* seqSlotIndices = bufferCast<SizeType32>(*seqSlots);
747-
748-
std::iota(remappingSeqSlotIndices.begin(), remappingSeqSlotIndices.end(), 0);
749-
std::sort(remappingSeqSlotIndices.begin(), remappingSeqSlotIndices.end(),
750-
[&seqSlotIndices](SizeType32 a, SizeType32 b) { return seqSlotIndices[a] < seqSlotIndices[b]; });
751-
manager.copy(*seqSlotRemappingHost, *seqSlotRemappingDevice);
752-
753-
manager.copy(*seqSlots, *sortedSeqSlots);
754-
auto sortedSeqSlotIndices = BufferRange<SizeType32>(*sortedSeqSlots);
755-
std::sort(sortedSeqSlotIndices.begin(), sortedSeqSlotIndices.end());
756-
}
757736
if (modelConfig.getSpeculativeDecodingMode().isLookaheadDecoding())
758737
{
759738
// copy from lookahead decoding buffer

cpp/tensorrt_llm/batch_manager/sequenceSlotManager.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -61,7 +61,7 @@ std::optional<SequenceSlotManager::SlotIdType> SequenceSlotManager::getSequenceS
6161
auto const it = mSequenceIdToSlot.find(sequenceId);
6262
if (it == mSequenceIdToSlot.end())
6363
{
64-
TLLM_LOG_ERROR("Could not find sequence id in allocated sequence slots");
64+
TLLM_LOG_ERROR("Could not find sequence id %lu in allocated sequence slots", sequenceId);
6565
}
6666
else
6767
{

cpp/tensorrt_llm/batch_manager/trtGptModelInflightBatching.cpp

Lines changed: 10 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -983,8 +983,10 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
983983
if (fittingRequests.empty() && fittingDisaggGenInitRequests.empty())
984984
{
985985
TLLM_LOG_WARNING(
986-
"CapacityScheduler didn't schedule any requests, probably because of insufficient resources such as KV "
987-
"cache, will try wait for KV cache transfer to complete");
986+
"CapacityScheduler didn't schedule any requests in iteration %lu, "
987+
"probably because of insufficient resources such as KV cache, "
988+
"will try wait for KV cache transfer to complete",
989+
mIterCounter);
988990
if (mCacheTransceiver)
989991
{
990992
mCacheTransceiver->checkContextTransferStatus(1);
@@ -1038,6 +1040,10 @@ void TrtGptModelInflightBatching::forwardAsync(RequestList const& activeRequests
10381040
auto const contextBufferId = mCtxGenFusion ? getFusedBufferId() : getContextBufferId();
10391041
setupDecoderStep(currRequests.contextRequests, *mBuffers.at(contextBufferId),
10401042
mDecoderInputBuffers.at(getFusedBufferId()));
1043+
// WAR: Sync to ensure that the decoder setup is complete before the context phase starts.
1044+
// Without this, there may be a race condition between the decoder setup and the context phase
1045+
// which also leads to spurious test failure in trtGptModelRealDecoderTest.
1046+
mRuntime->getStream().synchronize();
10411047
}
10421048
else
10431049
{
@@ -2432,9 +2438,8 @@ void TrtGptModelInflightBatching::rewindKVCacheBlocks(SizeType32 numSequences)
24322438
tensorrt_llm::runtime::kernels::invokeUpdateKVBlockArrayDraftTokenLocation(
24332439
*mDecoderState->getAcceptedLengthsCumSum(), *mDecoderState->getAcceptedPackedPaths(),
24342440
*runtimeBuffers.sequenceLengthsDevice, pointerArrayPtr, offsetArrayPtr, localNbLayers, numSequences,
2435-
mRewindInputs.numKvHeads, sizeInBytesPerKVHead, commonRewindLen, rewindLens,
2436-
*runtimeBuffers.seqSlotRemappingDevice, *runtimeBuffers.sortedSeqSlots, getMaxAttentionWindow(),
2437-
mRewindInputs.maxBlocksPerSeq, tokensPerBlock, mRewindInputs.isUseOneMoreBlock,
2441+
mRewindInputs.numKvHeads, sizeInBytesPerKVHead, commonRewindLen, rewindLens, *runtimeBuffers.seqSlots,
2442+
getMaxAttentionWindow(), mRewindInputs.maxBlocksPerSeq, tokensPerBlock, mRewindInputs.isUseOneMoreBlock,
24382443
mRuntime->getStreamPtr()->get());
24392444

24402445
sync_check_cuda_error(mRuntime->getStream().get());

cpp/tensorrt_llm/kernels/speculativeDecoding/common.cu

Lines changed: 6 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,8 @@ namespace tensorrt_llm::kernels::speculative_decoding
4040
template <int32_t BLOCK_SIZE>
4141
__global__ void packAcceptedPaths(SizeType32* acceptedLengthsCumSum, SizeType32* pathsOffsets,
4242
SizeType32 const* acceptedLengths, SizeType32 const* bestPathIds, SizeType32 const* paths,
43-
SizeType32 const* batchSlots, runtime::SizeType32 const* seqSlots, SizeType32 batchSize, SizeType32 engineBatchSize,
44-
SizeType32 numPaths, SizeType32 maxPathLen, bool isPathsSeqSlotIdx)
43+
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 engineBatchSize, SizeType32 numPaths,
44+
SizeType32 maxPathLen, bool isPathsSeqSlotIdx)
4545
{
4646
// Specialize BlockScan for a 1D block of 128 threads of type int
4747
typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
@@ -81,22 +81,7 @@ __global__ void packAcceptedPaths(SizeType32* acceptedLengthsCumSum, SizeType32*
8181
}
8282
__syncthreads();
8383

84-
int32_t pathBatchIdx{batchSlot};
85-
if (isPathsSeqSlotIdx)
86-
{
87-
// If paths tensor is the tensor arranged according to seq slot,
88-
// we must find the position of the batchSlots index in the seq slot array.
89-
// TODO optimize it.
90-
for (int bi = 0; bi < batchSize; ++bi)
91-
{
92-
auto const seqSlot = seqSlots[bi];
93-
if (batchSlot == seqSlot)
94-
{
95-
pathBatchIdx = bi;
96-
break;
97-
}
98-
}
99-
}
84+
auto const pathBatchIdx = isPathsSeqSlotIdx ? bi : batchSlot;
10085

10186
if (valid)
10287
{
@@ -117,13 +102,12 @@ __global__ void packAcceptedPaths(SizeType32* acceptedLengthsCumSum, SizeType32*
117102

118103
void invokePackAcceptedPaths(SizeType32* acceptedLengthsCumSum, SizeType32* pathsOffsets,
119104
SizeType32 const* acceptedLengths, SizeType32 const* bestPathIds, SizeType32 const* paths,
120-
SizeType32 const* batchSlots, runtime::SizeType32 const* seqSlots, SizeType32 batchSize, SizeType32 engineBatchSize,
121-
SizeType32 numPaths, SizeType32 maxPathLen, bool isPathsLinearBatchIdx, cudaStream_t stream)
105+
SizeType32 const* batchSlots, SizeType32 batchSize, SizeType32 engineBatchSize, SizeType32 numPaths,
106+
SizeType32 maxPathLen, bool isPathsSeqSlotIdx, cudaStream_t stream)
122107
{
123108
constexpr SizeType32 BLOCK_SIZE = 1024;
124109
packAcceptedPaths<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(acceptedLengthsCumSum, pathsOffsets, acceptedLengths,
125-
bestPathIds, paths, batchSlots, seqSlots, batchSize, engineBatchSize, numPaths, maxPathLen,
126-
isPathsLinearBatchIdx);
110+
bestPathIds, paths, batchSlots, batchSize, engineBatchSize, numPaths, maxPathLen, isPathsSeqSlotIdx);
127111
}
128112

129113
namespace

cpp/tensorrt_llm/kernels/speculativeDecoding/common.h

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -38,10 +38,6 @@ namespace tensorrt_llm::kernels::speculative_decoding
3838
//! everything that is not path.
3939
//! \param batchSlots input buffer [engineBatchSize], address map from local index to
4040
//! global index [0, batchSize] -> [0, maxBatchSize].
41-
//! This is in the order of increasing order of the requests in the decoder.
42-
//! \param seqSlots input buffer [engineBatchSize], address map from local index to
43-
//! global index [0, batchSize] -> [0, maxBatchSize]
44-
//! These are the slots of the sequences in the runtime buffers.
4541
//! \param batchSize the number of sequences to be decoded
4642
//! \param engineBatchSize number of sequences processed in the engine.
4743
//! Includes chunked context reqs that are not in the last chunk.
@@ -52,9 +48,9 @@ namespace tensorrt_llm::kernels::speculative_decoding
5248
//! \param stream stream
5349
void invokePackAcceptedPaths(runtime::SizeType32* acceptedLengthsCumSum, runtime::SizeType32* pathsOffsets,
5450
runtime::SizeType32 const* acceptedLengths, runtime::SizeType32 const* bestPathIds,
55-
runtime::SizeType32 const* paths, runtime::SizeType32 const* batchSlots, runtime::SizeType32 const* seqSlots,
56-
runtime::SizeType32 batchSize, runtime::SizeType32 engineBatchSize, runtime::SizeType32 numPaths,
57-
runtime::SizeType32 maxPathLen, bool isPathsSeqSlotIdx, cudaStream_t stream);
51+
runtime::SizeType32 const* paths, runtime::SizeType32 const* batchSlots, runtime::SizeType32 batchSize,
52+
runtime::SizeType32 engineBatchSize, runtime::SizeType32 numPaths, runtime::SizeType32 maxPathLen,
53+
bool isPathsSeqSlotIdx, cudaStream_t stream);
5854

5955
template <typename T>
6056
struct AcceptDraftTokensByIdsWithPathsParams

cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.cu

Lines changed: 11 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -1372,49 +1372,30 @@ void invokeGetPackedMaskFromPath(int32_t* specDecodingPackedMasks, SizeType32 co
13721372
namespace
13731373
{
13741374
template <int BLOCK_SIZE>
1375-
__global__ void augmentBatchSlotsKernel(SizeType32* augmentedSeqSlots, SizeType32* augmentedBatchSlots,
1376-
SizeType32 const* chunkedContextNextTokens, SizeType32 const* lastDraftLens, SizeType32 const* seqSlots,
1377-
SizeType32 const* batchSlots, SizeType32 actualBatchSize)
1375+
__global__ void augmentBatchSlotsKernel(SizeType32* augmentedSeqSlots, SizeType32 const* chunkedContextNextTokens,
1376+
SizeType32 const* lastDraftLens, SizeType32 const* seqSlots, SizeType32 engineBatchSize)
13781377
{
1379-
typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
1380-
__shared__ typename BlockScan::TempStorage tempStorage;
1381-
13821378
auto const batchIdx = static_cast<SizeType32>(threadIdx.x);
1383-
auto const valid = batchIdx < actualBatchSize;
1379+
auto const valid = batchIdx < engineBatchSize;
13841380

1385-
bool needDecoding{false};
13861381
if (valid)
13871382
{
13881383
auto const draftLen = lastDraftLens[batchIdx];
1389-
needDecoding = (draftLen == 0 && chunkedContextNextTokens[batchIdx] == -1) || (draftLen > 0);
1390-
}
1391-
1392-
SizeType32 originalIndex{0};
1393-
BlockScan(tempStorage).ExclusiveSum(needDecoding, originalIndex);
1394-
1395-
if (needDecoding)
1396-
{
1397-
augmentedSeqSlots[batchIdx] = seqSlots[batchIdx];
1398-
augmentedBatchSlots[batchIdx] = batchSlots[originalIndex];
1399-
}
1400-
else if (valid)
1401-
{
1402-
augmentedSeqSlots[batchIdx] = -1;
1403-
augmentedBatchSlots[batchIdx] = -1;
1384+
auto const needDecoding = (draftLen == 0 && chunkedContextNextTokens[batchIdx] == -1) || (draftLen > 0);
1385+
augmentedSeqSlots[batchIdx] = needDecoding ? seqSlots[batchIdx] : -1;
14041386
}
14051387
}
14061388
} // namespace
14071389

1408-
void invokeAugmentBatchSlots(SizeType32* augmentedSeqSlots, SizeType32* augmentedBatchSlots,
1409-
runtime::SizeType32 const* chunkedContextNextTokens, runtime::SizeType32 const* lastDraftLens,
1410-
SizeType32 const* seqSlots, SizeType32 const* batchSlots, SizeType32 actualBatchSize, SizeType32 batchSize,
1411-
cudaStream_t stream)
1390+
void invokeAugmentBatchSlots(SizeType32* augmentedSeqSlots, runtime::SizeType32 const* chunkedContextNextTokens,
1391+
runtime::SizeType32 const* lastDraftLens, SizeType32 const* seqSlots, SizeType32 engineBatchSize,
1392+
SizeType32 batchSize, cudaStream_t stream)
14121393
{
14131394
SizeType32 constexpr BLOCK_SIZE = 512;
14141395
TLLM_CHECK_WITH_INFO(
1415-
actualBatchSize <= BLOCK_SIZE, "Batch size larger than %d is not supported for EAGLE yet", batchSize);
1416-
augmentBatchSlotsKernel<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(augmentedSeqSlots, augmentedBatchSlots,
1417-
chunkedContextNextTokens, lastDraftLens, seqSlots, batchSlots, actualBatchSize);
1396+
engineBatchSize <= BLOCK_SIZE, "Batch size larger than %d is not supported for EAGLE yet", batchSize);
1397+
augmentBatchSlotsKernel<BLOCK_SIZE><<<1, BLOCK_SIZE, 0, stream>>>(
1398+
augmentedSeqSlots, chunkedContextNextTokens, lastDraftLens, seqSlots, engineBatchSize);
14181399
}
14191400

14201401
namespace

cpp/tensorrt_llm/kernels/speculativeDecoding/eagleDecodingKernels.h

Lines changed: 4 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -552,29 +552,23 @@ void invokeCopyOutputTokensIds(runtime::TokenIdType const* const* tmpOutputIdsPt
552552
runtime::SizeType32 const* inputPaths, runtime::SizeType32* outputPaths, runtime::SizeType32 maxPathLen,
553553
cudaStream_t stream);
554554

555-
//! \brief Augment seq slots and batch slots from batchSize size to engineBatchSize size.
556-
//! For seqSlot sets -1 for non-last chunks (chunkedContextNextTokens != -1).
557-
//! For batchSlots sets -1 for non-last chunks. Copies actual batch slots to the last chunk or gen requests
558-
//! positions.
555+
//! \brief Augment seq slots so that non-last chunks are set to -1 (if chunkedContextNextTokens != -1).
559556
//!
560557
//! \param augmentedSeqSlots output buffer [engineBatchSize]
561-
//! \param augmentedBatchSlots output buffer [engineBatchSize]
562558
//! \param chunkedContextNextTokens input buffer [engineBatchSize], indicator of the not last chunk of the ctx
563559
//! requests. -1 for gen requests and last chunk, != -1 otherwise.
564560
//! \param lastDraftLens input buffer [engineBatchSize], number of draft tokens input to the current iteration.
565561
//! 0 for ctx requests and > 0 for gen requests.
566562
//! \param seqSlots input buffer [engineBatchSize], address map from local index to global index [0, batchSize]
567563
//! -> [0, maxBatchSize]
568-
//! \param batchSlots input buffer [engineBatchSize], address map from local index to global index [0, batchSize]
569-
//! -> [0, maxBatchSize]
570564
//! \param engineBatchSize number of sequences processed in the engine.
571565
//! Includes chunked context reqs that are not in the last chunk.
572566
//! \param batchSize the number of sequences to be decoded
573567
//! \param stream cuda stream.
574-
void invokeAugmentBatchSlots(runtime::SizeType32* augmentedSeqSlots, runtime::SizeType32* augmentedBatchSlots,
568+
void invokeAugmentBatchSlots(runtime::SizeType32* augmentedSeqSlots,
575569
runtime::SizeType32 const* chunkedContextNextTokens, runtime::SizeType32 const* lastDraftLens,
576-
runtime::SizeType32 const* seqSlots, runtime::SizeType32 const* batchSlots, runtime::SizeType32 engineBatchSize,
577-
runtime::SizeType32 batchSize, cudaStream_t stream);
570+
runtime::SizeType32 const* seqSlots, runtime::SizeType32 engineBatchSize, runtime::SizeType32 batchSize,
571+
cudaStream_t stream);
578572

579573
//! \brief For Eagle-2, set topK tensor according to the max topK value for each request.
580574
//! And fill the batchSlots for the softMax kernel.

0 commit comments

Comments
 (0)