Skip to content

Commit

Permalink
Processing all the batch inputs
Browse files Browse the repository at this point in the history
Signed-off-by: Shrinath Suresh <[email protected]>
  • Loading branch information
shrinath-suresh committed Sep 6, 2023
1 parent 9afce52 commit 0d12619
Showing 1 changed file with 55 additions and 48 deletions.
103 changes: 55 additions & 48 deletions cpp/src/examples/babyllama/baby_llama_handler.cc
Original file line number Diff line number Diff line change
Expand Up @@ -160,67 +160,74 @@ torch::Tensor LlmHandler::Inference(
std::shared_ptr<torch::Device>& device,
std::pair<std::string&, std::map<uint8_t, std::string>&>& idx_to_req_id,
std::shared_ptr<torchserve::InferenceResponseBatch>& response_batch) {
std::vector<torch::Tensor> tensor_vector;
tensor_vector.reserve(steps);
torch::Tensor tokens_list_tensor = inputs[0].toTensor();
std::vector<torch::Tensor> batch_output_vector;
for (const torch::jit::IValue& input : inputs) {
std::vector<torch::Tensor> tensor_vector;
tensor_vector.reserve(steps);
torch::Tensor tokens_list_tensor = input.toTensor();

int64_t num_elements = tokens_list_tensor.numel();
int64_t num_elements = tokens_list_tensor.numel();

int64_t* data_ptr = tokens_list_tensor.data_ptr<int64_t>();
int64_t* data_ptr = tokens_list_tensor.data_ptr<int64_t>();

std::unique_ptr<int[]> prompt_tokens(new int[num_elements]);
std::unique_ptr<int[]> prompt_tokens(new int[num_elements]);

for (int64_t i = 0; i < num_elements; ++i) {
prompt_tokens[i] = data_ptr[i];
}

// start the main loop
long start =
0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int token = prompt_tokens[0]; // kick off with the first token in the prompt
int pos = 0; // position in the sequence
while (pos < steps) {
// forward the transformer to get logits for the next token
float* logits = forward(&transformer, token, pos);

// advance the state state machine
if (pos < num_elements - 1) {
// if we are still processing the input prompt, force the next prompt
// token
next = prompt_tokens[pos + 1];
} else {
// otherwise sample the next token from the logits
next = sample(&sampler, logits);
for (int64_t i = 0; i < num_elements; ++i) {
prompt_tokens[i] = data_ptr[i];
}
pos++;

torch::Tensor tensor = torch::tensor(next, torch::kLong);
tensor_vector.push_back(tensor);
// start the main loop
long start =
0; // used to time our code, only initialized after first iteration
int next; // will store the next token in the sequence
int token =
prompt_tokens[0]; // kick off with the first token in the prompt
int pos = 0; // position in the sequence
while (pos < steps) {
// forward the transformer to get logits for the next token
float* logits = forward(&transformer, token, pos);

// advance the state state machine
if (pos < num_elements - 1) {
// if we are still processing the input prompt, force the next prompt
// token
next = prompt_tokens[pos + 1];
} else {
// otherwise sample the next token from the logits
next = sample(&sampler, logits);
}
pos++;

torch::Tensor tensor = torch::tensor(next, torch::kLong);
tensor_vector.push_back(tensor);

// data-dependent terminating condition: the BOS (=1) token delimits
// sequences
if (next == 1) {
break;
// data-dependent terminating condition: the BOS (=1) token delimits
// sequences
if (next == 1) {
break;
}
token = next;

// init the timer here because the first iteration can be slower
if (start == 0) {
start = time_in_ms();
}
}
token = next;

// init the timer here because the first iteration can be slower
if (start == 0) {
start = time_in_ms();
// report achieved tok/s (pos-1 because the timer starts after first
// iteration)
if (pos > 1) {
long end = time_in_ms();
double token_per_sec = (pos - 1) / (double)(end - start) * 1000;
std::cout << "Achieved tok per sec: " << token_per_sec << std::endl;
}
}

// report achieved tok/s (pos-1 because the timer starts after first
// iteration)
if (pos > 1) {
long end = time_in_ms();
double token_per_sec = (pos - 1) / (double)(end - start) * 1000;
std::cout << "Achieved tok per sec: " << token_per_sec << std::endl;
torch::Tensor stacked_tensor = torch::stack(tensor_vector);

batch_output_vector.push_back(stacked_tensor);
}

torch::Tensor stacked_tensor = torch::stack(tensor_vector);
return stacked_tensor;
return torch::stack(batch_output_vector);
}

void LlmHandler::Postprocess(
Expand Down

0 comments on commit 0d12619

Please sign in to comment.