Skip to content

Commit

Permalink
dynamic decoding + separate scheduler
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Oct 30, 2024
1 parent 5ff955e commit 6253f1a
Show file tree
Hide file tree
Showing 6 changed files with 88 additions and 31 deletions.
41 changes: 31 additions & 10 deletions src/llm/apis/openai_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ absl::Status OpenAIChatCompletionsHandler::parseChatCompletionsPart() {
return absl::OkStatus();
}

absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLimit, uint32_t bestOfLimit) {
absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline) {
OVMS_PROFILE_FUNCTION();
// stream: bool; optional
if (!doc.IsObject())
Expand Down Expand Up @@ -349,14 +349,35 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim
}

// Speculative decoding specific parameters

// num_assistant_tokens: uint; optional - defaults to 0
it = doc.FindMember("num_assistant_tokens");
if (it != doc.MemberEnd()) {
if (!it->value.IsUint()) {
return absl::InvalidArgumentError("num_assistant_tokens must be an unsigned integer");

auto numAssistantTokensIt = doc.FindMember("num_assistant_tokens");
auto assistantConfidenceThresholdIt = doc.FindMember("assistant_confidence_threshold");

if (isSpeculativePipeline) {
if (numAssistantTokensIt == doc.MemberEnd() && assistantConfidenceThresholdIt == doc.MemberEnd())
return absl::InvalidArgumentError("Speculative decoding requires either num_assistant_tokens or assistant_confidence_threshold to be set.");

if (numAssistantTokensIt != doc.MemberEnd() && assistantConfidenceThresholdIt != doc.MemberEnd())
return absl::InvalidArgumentError("num_assistant_tokens and assistant_confidence_threshold are mutually exclusive and cannot both be set.");
} else if (numAssistantTokensIt != doc.MemberEnd() || assistantConfidenceThresholdIt != doc.MemberEnd()) {
return absl::InvalidArgumentError("num_assistant_tokens and assistant_confidence_threshold are only supported when speculative decoding is enabled.");
}
// num_assistant_tokens: uint;
if (numAssistantTokensIt != doc.MemberEnd()) {
if (!numAssistantTokensIt->value.IsUint() || numAssistantTokensIt->value.GetUint() == 0) {
return absl::InvalidArgumentError("num_assistant_tokens must be an unsigned integer greater than 0");
}
request.numAssistantTokens = numAssistantTokensIt->value.GetUint();
}
// assistant_confidence_threshold: float;
if (assistantConfidenceThresholdIt != doc.MemberEnd()) {
if (!assistantConfidenceThresholdIt->value.IsDouble() && !assistantConfidenceThresholdIt->value.IsInt()) {
return absl::InvalidArgumentError("assistant_confidence_threshold must be a positive number");
}
request.assistantConfidenceThreshold = assistantConfidenceThresholdIt->value.GetDouble();
if (request.assistantConfidenceThreshold <= 0.0) {
return absl::InvalidArgumentError("assistant_confidence_threshold must be greater than 0");
}
request.numAssistantTokens = it->value.GetUint();
}

// use_beam_search: bool; optional - defaults to false
Expand Down Expand Up @@ -401,8 +422,8 @@ ov::genai::GenerationConfig OpenAIChatCompletionsHandler::createGenerationConfig
return request.createGenerationConfig();
}

absl::Status OpenAIChatCompletionsHandler::parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit) {
absl::Status status = parseCommonPart(maxTokensLimit, bestOfLimit);
absl::Status OpenAIChatCompletionsHandler::parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline) {
absl::Status status = parseCommonPart(maxTokensLimit, bestOfLimit, isSpeculativePipeline);

if (status != absl::OkStatus())
return status;
Expand Down
38 changes: 23 additions & 15 deletions src/llm/apis/openai_completions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,30 +60,36 @@ struct CompletionUsageStatistics {

// Class that maps OpenAI request content and provides methods to create GenerationConfig from it.
struct OpenAIChatCompletionsRequest {
// Generic
chat_t messages;
std::optional<std::string> prompt{std::nullopt};
bool stream{false};
StreamOptions streamOptions;
std::string model;
std::optional<int> maxTokens{std::nullopt};
std::optional<int> numAssistantTokens{std::nullopt};
std::optional<float> frequencyPenalty{std::nullopt};
std::optional<float> presencePenalty{std::nullopt};
std::optional<float> diversityPenalty{std::nullopt};
std::optional<float> repetitionPenalty{std::nullopt};
std::optional<float> lengthPenalty{std::nullopt};
std::optional<int> numReturnSequences{std::nullopt};
bool logprobs = 0;
int logprobschat = false;
bool echo{false};
std::optional<bool> ignoreEOS{std::nullopt};
std::optional<std::set<std::string>> stop{std::nullopt};
std::optional<bool> includeStopStrInOutput{std::nullopt};
std::optional<int> numReturnSequences{std::nullopt}; // effective for beam search and multinomial decoding
// Multinomial decoding specific
std::optional<float> temperature{std::nullopt};
std::optional<float> topP{std::nullopt};
std::optional<int> topK{std::nullopt};
std::optional<int> seed{std::nullopt};
std::optional<std::set<std::string>> stop{std::nullopt};
std::optional<bool> includeStopStrInOutput{std::nullopt};
std::optional<float> frequencyPenalty{std::nullopt};
std::optional<float> presencePenalty{std::nullopt};;
std::optional<float> repetitionPenalty{std::nullopt};
// Beam search specific
std::optional<int> bestOf{std::nullopt};
std::optional<bool> ignoreEOS{std::nullopt};
bool logprobs = 0;
int logprobschat = false;
bool echo{false};
std::optional<float> lengthPenalty{std::nullopt};
std::optional<float> diversityPenalty{std::nullopt};

// Speculative decoding specific (only with speculative decoding pipeline, see <docs> for reference)
std::optional<int> numAssistantTokens{std::nullopt};
std::optional<float> assistantConfidenceThreshold{std::nullopt};

OpenAIChatCompletionsRequest() = default;
~OpenAIChatCompletionsRequest() = default;
Expand Down Expand Up @@ -143,6 +149,8 @@ struct OpenAIChatCompletionsRequest {
// Speculative decoding specific
if (numAssistantTokens.has_value())
config.num_assistant_tokens = numAssistantTokens.value();
if (assistantConfidenceThreshold.has_value())
config.assistant_confidence_threshold = assistantConfidenceThreshold.value();

return config;
}
Expand All @@ -161,7 +169,7 @@ class OpenAIChatCompletionsHandler {

absl::Status parseCompletionsPart();
absl::Status parseChatCompletionsPart();
absl::Status parseCommonPart(uint32_t maxTokensLimit, uint32_t bestOfLimit);
absl::Status parseCommonPart(uint32_t maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline);

public:
OpenAIChatCompletionsHandler(Document& doc, Endpoint endpoint, std::chrono::time_point<std::chrono::system_clock> creationTime,
Expand All @@ -184,7 +192,7 @@ class OpenAIChatCompletionsHandler {

ov::genai::GenerationConfig createGenerationConfig() const;

absl::Status parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit);
absl::Status parseRequest(uint32_t maxTokensLimit, uint32_t bestOfLimit, bool isSpeculativePipeline);

std::string serializeUnaryResponse(const std::vector<ov::genai::GenerationOutput>& generationOutputs);
std::string serializeStreamingChunk(const std::string& chunkResponse, ov::genai::GenerationFinishReason finishReason);
Expand Down
2 changes: 1 addition & 1 deletion src/llm/http_llm_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ class HttpLLMCalculator : public CalculatorBase {
nodeResources->cbPipe->get_tokenizer());
this->client = payload.client;

auto status = this->apiHandler->parseRequest(nodeResources->maxTokensLimit, nodeResources->bestOfLimit);
auto status = this->apiHandler->parseRequest(nodeResources->maxTokensLimit, nodeResources->bestOfLimit, nodeResources->isSpeculativePipeline);
if (status != absl::OkStatus())
return status;

Expand Down
17 changes: 14 additions & 3 deletions src/llm/llm_calculator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -48,9 +48,20 @@ message LLMCalculatorOptions {

optional bool enable_prefix_caching = 12 [default = false];

// speculative decoding enablement

// speculative decoding - draft model config (ignore below fields if you don't want to use speculative decoding)
// when draft_models_path is set, the pipeline will use speculative decoding
// other values are by default inherited from the main model when speculative decoding is enabled, but can be overridden
optional string draft_models_path = 13;

optional string draft_models_device = 14 [default = "CPU"];
optional string draft_device = 14;

optional uint64 draft_max_num_batched_tokens = 15;

optional uint64 draft_cache_size = 16;

optional uint64 draft_block_size = 17;

optional uint64 draft_max_num_seqs = 18;

optional bool draft_dynamic_split_fuse = 19;
}
18 changes: 16 additions & 2 deletions src/llm/llmnoderesources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,6 @@
#pragma GCC diagnostic pop

#include "../mediapipe_internal/mediapipe_utils.hpp"
#include "src/llm/llm_calculator.pb.h"
#include "src/llm/llm_executor.hpp"
#include "src/llm/text_processor.hpp"

Expand Down Expand Up @@ -157,8 +156,12 @@ Status LLMNodeResources::initializeLLMNodeResources(std::shared_ptr<LLMNodeResou
nodeResources->device = nodeOptions.device();

if (!nodeOptions.draft_models_path().empty()) {
auto draftModelConfig = ov::genai::draft_model(nodeOptions.draft_models_path(), nodeOptions.draft_models_device());
auto draftSchedulerConfig = prepareDraftModelSchedulerConfig(nodeOptions);
std::cout << "#################### draft cache size: " << draftSchedulerConfig.cache_size << std::endl;
auto draftModelConfig = ov::genai::draft_model(nodeOptions.draft_models_path(), nodeOptions.draft_device(),
ov::genai::scheduler_config(draftSchedulerConfig));
nodeResources->pluginConfig.insert(draftModelConfig);
nodeResources->isSpeculativePipeline = true;
}

auto status = JsonParser::parsePluginConfig(nodeOptions.plugin_config(), nodeResources->pluginConfig);
Expand Down Expand Up @@ -215,4 +218,15 @@ std::unordered_map<std::string, std::string> LLMNodeResources::prepareLLMNodeIni
return LLMArguments;
}

ov::genai::SchedulerConfig LLMNodeResources::prepareDraftModelSchedulerConfig(const mediapipe::LLMCalculatorOptions& nodeOptions) {
return {
.max_num_batched_tokens = nodeOptions.has_draft_max_num_batched_tokens() ? nodeOptions.draft_max_num_batched_tokens() : nodeOptions.max_num_batched_tokens(),
.cache_size = nodeOptions.has_draft_cache_size() ? nodeOptions.draft_cache_size() : nodeOptions.cache_size(),
.block_size = nodeOptions.has_draft_block_size() ? nodeOptions.draft_block_size() : nodeOptions.block_size(),
.dynamic_split_fuse = nodeOptions.has_draft_dynamic_split_fuse() ? nodeOptions.draft_dynamic_split_fuse() : nodeOptions.dynamic_split_fuse(),
.max_num_seqs = nodeOptions.has_draft_max_num_seqs() ? nodeOptions.draft_max_num_seqs() : nodeOptions.max_num_seqs(),
.enable_prefix_caching = nodeOptions.enable_prefix_caching(),
};
}

} // namespace ovms
3 changes: 3 additions & 0 deletions src/llm/llmnoderesources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@
#include "../logging.hpp"
#include "../stringutils.hpp"
#include "src/python/utils.hpp"
#include "src/llm/llm_calculator.pb.h"
#include "text_processor.hpp"

namespace ovms {
Expand Down Expand Up @@ -105,6 +106,7 @@ using plugin_config_t = std::map<std::string, ov::Any>;
struct LLMNodeResources {
public:
std::shared_ptr<ov::genai::ContinuousBatchingPipeline> cbPipe = nullptr;
bool isSpeculativePipeline{false};
std::string modelsPath;
std::string device;
plugin_config_t pluginConfig;
Expand All @@ -129,6 +131,7 @@ struct LLMNodeResources {
private:
std::unique_ptr<LLMExecutorWrapper> llmExecutorWrapper;
static std::unordered_map<std::string, std::string> prepareLLMNodeInitializeArguments(const ::mediapipe::CalculatorGraphConfig::Node& graphNodeConfig, std::string basePath);
static ov::genai::SchedulerConfig prepareDraftModelSchedulerConfig(const mediapipe::LLMCalculatorOptions& nodeOptions);

public:
virtual void initializeContinuousBatchingPipeline(
Expand Down

0 comments on commit 6253f1a

Please sign in to comment.