Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

BabyLlama with CPP backend #2544

Conversation

shrinath-suresh
Copy link
Contributor

@shrinath-suresh shrinath-suresh commented Aug 28, 2023

Description

Benchmarking Babyllama deployment with CPP Backend

Setup and Test

  1. Follow the instructions from README.md to set up the cpp backend environment

  2. Download the stories model using

wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin

Download the tokenizer.bin file

Create a config file named config.json with the path of the downloaded model and tokenizer.

{
"checkpoint_path" : "/home/ubuntu/serve/cpp/stories15M.bin",
"tokenizer_path" : "/home/ubuntu/serve/cpp/src/examples/babyllama/tokenizer.bin"
}
  1. Run the build
cd serve/cpp
./builld.sh

Once the build is successful libllm_handler.so shared object file would be generated in serve/cpp/test/resources/torchscript_model/babyllama/llm_handler folder.

  1. Copy the dummy.pt file to the llm_handler folder.
  2. Move to llm_handler folder and run the following command to generate mar file
cd serve/cpp/test/resources/torchscript_model/babyllama/babyllama_handler
torch-model-archiver --model-name llm --version 1.0 --serialized-file dummy.pt --handler libbabyllama_handler:BabyLlamaHandler --runtime LSP --extra-files config.json
  1. Move the llm.mar to model_store
mkdir model_store
mv llm.mar model_store/llm.mar
  1. Create a new config.properties file and past the content.
default_response_timeout=300000

The default timeout is 120000. When the context size is large, LLM generation takes more time to complete the request in the single gpu machine.

  1. Start the torchserve
torchserve --start --ncs --ts-config config.properties --model-store model_store/
  1. Register the model using curl command
curl -v -X POST "http://localhost:8081/models?initial_workers=1&url=llm.mar"
  1. Update the input in prompt.txt if needed and run
curl http://localhost:8080/predictions/llm -T prompt.txt

Sample response

Hello my name is Daisy. Daisy is three years old. She loves to play with her toys.
One day, Daisy's mommy said, "Daisy, it's time to go to the store." Daisy was so excited! She ran to the store with her mommy.
At the store, Daisy saw a big, red balloon. She wanted it so badly! She asked her mommy, "Can I have the balloon, please?"
Mommy said, "No, Daisy. We don't have enough money for that balloon."
Daisy was sad. She wanted the balloon so much. She started to cry.
Mommy said, "Daisy, don't cry. We can get the balloon. We can buy it and take it home."
Daisy smiled. She was so happy. She hugged her mommy and said, "Thank you, mommy!"
<s>

Benchmarking

Clone the llama2.c repo

git clone https://github.com/karpathy/llama2.c/tree/master

Move to the folder and compile. Executed run will be generated.

cd llama2.c
make -j

Download the model

wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin

Run inference using the following command

./run stories15M.bin -n 256 -i "Hello my name is" -t 1.0 -p 0.9

babyllama_gcc_standalone

babyllama_gcc_standalone.txt

The standalone version generates output with 55.29 tokens per second. The variation is due to the compiler options.

Check the PR - karpathy/llama2.c#116 for cmake build support

Clone the krrishnarraj/llama2.c/ branch - from pull request

git clone https://github.com/krrishnarraj/llama2.c.git

Follow build instructions from here or run the following commands

mkdir build
cd build
cmake ..
cmake --build .
cp ../tokenizer.bin .
wget https://huggingface.co/karpathy/tinyllamas/resolve/main/stories15M.bin

Once the build is succeded , run the following command

babyllama_cmake_standalone

Standalone cmake version generates - 147.39 tokens per second

torchserve curl request

curl http://localhost:8080/predictions/llm -T prompt.txt

babyllama_torchserve

ts_log.txt

babyllama with cpp backend generates 172.3 tokens per second

Type of change

Please delete options that are not relevant.

  • Bug fix (non-breaking change which fixes an issue)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • New feature (non-breaking change which adds functionality)
  • This change requires a documentation update

Feature/Issue validation/testing

Please describe the Unit or Integration tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration.

  • Test A
    Logs for Test A

  • Test B
    Logs for Test B

Checklist:

  • Did you have fun?
  • Have you added tests that prove your fix is effective or that this feature works?
  • Has code been commented, particularly in hard-to-understand areas?
  • Have you made corresponding changes to the documentation?

@chauhang chauhang added the c++ label Aug 29, 2023
build_transformer(&transformer, checkpoint_path);

char tokenizer_path[] =
"/home/ubuntu/serve/cpp/src/examples/image_classifier/babyllama/"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Path is hard coded at present -- read from config file

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shrinath-suresh you can also add the tokenizer.bin as an additional file when creating the mar file and set the filename as load_model_request->model_dir + "tokenizer.bin"

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the code to read the tokenizer and model path from a config file

Signed-off-by: Shrinath Suresh <[email protected]>
Copy link
Collaborator

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi, thanks for this great contribution! I've left some comments on the PR.

#ifndef LLM_HANDLER_HH_
#define LLM_HANDLER_HH_

#include "run.c"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a reason why run.c gets included here? Its also listed in the cmake file as source file which probably did not work as there is no header file to declare the content. I would recommend removing it from the cmake file and include it in the .cc instead to localize visibility.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Moved run.c import to the .cc file

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

