Skip to content

Commit c282bf2

Browse files
Funatiqdominicshanshan
authored andcommitted
refactor: remove batch_manager::KvCacheConfig and use executor::KvCacheConfig instead (NVIDIA#5384)
Signed-off-by: Robin Kobus <[email protected]>
1 parent cb904f4 commit c282bf2

File tree

19 files changed

+146
-231
lines changed

19 files changed

+146
-231
lines changed

cpp/include/tensorrt_llm/batch_manager/kvCacheConfig.h

Lines changed: 0 additions & 107 deletions
This file was deleted.

cpp/include/tensorrt_llm/batch_manager/kvCacheManager.h

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,10 +16,11 @@
1616

1717
#pragma once
1818

19-
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
2019
#include "tensorrt_llm/batch_manager/kvCacheEventManager.h"
20+
#include "tensorrt_llm/batch_manager/kvCacheType.h"
2121
#include "tensorrt_llm/batch_manager/llmRequest.h" // TODO forward declare
2222
#include "tensorrt_llm/common/optionalRef.h"
23+
#include "tensorrt_llm/executor/executor.h"
2324
#include "tensorrt_llm/kernels/kvCacheIndex.h"
2425
#include "tensorrt_llm/runtime/bufferManager.h"
2526
#include "tensorrt_llm/runtime/common.h"
@@ -1309,7 +1310,7 @@ class BaseKVCacheManager
13091310
/// @param config KV cache configuration parameters
13101311
/// @return Tuple containing the {.freePrimaryMemBytes, .freeSecondaryMemBytes}
13111312
[[nodiscard]] static std::tuple<uint64_t, uint64_t> calculateFreeMemBytes(
1312-
runtime::BufferManager const& bufferManager, KvCacheConfig const& config);
1313+
runtime::BufferManager const& bufferManager, executor::KvCacheConfig const& config);
13131314

