Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve trim tensor implementation #423

Merged
merged 23 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/causal_lm_cpp.yml
Original file line number Diff line number Diff line change
Expand Up @@ -374,7 +374,7 @@ jobs:
python -m pip install --upgrade-strategy eager "optimum>=1.14" -r ./llm_bench/python/requirements.txt "transformers<4.38" ./thirdparty/openvino_tokenizers/[transformers] --extra-index-url https://download.pytorch.org/whl/cpu
python ./llm_bench/python/convert.py --model_id TinyLlama/TinyLlama-1.1B-Chat-v1.0 --output_dir ./TinyLlama-1.1B-Chat-v1.0/ --precision FP16
convert_tokenizer ./TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16/ --output ./TinyLlama-1.1B-Chat-v1.0/pytorch/dldt/FP16/ --with-detokenizer
cmake -DCMAKE_BUILD_TYPE=Release -S ./text_generation/causal_lm/cpp/ -B ./build/
cmake -DCMAKE_BUILD_TYPE=Release -DENABLE_SYSTEM_TBB=ON -S ./text_generation/causal_lm/cpp/ -B ./build/
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
cmake --build ./build/ --config Release -j
wait
- name: run and compare
Expand Down
6 changes: 5 additions & 1 deletion text_generation/causal_lm/cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# Copyright (C) 2023-2024 Intel Corporation
# SPDX-License-Identifier: Apache-2.0

cmake_minimum_required(VERSION 3.15)
cmake_minimum_required(VERSION 3.16)
ilya-lavrenov marked this conversation as resolved.
Show resolved Hide resolved
project(causal_lm)

add_subdirectory(../../../thirdparty/openvino_tokenizers/ "${CMAKE_CURRENT_BINARY_DIR}/openvino_tokenizers/")
Expand All @@ -28,6 +28,8 @@ find_package(OpenVINO REQUIRED COMPONENTS Runtime)
target_link_libraries(speculative_decoding_lm PRIVATE openvino::runtime)
set_target_properties(speculative_decoding_lm PROPERTIES CXX_STANDARD 17)
set_target_properties(speculative_decoding_lm PROPERTIES CXX_STANDARD_REQUIRED ON)
find_package(TBB COMPONENTS tbb)
target_link_libraries(speculative_decoding_lm PRIVATE TBB::tbb)