run.c was published under MIT license which requests that we include the copy right notice and license. My proposal is to create a subfolder and include run.c + the original license file. @chauhang Is that a viable proceeding?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Created a subdirectory named llama2.c and copied the run.c with the license file

build_transformer(&transformer, checkpoint_path);

char tokenizer_path[] =
"/home/ubuntu/serve/cpp/src/examples/image_classifier/babyllama/"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@shrinath-suresh you can also add the tokenizer.bin as an additional file when creating the mar file and set the filename as load_model_request->model_dir + "tokenizer.bin"

float topp = 0.9f; // top-p in nucleus sampling. 1.0 = off. 0.9 works well,
// but slower
int steps = 256; // number of steps to run for
unsigned long long rng_seed = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Initializing an rng with 0 (at bits zero) can be problematic in some cases.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removed the initialization


std::string msg = torchserve::Converter::VectorToStr(data_it->second);

char* msgCStr = new char[msg.size() + 1]; // +1 for the null terminator
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please use smart pointers when allocating dynamic memory and prefer new over malloc.
Something like

std::unique_ptr<int[]> prompt_tokens(new int[(strlen(msgCStr) + 3) * sizeof(int)]);

should work as well.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated code to use smart pointers in necessary places

long_vector.push_back(data_ptr[i]);
}

int* prompt_tokens = new int[num_elements];
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

see above

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated code to use smart pointer


int* prompt_tokens = new int[num_elements];
for (int64_t i = 0; i < num_elements; ++i) {
prompt_tokens[i] = static_cast<int>(long_vector[i]);
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why can't we just copy the data from the tensor instead of going through long_vector?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Updated the logic to directly copy the data from tensor

std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
std::vector<torch::Tensor> tensor_vector;
auto tokens_list_tensor = inputs[0].toTensor();
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we extend this to batched processing or at least process all entries in the batch?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Working on the batch processing part. Will keep you posted once it is done

@shrinath-suresh
Copy link
Contributor Author

@mreso Thanks for your comments. I will address it.

@shrinath-suresh
Copy link
Contributor Author

@mreso I have addressed most of your comments. Working on updating the code to enable batch processing. Once it is done, I will run a sanity test and update the steps/README with the details. Will keep you posted on this.

@shrinath-suresh shrinath-suresh changed the title [WIP] BabyLlama with CPP backend BabyLlama with CPP backend Sep 5, 2023
return batch_ivalue;
}

torch::Tensor LlmHandler::Inference(
Copy link
Collaborator

@lxning lxning Sep 5, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you add with torch.inference_mode(): ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is it c10::InferenceMode guard; on cpp ? - https://pytorch.org/cppdocs/notes/inference_mode.html

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

torch::inferencemode is a high level API, c10::InferenceMode is a low level api. According to libtorch doc, they are trying to use torch::xxx to unify the low level apis.

@shrinath-suresh
Copy link
Contributor Author

@mreso I have addressed most of your comments. Working on updating the code to enable batch processing. Once it is done, I will run a sanity test and update the steps/README with the details. Will keep you posted on this.

Updated code to process all the items in the batch. Please review when you find time and let me know if any other changes are needed.

@shrinath-suresh
Copy link
Contributor Author

@lxning Should we add the model and tokenizer download steps in test script. Do we have any concept of setup and teardown in the cpp backend for each test case ?. For the unit tests to pass, these files are mandatory.

Comment on lines +23 to +25
manifest_->GetModel().serialized_file),
*device));

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current cpp backend only support one device id, which means there is no across gpu device partition.

i assume this example only work for single gpu.

std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
c10::InferenceMode guard;
std::vector<torch::Tensor> batch_output_vector;
for (const torch::jit::IValue& input : inputs) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This for loop predict each inference request one by one. Can we optimize this section to either leverage C++ multithreading or GPU batching power?

Copy link
Collaborator

@mreso mreso left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

When we remove the absolute paths (and the tests pass) this will be ready to go for now. We can add batched processing in a follow-up PR @lxning .

@@ -0,0 +1,5 @@
{
"checkpoint_path" : "/home/ubuntu/serve/cpp/stories15M.bin",
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to make these relative instead of absolute.

Copy link
Contributor Author

@shrinath-suresh shrinath-suresh Sep 15, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@mreso

-rw-rw-r-- 1 ubuntu ubuntu 424K Sep  5 23:55 tokenizer.bin
-rw-rw-r-- 1 ubuntu ubuntu 58M Jul 27 04:09 stories15M.bin

These two files are needed at run time for the unit test case. I can think of two apporaches

  1. We can download these files in build.sh and remove it once the unit test case passes
  2. Add the download logic in cpp - serve/cpp/test/backends/torch_scripted/torch_scripted_backend_test.cc when executing the test case.

Second option seems to be more resonable. Whats your opinion ?

@mreso mreso mentioned this pull request Jan 24, 2024
7 tasks
@mreso
Copy link
Collaborator

mreso commented Jan 25, 2024

Closing this in favor of #2903 which picks up all changes and adds adjusts.

@mreso mreso closed this Jan 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

4 participants