13141315
/// @brief Calculate the maximum number of KV cache blocks that can be allocated based on available GPU memory.
13151316
/// @details This function computes how many blocks each WindowBlockManager should receive based on the weighted
@@ -1327,8 +1328,8 @@ class BaseKVCacheManager
13271328
/// @param extraCostMemory Additional memory cost to account for CacheTransBufferManager::preAllocBufferSize
13281329
/// @param kvFactor Factor for KV cache size calculation (typically 2 for key+value)
13291330
/// @return Map from window size to tuple of (primary blocks, secondary blocks)
1330-
[[nodiscard]] static BlocksPerWindow calculateMaxNumBlocks(KvCacheConfig const& config, bool isCrossAttention,
1331-
nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig,
1331+
[[nodiscard]] static BlocksPerWindow calculateMaxNumBlocks(executor::KvCacheConfig const& config,
1332+
bool isCrossAttention, nvinfer1::DataType dtype, tensorrt_llm::runtime::ModelConfig const& modelConfig,
13321333
tensorrt_llm::runtime::WorldConfig const& worldConfig,
13331334
std::map<SizeType32, std::vector<SizeType32>> const& windowSizeToLayers, uint64_t allottedPrimaryMemBytes,
13341335
uint64_t allottedSecondaryMemBytes, size_t extraCostMemory, SizeType32 kvFactor);
Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
/*
2+
* Copyright (c) 2022-2024, NVIDIA CORPORATION. All rights reserved.
3+
*
4+
* Licensed under the Apache License, Version 2.0 (the "License");
5+
* you may not use this file except in compliance with the License.
6+
* You may obtain a copy of the License at
7+
*
8+
* http://www.apache.org/licenses/LICENSE-2.0
9+
*
10+
* Unless required by applicable law or agreed to in writing, software
11+
* distributed under the License is distributed on an "AS IS" BASIS,
12+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13+
* See the License for the specific language governing permissions and
14+
* limitations under the License.
15+
*/
16+
17+
#pragma once
18+
19+
namespace tensorrt_llm::batch_manager::kv_cache_manager
20+
{
21+
22+
enum class CacheType
23+
{
24+
kSELF = 0,
25+
kCROSS = 1,
26+
kSELFKONLY = 2,
27+
};
28+
29+
} // namespace tensorrt_llm::batch_manager::kv_cache_manager

cpp/include/tensorrt_llm/batch_manager/transformerBuffers.h

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
#pragma once
1818

1919
#include "tensorrt_llm/batch_manager/common.h"
20-
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
20+
#include "tensorrt_llm/batch_manager/kvCacheType.h"
2121
#include "tensorrt_llm/runtime/bufferManager.h"
2222
#include "tensorrt_llm/runtime/iTensor.h"
2323
#include "tensorrt_llm/runtime/modelConfig.h"
@@ -43,7 +43,6 @@ class TransformerBuffers
4343
using SizeType32 = runtime::SizeType32;
4444
using TensorPtr = runtime::ITensor::SharedPtr;
4545
using TensorMap = runtime::StringPtrMap<runtime::ITensor>;
46-
using KvCacheType = batch_manager::kv_cache_manager::CacheType;
4746

4847
static constexpr auto kCrossAttentionMaskTensorName = "cross_attention_mask";
4948
static constexpr auto kCrossAttentionPackedMaskTensorName = "cross_attention_packed_mask";

cpp/include/tensorrt_llm/executor/executor.h

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -991,15 +991,17 @@ class SchedulerConfig
991991
class KvCacheConfig
992992
{
993993
public:
994+
static constexpr auto kDefaultGpuMemFraction = 0.9F;
995+
994996
explicit KvCacheConfig(bool enableBlockReuse = true, std::optional<SizeType32> const& maxTokens = std::nullopt,
995997
std::optional<std::vector<SizeType32>> const& maxAttentionWindowVec = std::nullopt,
996998
std::optional<SizeType32> const& sinkTokenLength = std::nullopt,
997999
std::optional<FloatType> const& freeGpuMemoryFraction = std::nullopt,
9981000
std::optional<size_t> const& hostCacheSize = std::nullopt, bool onboardBlocks = true,
9991001
std::optional<FloatType> const& crossKvCacheFraction = std::nullopt,
10001002
std::optional<RetentionPriority> secondaryOffloadMinPriority = std::nullopt, size_t eventBufferMaxSize = 0,
1001-
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt,
1002-
bool enablePartialReuse = true, bool copyOnPartialReuse = true);
1003+
bool enablePartialReuse = true, bool copyOnPartialReuse = true, bool useUvm = false,
1004+
std::optional<tensorrt_llm::runtime::RuntimeDefaults> const& runtimeDefaults = std::nullopt);
10031005

10041006
[[nodiscard]] bool getEnableBlockReuse() const;
10051007
[[nodiscard]] bool getEnablePartialReuse() const;
@@ -1013,6 +1015,7 @@ class KvCacheConfig
10131015
[[nodiscard]] bool getOnboardBlocks() const;
10141016
[[nodiscard]] std::optional<RetentionPriority> getSecondaryOffloadMinPriority() const;
10151017
[[nodiscard]] size_t getEventBufferMaxSize() const;
1018+
[[nodiscard]] bool getUseUvm() const;
10161019

10171020
void setEnableBlockReuse(bool enableBlockReuse);
10181021
void setEnablePartialReuse(bool enablePartialReuse);
@@ -1026,7 +1029,9 @@ class KvCacheConfig
10261029
void setOnboardBlocks(bool onboardBlocks);
10271030
void setSecondaryOffloadMinPriority(std::optional<RetentionPriority> secondaryOffloadMinPriority);
10281031
void setEventBufferMaxSize(size_t eventBufferMaxSize);
1029-
void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults runtimeDefaults);
1032+
void setUseUvm(bool useUvm);
1033+
1034+
void fillEmptyFieldsFromRuntimeDefaults(tensorrt_llm::runtime::RuntimeDefaults const& runtimeDefaults);
10301035

10311036
private:
10321037
friend class Serialization;
@@ -1077,6 +1082,9 @@ class KvCacheConfig
10771082

10781083
/// @brief Whether partially matched blocks that are in use can be reused after copying them
10791084
bool mCopyOnPartialReuse;
1085+
1086+
/// @brief Whether to use UVM for the KV cache.
1087+
bool mUseUvm;
10801088
};
10811089

