Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
mzegla committed Oct 29, 2024
1 parent 1d53546 commit 5ff955e
Show file tree
Hide file tree
Showing 11 changed files with 33 additions and 25 deletions.
1 change: 0 additions & 1 deletion ci/lib_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ def check_dir(start_dir):
'__pycache__',
'add.xml',
'azure_sdk.patch',
'cb.patch',
'bazel-',
'check_coverage.bat',
'genhtml',
Expand Down
1 change: 0 additions & 1 deletion external/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -47,5 +47,4 @@ exports_files([
"listen.patch",
"tf.patch",
"net_http.patch",
"cb.patch",
])
18 changes: 0 additions & 18 deletions external/cb.patch

This file was deleted.

1 change: 0 additions & 1 deletion spelling-whitelist.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
client/common/resnet_labels.txt
demos/common/python/classes.py
demos/image_classification/go/labels.go
external/cb.patch
extras/nginx-mtls-auth/model_server.conf.template
release_files/thirdparty-licenses/boringssl.LICENSE.txt
src/shape.cpp:436: strIn
Expand Down
11 changes: 11 additions & 0 deletions src/llm/apis/openai_completions.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -348,6 +348,17 @@ absl::Status OpenAIChatCompletionsHandler::parseCommonPart(uint32_t maxTokensLim
request.numReturnSequences = it->value.GetUint();
}

// Speculative decoding specific parameters

// num_assistant_tokens: uint; optional - defaults to 0
it = doc.FindMember("num_assistant_tokens");
if (it != doc.MemberEnd()) {
if (!it->value.IsUint()) {
return absl::InvalidArgumentError("num_assistant_tokens must be an unsigned integer");
}
request.numAssistantTokens = it->value.GetUint();
}

// use_beam_search: bool; optional - defaults to false
// Extension from vLLM, unsupported by OpenAI API, not available directly in CB lib
// Use best_of>1 to steer into beams search
Expand Down
7 changes: 6 additions & 1 deletion src/llm/apis/openai_completions.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@ struct OpenAIChatCompletionsRequest {
StreamOptions streamOptions;
std::string model;
std::optional<int> maxTokens{std::nullopt};
std::optional<int> numAssistantTokens{std::nullopt};
std::optional<float> frequencyPenalty{std::nullopt};
std::optional<float> presencePenalty{std::nullopt};
std::optional<float> diversityPenalty{std::nullopt};
Expand Down Expand Up @@ -120,7 +121,7 @@ struct OpenAIChatCompletionsRequest {
// TODO: early_finish = ?
// TODO use_beam_search is unused ?

// Multinomial specific
// Multinomial sampling specific
if (temperature.has_value())
config.temperature = temperature.value();
if (topK.has_value())
Expand All @@ -139,6 +140,10 @@ struct OpenAIChatCompletionsRequest {
config.presence_penalty = presencePenalty.value();
config.do_sample = config.temperature > 0.0f && config.num_beams == 1;

// Speculative decoding specific
if (numAssistantTokens.has_value())
config.num_assistant_tokens = numAssistantTokens.value();

return config;
}
};
Expand Down
2 changes: 2 additions & 0 deletions src/llm/http_llm_calculator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,9 @@ class HttpLLMCalculator : public CalculatorBase {
if (this->generationHandle->get_status() == ov::genai::GenerationStatus::RUNNING || this->generationHandle->can_read()) {
// Subsequent iteration
OVMS_PROFILE_SCOPE("Generation of subsequent streaming response");
//SPDLOG_LOGGER_INFO(llm_calculator_logger, "Start read() ...");
ov::genai::GenerationOutputs generationOutputs = this->generationHandle->read();
//SPDLOG_LOGGER_INFO(llm_calculator_logger, "End read() ...");
RET_CHECK(generationOutputs.size() == 1); // TODO: Support multiple generations
this->apiHandler->incrementProcessedTokens(generationOutputs.begin()->second.generated_ids.size());

Expand Down
6 changes: 6 additions & 0 deletions src/llm/llm_calculator.proto
Original file line number Diff line number Diff line change
Expand Up @@ -47,4 +47,10 @@ message LLMCalculatorOptions {
optional uint32 max_tokens_limit = 11 [default = 4096];

optional bool enable_prefix_caching = 12 [default = false];

// speculative decoding enablement

optional string draft_models_path = 13;

optional string draft_models_device = 14 [default = "CPU"];
}
8 changes: 7 additions & 1 deletion src/llm/llmnoderesources.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ Status LLMNodeResources::initializeLLMNodeResources(std::shared_ptr<LLMNodeResou

nodeResources->device = nodeOptions.device();

if (!nodeOptions.draft_models_path().empty()) {
auto draftModelConfig = ov::genai::draft_model(nodeOptions.draft_models_path(), nodeOptions.draft_models_device());
nodeResources->pluginConfig.insert(draftModelConfig);
}

auto status = JsonParser::parsePluginConfig(nodeOptions.plugin_config(), nodeResources->pluginConfig);
if (!status.ok()) {
SPDLOG_ERROR("Error during llm node plugin_config option parsing to JSON: {}", nodeOptions.plugin_config());
Expand All @@ -164,7 +169,8 @@ Status LLMNodeResources::initializeLLMNodeResources(std::shared_ptr<LLMNodeResou

try {
plugin_config_t tokenizerPluginConfig = {{"PERFORMANCE_HINT", "THROUGHPUT"}};
nodeResources->initializeContinuousBatchingPipeline(basePath, nodeResources->schedulerConfig, nodeResources->device, nodeResources->pluginConfig, tokenizerPluginConfig);
nodeResources->initializeContinuousBatchingPipeline(basePath, nodeResources->schedulerConfig, nodeResources->device,
nodeResources->pluginConfig, tokenizerPluginConfig);
} catch (const std::exception& e) {
SPDLOG_ERROR("Error during llm node initialization for models_path: {} exception: {}", basePath, e.what());
return StatusCode::LLM_NODE_RESOURCE_STATE_INITIALIZATION_FAILED;
Expand Down
1 change: 1 addition & 0 deletions src/llm/llmnoderesources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ struct LLMNodeResources {
int maxTokensLimit;
int bestOfLimit;


static Status initializeLLMNodeResources(std::shared_ptr<LLMNodeResources>& nodeResources, const ::mediapipe::CalculatorGraphConfig::Node& graphNode, std::string graphPath);
static void loadTextProcessor(std::shared_ptr<LLMNodeResources>& nodeResources, const std::string& chatTemplateDirectory);

Expand Down
2 changes: 0 additions & 2 deletions third_party/llm_engine/llm_engine.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,6 @@ def llm_engine():
build_file = "@_llm_engine//:BUILD",
init_submodules = True,
recursive_init_submodules = True,
patch_args = ["-p1"],
patches = ["cb.patch"],
)
# when using local repository manually run: git submodule update --recursive
#native.new_local_repository(
Expand Down

0 comments on commit 5ff955e

Please sign in to comment.