@@ -1372,49 +1372,30 @@ void invokeGetPackedMaskFromPath(int32_t* specDecodingPackedMasks, SizeType32 co
13721372namespace
13731373{
13741374template <int BLOCK_SIZE>
1375- __global__ void augmentBatchSlotsKernel (SizeType32* augmentedSeqSlots, SizeType32* augmentedBatchSlots,
1376- SizeType32 const * chunkedContextNextTokens, SizeType32 const * lastDraftLens, SizeType32 const * seqSlots,
1377- SizeType32 const * batchSlots, SizeType32 actualBatchSize)
1375+ __global__ void augmentBatchSlotsKernel (SizeType32* augmentedSeqSlots, SizeType32 const * chunkedContextNextTokens,
1376+ SizeType32 const * lastDraftLens, SizeType32 const * seqSlots, SizeType32 engineBatchSize)
13781377{
1379- typedef cub::BlockScan<SizeType32, BLOCK_SIZE> BlockScan;
1380- __shared__ typename BlockScan::TempStorage tempStorage;
1381-
13821378 auto const batchIdx = static_cast <SizeType32>(threadIdx .x );
1383- auto const valid = batchIdx < actualBatchSize ;
1379+ auto const valid = batchIdx < engineBatchSize ;
13841380
1385- bool needDecoding{false };
13861381 if (valid)
13871382 {
13881383 auto const draftLen = lastDraftLens[batchIdx];
1389- needDecoding = (draftLen == 0 && chunkedContextNextTokens[batchIdx] == -1 ) || (draftLen > 0 );
1390- }
1391-
1392- SizeType32 originalIndex{0 };
1393- BlockScan (tempStorage).ExclusiveSum (needDecoding, originalIndex);
1394-
1395- if (needDecoding)
1396- {
1397- augmentedSeqSlots[batchIdx] = seqSlots[batchIdx];
1398- augmentedBatchSlots[batchIdx] = batchSlots[originalIndex];
1399- }
1400- else if (valid)
1401- {
1402- augmentedSeqSlots[batchIdx] = -1 ;
1403- augmentedBatchSlots[batchIdx] = -1 ;
1384+ auto const needDecoding = (draftLen == 0 && chunkedContextNextTokens[batchIdx] == -1 ) || (draftLen > 0 );
1385+ augmentedSeqSlots[batchIdx] = needDecoding ? seqSlots[batchIdx] : -1 ;
14041386 }
14051387}
14061388} // namespace
14071389
1408- void invokeAugmentBatchSlots (SizeType32* augmentedSeqSlots, SizeType32* augmentedBatchSlots,
1409- runtime::SizeType32 const * chunkedContextNextTokens, runtime::SizeType32 const * lastDraftLens,
1410- SizeType32 const * seqSlots, SizeType32 const * batchSlots, SizeType32 actualBatchSize, SizeType32 batchSize,
1411- cudaStream_t stream)
1390+ void invokeAugmentBatchSlots (SizeType32* augmentedSeqSlots, runtime::SizeType32 const * chunkedContextNextTokens,
1391+ runtime::SizeType32 const * lastDraftLens, SizeType32 const * seqSlots, SizeType32 engineBatchSize,
1392+ SizeType32 batchSize, cudaStream_t stream)
14121393{
14131394 SizeType32 constexpr BLOCK_SIZE = 512 ;
14141395 TLLM_CHECK_WITH_INFO (
1415- actualBatchSize <= BLOCK_SIZE, " Batch size larger than %d is not supported for EAGLE yet" , batchSize);
1416- augmentBatchSlotsKernel<BLOCK_SIZE><<<1 , BLOCK_SIZE, 0 , stream>>> (augmentedSeqSlots, augmentedBatchSlots,
1417- chunkedContextNextTokens, lastDraftLens, seqSlots, batchSlots, actualBatchSize );
1396+ engineBatchSize <= BLOCK_SIZE, " Batch size larger than %d is not supported for EAGLE yet" , batchSize);
1397+ augmentBatchSlotsKernel<BLOCK_SIZE><<<1 , BLOCK_SIZE, 0 , stream>>> (
1398+ augmentedSeqSlots, chunkedContextNextTokens, lastDraftLens, seqSlots, engineBatchSize );
14181399}
14191400
14201401namespace
0 commit comments