10821090
/// @brief Configuration class for the runtime perf knobs

cpp/tensorrt_llm/batch_manager/kvCacheManager.cpp

Lines changed: 13 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2057,32 +2057,33 @@ std::map<SizeType32, std::vector<SizeType32>> BaseKVCacheManager::groupLayersByW
20572057
}
20582058

20592059
std::tuple<uint64_t, uint64_t> BaseKVCacheManager::calculateFreeMemBytes(
2060-
runtime::BufferManager const& bufferManager, KvCacheConfig const& config)
2060+
runtime::BufferManager const& bufferManager, executor::KvCacheConfig const& config)
20612061
{
2062-
auto const freeMemFraction = config.freeGpuMemoryFraction.value_or(KvCacheConfig::kDefaultGpuMemFraction);
2062+
auto const freeMemFraction
2063+
= config.getFreeGpuMemoryFraction().value_or(executor::KvCacheConfig::kDefaultGpuMemFraction);
20632064
TLLM_CHECK_WITH_INFO(freeMemFraction < 1.0F,
20642065
"Invalid freeMemFraction, freeMemFraction (%f) must be smaller than 1.0f", freeMemFraction);
2065-
if (config.maxTokens.has_value())
2066+
if (config.getMaxTokens().has_value())
20662067
{
2067-
if (config.freeGpuMemoryFraction.has_value())
2068+
if (config.getFreeGpuMemoryFraction().has_value())
20682069
{
20692070
TLLM_LOG_WARNING(
20702071
"Both freeGpuMemoryFraction (aka kv_cache_free_gpu_mem_fraction) "
20712072
"and maxTokens (aka max_tokens_in_paged_kv_cache) "
20722073
"are set (to %f and %ld, respectively). The smaller value will be used.",
2073-
freeMemFraction, (int64_t) config.maxTokens.value());
2074+
freeMemFraction, (int64_t) config.getMaxTokens().value());
20742075
}
20752076
}
20762077

20772078
TLLM_CUDA_CHECK(::cudaDeviceSynchronize());
2078-
auto const [freeMem, totalMem] = tc::getDeviceMemoryInfo(config.useUvm);
2079+
auto const [freeMem, totalMem] = tc::getDeviceMemoryInfo(config.getUseUvm());
20792080
auto const finalFreeMem = freeMem + bufferManager.memoryPoolFree();
20802081
TLLM_LOG_INFO("Memory usage when calculating max tokens in paged kv cache: total: %0.2f GiB, available: %0.2f GiB",
20812082
totalMem / static_cast<double>(1 << 30), finalFreeMem / static_cast<double>(1 << 30));
20822083
TLLM_CHECK_WITH_INFO(finalFreeMem <= totalMem, "Free memory cannot exceed total memory");
20832084

20842085
auto const freePrimaryMemBytes = static_cast<uint64_t>(finalFreeMem * freeMemFraction);
2085-
auto const freeSecondaryMemBytes = config.hostCacheSize.value_or(0);
2086+
auto const freeSecondaryMemBytes = config.getHostCacheSize().value_or(0);
20862087

20872088
TLLM_LOG_DEBUG("Calculated free memory: {.freePrimaryMemBytes=%" PRIu64 ", .freeSecondaryMemBytes=%" PRIu64 "}",
20882089
freePrimaryMemBytes, freeSecondaryMemBytes);
@@ -2120,7 +2121,7 @@ bool isSortedVectorIdenticalAcrossAllRanks(WorldConfig const& worldConfig, std::
21202121
}
21212122
} // namespace
21222123

