From d894f352bf433157232dc8dc54eacd50014e898e Mon Sep 17 00:00:00 2001 From: slaren Date: Sat, 9 Mar 2024 19:55:54 +0100 Subject: [PATCH] perplexity : support using multiple sequences to allow larger batch sizes (#5946) * perplexity : support using multiple sequences to allow larger batch sizes ggml-ci * set cparams.n_parallel to the number of sequences * print tested n_ctx, add assert --- examples/perplexity/perplexity.cpp | 139 +++++++++++++++++++---------- llama.cpp | 22 +++-- 2 files changed, 108 insertions(+), 53 deletions(-) diff --git a/examples/perplexity/perplexity.cpp b/examples/perplexity/perplexity.cpp index 52789ee631234..293eb52c33653 100644 --- a/examples/perplexity/perplexity.cpp +++ b/examples/perplexity/perplexity.cpp @@ -442,7 +442,7 @@ static results_perplexity perplexity_v2(llama_context * ctx, const gpt_params & return {tokens, std::exp(nll / count), logit_history, prob_history}; } -static results_perplexity perplexity(llama_context * ctx, const gpt_params & params) { +static results_perplexity perplexity(llama_context * ctx, const gpt_params & params, const int32_t n_ctx) { if (params.ppl_stride > 0) { return perplexity_v2(ctx, params); } @@ -453,7 +453,6 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par // BOS tokens will be added for each chunk before eval const bool add_bos = llama_should_add_bos_token(llama_get_model(ctx)); - const int n_ctx = llama_n_ctx(ctx); std::ofstream logits_stream; if (!params.logits_file.empty()) { @@ -499,13 +498,19 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par double nll2 = 0.0; const int num_batches = (n_ctx + n_batch - 1) / n_batch; + const int n_seq = std::max(1, n_batch / n_ctx); + + GGML_ASSERT(n_batch < n_ctx || n_batch % n_ctx == 0); + GGML_ASSERT(params.n_ctx == n_seq * n_ctx); + + llama_batch batch = llama_batch_init(std::min(n_batch, n_ctx*n_seq), 0, 1); std::vector logits; if (num_batches > 1) { logits.reserve((size_t)n_ctx * n_vocab); } - fprintf(stderr, "%s: calculating perplexity over %d chunks, batch_size=%d\n", __func__, n_chunk, n_batch); + fprintf(stderr, "%s: calculating perplexity over %d chunks, n_ctx=%d, batch_size=%d, n_seq=%d\n", __func__, n_chunk, n_ctx, n_batch, n_seq); std::vector workers(std::thread::hardware_concurrency() - 1); @@ -518,10 +523,26 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par log_probs.resize(n_ctx * nv); } - for (int i = 0; i < n_chunk; ++i) { + // We get the logits for all the tokens in the context window (params.n_ctx) + // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, + // calculate the perplexity over the last half of the window (so the model always has + // some context to predict the token). + // + // We rely on the fact that attention in the forward pass only looks at previous + // tokens here, so the logits returned for each token are an accurate representation + // of what the model would have predicted at that point. + // + // Example, we have a context window of 512, we will compute perplexity for each of the + // last 256 tokens. Then, we split the input up into context window size chunks to + // process the entire prompt. + const int first = n_ctx/2; + + for (int i = 0; i < n_chunk; i += n_seq) { const int start = i * n_ctx; const int end = start + n_ctx; + const int n_seq_batch = std::min(n_seq, n_chunk - i); + const auto t_start = std::chrono::high_resolution_clock::now(); // clear the KV cache @@ -531,22 +552,37 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par const int batch_start = start + j * n_batch; const int batch_size = std::min(end - batch_start, n_batch); - // save original token and restore it after eval - const auto token_org = tokens[batch_start]; + batch.n_tokens = 0; + for (int seq = 0; seq < n_seq_batch; seq++) { + int seq_start = batch_start + seq*n_ctx; - // add BOS token for the first batch of each chunk - if (add_bos && j == 0) { - tokens[batch_start] = llama_token_bos(llama_get_model(ctx)); + // save original token and restore it after eval + const auto token_org = tokens[seq_start]; + + // add BOS token for the first batch of each chunk + if (add_bos && j == 0) { + tokens[seq_start] = llama_token_bos(llama_get_model(ctx)); + } + + for (int k = 0; k < batch_size; ++k) { + const int idx = seq*n_ctx + k; + batch.token[idx] = tokens[seq_start + k]; + batch.pos[idx] = j*n_batch + k; + batch.n_seq_id[idx] = 1; + batch.seq_id[idx][0] = seq; + batch.logits[idx] = batch.pos[idx] >= first ? 1 : 0; + } + batch.n_tokens += batch_size; + + // restore the original token in case it was set to BOS + tokens[seq_start] = token_org; } - if (llama_decode(ctx, llama_batch_get_one(tokens.data() + batch_start, batch_size, j * n_batch, 0))) { + if (llama_decode(ctx, batch)) { fprintf(stderr, "%s : failed to eval\n", __func__); return {tokens, -1, logit_history, prob_history}; } - // restore the original token in case it was set to BOS - tokens[batch_start] = token_org; - if (num_batches > 1) { const auto * batch_logits = llama_get_logits(ctx); logits.insert(logits.end(), batch_logits, batch_logits + batch_size * n_vocab); @@ -558,7 +594,7 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par if (i == 0) { const float t_total = std::chrono::duration(t_end - t_start).count(); fprintf(stderr, "%s: %.2f seconds per pass - ETA ", __func__, t_total); - int total_seconds = (int)(t_total * n_chunk); + int total_seconds = (int)(t_total*n_chunk/n_seq); if (total_seconds >= 60*60) { fprintf(stderr, "%d hours ", total_seconds / (60*60)); total_seconds = total_seconds % (60*60); @@ -566,37 +602,31 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par fprintf(stderr, "%.2f minutes\n", total_seconds / 60.0); } - // We get the logits for all the tokens in the context window (params.n_ctx) - // from llama_eval above. Now, based on https://huggingface.co/docs/transformers/perplexity, - // calculate the perplexity over the last half of the window (so the model always has - // some context to predict the token). - // - // We rely on the fact that attention in the forward pass only looks at previous - // tokens here, so the logits returned for each token are an accurate representation - // of what the model would have predicted at that point. - // - // Example, we have a context window of 512, we will compute perplexity for each of the - // last 256 tokens. Then, we split the input up into context window size chunks to - // process the entire prompt. - const int first = n_ctx/2; - const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits(ctx); - if (!params.logits_file.empty()) { - process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, log_probs, nll, nll2); - } else { - process_logits(n_vocab, all_logits + first*n_vocab, tokens.data() + start + first, n_ctx - 1 - first, - workers, nll, nll2, logit_history.data() + start + first, prob_history.data() + start + first); - } - count += n_ctx - first - 1; - - // perplexity is e^(average negative log-likelihood) - if (params.ppl_output_type == 0) { - printf("[%d]%.4lf,", i + 1, std::exp(nll / count)); - } else { - double av = nll/count; - double av2 = nll2/count - av*av; - if (av2 > 0) av2 = sqrt(av2/(count-1)); - printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); + for (int seq = 0; seq < n_seq_batch; seq++) { + const float * all_logits = num_batches > 1 ? logits.data() : llama_get_logits_ith(ctx, seq*n_ctx); + llama_token * tokens_data = tokens.data() + start + seq*n_ctx + first; + if (!params.logits_file.empty()) { + process_logits(logits_stream, n_vocab, all_logits + first*n_vocab, + tokens_data, n_ctx - 1 - first, + workers, log_probs, nll, nll2); + } else { + process_logits(n_vocab, all_logits + first*n_vocab, + tokens_data, n_ctx - 1 - first, + workers, nll, nll2, + logit_history.data() + start + seq*n_ctx + first, + prob_history.data() + start + seq*n_ctx + first); + } + count += n_ctx - first - 1; + + // perplexity is e^(average negative log-likelihood) + if (params.ppl_output_type == 0) { + printf("[%d]%.4lf,", i + seq + 1, std::exp(nll / count)); + } else { + double av = nll/count; + double av2 = nll2/count - av*av; + if (av2 > 0) av2 = sqrt(av2/(count-1)); + printf("%8d %.4lf %4lf %4lf\n", i*n_ctx, std::exp(nll / count), av, av2); + } } fflush(stdout); @@ -615,6 +645,8 @@ static results_perplexity perplexity(llama_context * ctx, const gpt_params & par printf("Unexpected negative standard deviation of log(prob)\n"); } + llama_batch_free(batch); + return {tokens, ppl, logit_history, prob_history}; } @@ -1782,13 +1814,24 @@ static void kl_divergence(llama_context * ctx, const gpt_params & params) { int main(int argc, char ** argv) { gpt_params params; - params.n_batch = 512; if (!gpt_params_parse(argc, argv, params)) { return 1; } params.logits_all = true; - params.n_batch = std::min(params.n_batch, params.n_ctx); + + const int32_t n_ctx = params.n_ctx; + + const bool ppl = !params.hellaswag && !params.winogrande && !params.multiple_choice && !params.kl_divergence; + if (ppl) { + int n_seq = std::max(1, params.n_batch / n_ctx); + int32_t n_kv = n_seq * n_ctx; + params.n_parallel = n_seq; + params.n_ctx = n_kv; + params.n_batch = std::min(params.n_batch, n_kv); + } else { + params.n_batch = std::min(params.n_batch, params.n_ctx); + } if (params.ppl_stride > 0) { fprintf(stderr, "Will perform strided perplexity calculation -> adjusting context size from %d to %d\n", @@ -1847,7 +1890,7 @@ int main(int argc, char ** argv) { } else if (params.kl_divergence) { kl_divergence(ctx, params); } else { - results = perplexity(ctx, params); + results = perplexity(ctx, params, n_ctx); } llama_print_timings(ctx); diff --git a/llama.cpp b/llama.cpp index c58a029f74faf..b19616e8f9a5f 100644 --- a/llama.cpp +++ b/llama.cpp @@ -8925,17 +8925,29 @@ static int llama_decode_internal( if (batch.logits) { logits_out.resize(n_vocab * n_tokens); + int32_t i_first = -1; for (uint32_t i = 0; i < n_tokens; i++) { - if (batch.logits[i] == 0) { - continue; + if (batch.logits[i] && i_first == -1) { + i_first = (int32_t) i; + } + if (batch.logits[i] == 0 || i == n_tokens - 1) { + if (i_first != -1) { + int i_last = batch.logits[i] == 0 ? i : i + 1; + // extract logits for the range [i_first, i_last) + // group the requests to minimize the number of calls to the backend + ggml_backend_tensor_get_async(backend_res, res, + logits_out.data() + (n_vocab*i_first), + (n_vocab*i_first)*sizeof(float), + (i_last - i_first)*n_vocab*sizeof(float)); + i_first = -1; + } } - ggml_backend_tensor_get_async(backend_res, res, logits_out.data() + (n_vocab*i), (n_vocab*i)*sizeof(float), n_vocab*sizeof(float)); #ifndef NDEBUG - logits_valid[i] = true; + logits_valid[i] = batch.logits[i] != 0; #endif } } else if (lctx.logits_all) { - logits_out.resize(n_vocab * n_tokens); + logits_out.resize(n_vocab*n_tokens); ggml_backend_tensor_get_async(backend_res, res, logits_out.data(), 0, n_vocab*n_tokens*sizeof(float)); #ifndef NDEBUG std::fill(logits_valid.begin(), logits_valid.end(), true);