add_executable(prompt_lookup_decoding_lm prompt_lookup_decoding_lm.cpp)
target_compile_definitions(prompt_lookup_decoding_lm PRIVATE OPENVINO_TOKENIZERS_PATH=\"$<TARGET_FILE:openvino_tokenizers>\")
Expand All @@ -36,3 +38,5 @@ find_package(OpenVINO REQUIRED COMPONENTS Runtime)
target_link_libraries(prompt_lookup_decoding_lm PRIVATE openvino::runtime)
set_target_properties(prompt_lookup_decoding_lm PROPERTIES CXX_STANDARD 17)
set_target_properties(prompt_lookup_decoding_lm PROPERTIES CXX_STANDARD_REQUIRED ON)
find_package(TBB COMPONENTS tbb)
target_link_libraries(prompt_lookup_decoding_lm PRIVATE TBB::tbb)
10 changes: 6 additions & 4 deletions text_generation/causal_lm/cpp/prompt_lookup_decoding_lm.cpp
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <openvino/core/parallel.hpp>
#include <openvino/openvino.hpp>

namespace {
Expand Down Expand Up @@ -94,10 +95,11 @@ ov::Tensor trimm_tensor(ov::Tensor& tensor, uint64_t seq_len_axis, uint64_t new_

void update_kv_cache(ov::InferRequest request, uint64_t seq_len_axis, uint64_t new_seq_len) {
// trim kv_cache values up to the new_seq_len
for (auto& state : request.query_state()) {
ov::Tensor old_tensor = state.get_state();
state.set_state(trimm_tensor(old_tensor, seq_len_axis, new_seq_len));
}
auto states = request.query_state();
ov::parallel_for(states.size(), [&](size_t i) {
ov::Tensor old_tensor = states.at(i).get_state();
states.at(i).set_state(trimm_tensor(old_tensor, seq_len_axis, new_seq_len));
});
}

class PromptLookupCandidateGenerator {
Expand Down
116 changes: 60 additions & 56 deletions text_generation/causal_lm/cpp/speculative_decoding_lm.cpp
Original file line number Diff line number Diff line change
@@ -1,13 +1,14 @@
// Copyright (C) 2023-2024 Intel Corporation
// SPDX-License-Identifier: Apache-2.0

#include <openvino/openvino.hpp>
#include <cmath>
#include <openvino/core/parallel.hpp>
#include <openvino/openvino.hpp>
#include <random>

constexpr size_t BATCH_SIZE = 1;

// sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size],
// sequence length axis in key/values tensors, for most cases [BATCH_SIZE, num_kv_heads, seq_len, head_size],
// threfore usually SEQ_LEN_AXIS = 2
constexpr size_t SEQ_LEN_AXIS = 2;

Expand Down Expand Up @@ -43,7 +44,7 @@ struct TextStreamer {
std::cout << std::string_view{text.data() + print_len, text.size() - print_len};
token_cache.clear();
print_len = 0;
return;
return;
}
if (text.size() >= 3 && text.compare(text.size() - 3, 3, "�") == 0) {
// Don't print incomplete text
Expand All @@ -60,54 +61,52 @@ struct TextStreamer {
print_len = 0;
}
};
}

ov::Tensor trimm_tensor(ov::Tensor& tensor, uint64_t seq_len_axis, uint64_t new_seq_len) {
// Copy elements from the old to a new tensor and return it.
// It's assumed that key/values tensor has a shape [BATCH_SIZE, num_kv_heads, seq_len, head_size] or [seq_len, ...],
// It that's not the case for your model please implement your own trim method.
OPENVINO_ASSERT(seq_len_axis == 2 || seq_len_axis == 0, "Cannot trim key/values with sequence length axis = ", seq_len_axis);

OPENVINO_ASSERT(seq_len_axis == 2 || seq_len_axis == 0,
"Cannot trim key/values with sequence length axis = ",
seq_len_axis);

auto old_tensor_data = tensor.data<float>();
auto shape = tensor.get_shape();
size_t batch_size = shape[0];
size_t num_kv_heads = shape[1];
size_t old_seq_len = shape[2];
size_t head_size = shape[3];

OPENVINO_ASSERT(new_seq_len <= old_seq_len);

// if new_seq_len equal to old one no need to copy tensor, return as is
if (old_seq_len == new_seq_len)
return tensor;

if (seq_len_axis == 0) {
shape[0] = new_seq_len;
tensor.set_shape(shape);
return tensor;
}

// if seq_len_axis == 2, then data is not contiguous, in order to trim need to repack tensor
auto new_tensor = ov::Tensor{ov::element::f32, {BATCH_SIZE, num_kv_heads, new_seq_len, head_size}};
auto new_tensor_data = new_tensor.data<float>();
for (size_t batch = 0; batch < BATCH_SIZE; ++batch){
for (size_t i = 0; i < num_kv_heads; ++i) {
for (size_t j = 0; j < new_seq_len; ++j) {
auto dst_ptr = new_tensor_data + num_kv_heads * new_seq_len * head_size * batch + new_seq_len * head_size * i + head_size * j;
auto src_ptr = old_tensor_data + num_kv_heads * new_seq_len * head_size * batch + old_seq_len * head_size * i + head_size * j;
std::memcpy(dst_ptr, src_ptr, head_size * sizeof(float));
}
}
}
ov::Coordinate new_shape_begin{0, 0, 0, 0};
ov::Coordinate new_shape_end{batch_size, num_kv_heads, new_seq_len, head_size};
auto new_tensor = ov::Tensor(tensor, new_shape_begin, new_shape_end);

return new_tensor;
}

void update_kv_cache(ov::InferRequest request, uint64_t seq_len_axis, uint64_t new_seq_len) {
// trim kv_cache values up to the new_seq_len
for (auto& state: request.query_state()) {
ov::Tensor old_tensor = state.get_state();
state.set_state(trimm_tensor(old_tensor, seq_len_axis, new_seq_len));
}
auto states = request.query_state();
ov::parallel_for(states.size(), [&](size_t i) {
ov::Tensor old_tensor = states.at(i).get_state();
states.at(i).set_state(trimm_tensor(old_tensor, seq_len_axis, new_seq_len));
});
}

} // namespace

int main(int argc, char* argv[]) try {
if (argc != 4) {
throw std::runtime_error(std::string{"Usage: "} + argv[0] + " <DRAFT MODEL_DIR> <MAIN MODEL_DIR> '<PROMPT>'");
Expand All @@ -118,26 +117,29 @@ int main(int argc, char* argv[]) try {
core.add_extension(OPENVINO_TOKENIZERS_PATH); // OPENVINO_TOKENIZERS_PATH is defined in CMakeLists.txt
auto tokenizer_model = core.read_model(std::string{argv[1]} + "/openvino_tokenizer.xml");
// tokenizer and detokenizer work on CPU only
ov::InferRequest tokenizer = core.compile_model(
tokenizer_model, "CPU").create_infer_request();
ov::InferRequest tokenizer = core.compile_model(tokenizer_model, "CPU").create_infer_request();
auto [draft_input_ids, draft_attention_mask] = tokenize(tokenizer, argv[3]);
ov::InferRequest detokenizer = core.compile_model(
std::string{argv[1]} + "/openvino_detokenizer.xml", "CPU").create_infer_request();
ov::InferRequest detokenizer =
core.compile_model(std::string{argv[1]} + "/openvino_detokenizer.xml", "CPU").create_infer_request();
TextStreamer text_streamer{std::move(detokenizer)};

// draft model
ov::InferRequest draft_model = core.compile_model(std::string{argv[1]} + "/openvino_model.xml", "CPU").create_infer_request();
ov::InferRequest draft_model =
core.compile_model(std::string{argv[1]} + "/openvino_model.xml", "CPU").create_infer_request();

draft_model.set_tensor("input_ids", draft_input_ids);
draft_model.set_tensor("attention_mask", draft_attention_mask);

ov::Tensor draft_position_ids = draft_model.get_tensor("position_ids");
draft_position_ids.set_shape(draft_input_ids.get_shape());
std::iota(draft_position_ids.data<int64_t>(), draft_position_ids.data<int64_t>() + draft_position_ids.get_size(), 0);
std::iota(draft_position_ids.data<int64_t>(),
draft_position_ids.data<int64_t>() + draft_position_ids.get_size(),
0);
uint64_t seq_len = draft_input_ids.get_shape()[1];

// main model
ov::InferRequest main_model = core.compile_model(std::string{argv[2]} + "/openvino_model.xml", "CPU").create_infer_request();
ov::InferRequest main_model =
core.compile_model(std::string{argv[2]} + "/openvino_model.xml", "CPU").create_infer_request();

// Input tensors for the main model should not be mixed with draft.
// Do not feed the same draft_postion_ids to the main, but copy input_ids from the draft_input_ids
Expand All @@ -152,7 +154,7 @@ int main(int argc, char* argv[]) try {
auto position_ids = main_model.get_tensor("position_ids");
position_ids.set_shape(draft_input_ids.get_shape());
std::iota(position_ids.data<int64_t>(), position_ids.data<int64_t>() + position_ids.get_size(), 0);

// set beam_idx for stateful model: no beam search is used and BATCH_SIZE = 1
draft_model.get_tensor("beam_idx").set_shape({BATCH_SIZE});
draft_model.get_tensor("beam_idx").data<int32_t>()[0] = 0;
Expand All @@ -164,17 +166,18 @@ int main(int argc, char* argv[]) try {
main_model.infer();

size_t vocab_size = draft_model.get_tensor("logits").get_shape().back();
OPENVINO_ASSERT(vocab_size == main_model.get_tensor("logits").get_shape().back(), "vocab size should be the same for the both models");

OPENVINO_ASSERT(vocab_size == main_model.get_tensor("logits").get_shape().back(),
"vocab size should be the same for the both models");

// logits shape is [BATCH_SIZE, seq_len, vocab_size]
auto logits = main_model.get_tensor("logits");
auto data_logits = logits.data<float>() + (seq_len - 1) * vocab_size;
int64_t out_token = std::max_element(data_logits, data_logits + vocab_size) - data_logits;

// the first token which is fed to both draft and main netwoks on each iteration
auto first_token = out_token;
text_streamer.put(out_token);

// run K infer requests on draft model and get next K prediction tokens on each iteration
uint64_t K = 5;
std::vector<int64_t> draft_tokens;
Expand All @@ -183,28 +186,28 @@ int main(int argc, char* argv[]) try {
draft_input_ids.set_shape({BATCH_SIZE, 1});
draft_position_ids.set_shape({BATCH_SIZE, 1});

auto rt_info = tokenizer_model->get_rt_info(); //Get the runtime info for the model
auto rt_info = tokenizer_model->get_rt_info(); // Get the runtime info for the model

if (rt_info.count("eos_token_id") > 0) { //check if the runtime information has a valid EOS token ID
if (rt_info.count("eos_token_id") > 0) { // check if the runtime information has a valid EOS token ID
SPECIAL_EOS_TOKEN = rt_info["eos_token_id"].as<int64_t>();
} else {
throw std::runtime_error("EOS token ID not found in model's runtime information.");
}

/* Speculative decoding works the following way. The draft model predicts the next K
tokens one by one in an autoregressive manner, while the main model validates these
predictions and corrects them if necessary. We go through each predicted token, and
if a difference is detected between the draft and main model, we stop and keep the
last token predicted by the main model. Then the draft model gets the latest main
prediction and again tries to predict the next K tokens, repeating the cycle.

This approach reduces the need for multiple infer requests to the main model,
enhancing performance. For instance, in more predictable parts of text generation,
the draft model can, in best-case scenarios, generate the next K tokens that exactly
match the target. In tha caste the are validated in a single inference request to
the main model (which is bigger, more accurate but slower) instead of running K
subsequent requests.
*/
/* Speculative decoding works the following way. The draft model predicts the next K
tokens one by one in an autoregressive manner, while the main model validates these
predictions and corrects them if necessary. We go through each predicted token, and
if a difference is detected between the draft and main model, we stop and keep the
last token predicted by the main model. Then the draft model gets the latest main
prediction and again tries to predict the next K tokens, repeating the cycle.

This approach reduces the need for multiple infer requests to the main model,
enhancing performance. For instance, in more predictable parts of text generation,
the draft model can, in best-case scenarios, generate the next K tokens that exactly
match the target. In tha caste the are validated in a single inference request to
the main model (which is bigger, more accurate but slower) instead of running K
subsequent requests.
*/
int max_sequence_length = 100;
while (out_token != SPECIAL_EOS_TOKEN && seq_len < max_sequence_length) {
// infer the K next tokens with draft model
Expand Down Expand Up @@ -248,8 +251,9 @@ int main(int argc, char* argv[]) try {
out_token = std::max_element(start, stop) - start;
text_streamer.put(out_token);

disagree_idx = i;
if (out_token != draft_tokens[i] || out_token == SPECIAL_EOS_TOKEN || seq_len + disagree_idx + 1 >= max_sequence_length)
disagree_idx = i;
if (out_token != draft_tokens[i] || out_token == SPECIAL_EOS_TOKEN ||
seq_len + disagree_idx + 1 >= max_sequence_length)
break;
}

Expand All @@ -259,7 +263,7 @@ int main(int argc, char* argv[]) try {
seq_len += disagree_idx + 1;
update_kv_cache(draft_model, SEQ_LEN_AXIS, seq_len);
update_kv_cache(main_model, SEQ_LEN_AXIS, seq_len);

draft_tokens.clear();
first_token = out_token;
}
Expand Down
Loading