2123-
BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(KvCacheConfig const& config, bool isCrossAttention,
2124+
BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(executor::KvCacheConfig const& config, bool isCrossAttention,
21242125
nvinfer1::DataType dtype, ModelConfig const& modelConfig, WorldConfig const& worldConfig,
21252126
std::map<SizeType32, std::vector<SizeType32>> const& windowSizeToLayers, uint64_t allottedPrimaryMemBytes,
21262127
uint64_t allottedSecondaryMemBytes, size_t extraCostMemory, SizeType32 kvFactor)
@@ -2130,7 +2131,7 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(KvCacheConfig const& c
21302131
isCrossAttention ? "Cross KvCacheManager" : "Self KvCacheManager", allottedPrimaryMemBytes,
21312132
allottedSecondaryMemBytes);
21322133

2133-
if (config.maxTokens.has_value() && windowSizeToLayers.size() > 1)
2134+
if (config.getMaxTokens().has_value() && windowSizeToLayers.size() > 1)
21342135
{
21352136
TLLM_LOG_WARNING(
21362137
"Setting maxTokens when using Variable Sliding Window Attention is a strange concept, as it limits "
@@ -2162,9 +2163,9 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(KvCacheConfig const& c
21622163
TLLM_LOG_DEBUG("windowSizeShare: %f, cacheSizeBytesPerToken: %d", windowSizeShare, cacheSizeBytesPerToken);
21632164
auto maxTokens = static_cast<uint64_t>(
21642165
allottedPrimaryMemBytes * windowSizeShare / static_cast<double>(cacheSizeBytesPerToken));
2165-
if (config.maxTokens.has_value())
2166+
if (config.getMaxTokens().has_value())
21662167
{
2167-
auto const maxTokensFromConfig = static_cast<uint64_t>(config.maxTokens.value());
2168+
auto const maxTokensFromConfig = static_cast<uint64_t>(config.getMaxTokens().value());
21682169
TLLM_LOG_DEBUG("Maximum kv-cache token overridden by configuration as '%ld'.", maxTokensFromConfig);
21692170
maxTokens = std::min(maxTokensFromConfig, maxTokens);
21702171
}
@@ -2184,7 +2185,7 @@ BlocksPerWindow BaseKVCacheManager::calculateMaxNumBlocks(KvCacheConfig const& c
21842185
TLLM_LOG_DEBUG(
21852186
"Number of blocks in KV cache secondary pool for windowSize %d: %d, onboard blocks to primary memory "
21862187
"before reuse: %s",
2187-
windowSize, blocksInSecondaryPool, config.onboardBlocks ? "true" : "false");
2188+
windowSize, blocksInSecondaryPool, config.getOnboardBlocks() ? "true" : "false");
21882189
return blocksInSecondaryPool;
21892190
};
21902191

cpp/tensorrt_llm/batch_manager/transformerBuffers.cpp

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717

1818
#include "tensorrt_llm/batch_manager/transformerBuffers.h"
1919

20-
#include "tensorrt_llm/batch_manager/kvCacheConfig.h"
2120
#include "tensorrt_llm/batch_manager/kvCacheManager.h"
2221
#include "tensorrt_llm/common/assert.h"
2322
#include "tensorrt_llm/common/logger.h"
@@ -221,7 +220,7 @@ void TransformerBuffers::reshapeKvTensors(SizeType32 maxBatchSize, SizeType32 ma
221220
TLLM_LOG_TRACE("%s start", __PRETTY_FUNCTION__);
222221

223222
// allocate with max shape during init
224-
if (kvCacheType == KvCacheType::kSELF)
223+
if (kvCacheType == kv_cache_manager::CacheType::kSELF)
225224
{
226225
auto const cacheBlockOffsetsShape
227226
= ITensor::makeShape({numPools, maxBatchSize * maxBeamWidth, 2, maxBlocksPerSeq});
@@ -232,7 +231,7 @@ void TransformerBuffers::reshapeKvTensors(SizeType32 maxBatchSize, SizeType32 ma
232231
kvCacheBlockOffsetsDevice->reshape(cacheBlockOffsetsShape);
233232
manager.setZero(*kvCacheBlockOffsetsDevice);
234233
}
235-
else if (kvCacheType == KvCacheType::kCROSS)
234+
else if (kvCacheType == kv_cache_manager::CacheType::kCROSS)
236235
{
237236
auto const crossCacheBlockOffsetsShape
238237
= ITensor::makeShape({numPools, maxBatchSize * maxBeamWidth, 2, maxBlocksPerSeq});

0 commit comments

Comments
 (0)