diff --git a/common/common.cpp b/common/common.cpp index 7500e08ff1be46..0ad072e572c103 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -528,6 +528,18 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa /**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; } else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; } else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; } + else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; } + else { invalid_param = true; } + return true; + } + if (arg == "--attention") { + if (++i >= argc) { + invalid_param = true; + return true; + } + std::string value(argv[i]); + /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } + else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; } else { invalid_param = true; } return true; } @@ -1443,8 +1455,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n"); printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow); printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); - printf(" --pooling {none,mean,cls}\n"); + printf(" --pooling {none,mean,cls,last}\n"); printf(" pooling type for embeddings, use model default if unspecified\n"); + printf(" --attn-type {causal,non-causal}\n"); + printf(" attention type for generation, use model default if unspecified\n"); printf(" -dt N, --defrag-thold N\n"); printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); @@ -2042,6 +2056,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.pooling_type = params.pooling_type; + cparams.attention_type = params.attention_type; cparams.defrag_thold = params.defrag_thold; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; diff --git a/common/common.h b/common/common.h index f68f3c2979b94b..13b9b32709fcf2 100644 --- a/common/common.h +++ b/common/common.h @@ -95,6 +95,7 @@ struct gpt_params { enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings + enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type // // sampling parameters struct llama_sampling_params sparams; diff --git a/examples/embedding/embedding.cpp b/examples/embedding/embedding.cpp index 004399b5f7eb80..64d96972048b9e 100644 --- a/examples/embedding/embedding.cpp +++ b/examples/embedding/embedding.cpp @@ -17,9 +17,25 @@ static std::vector split_lines(const std::string & s) { return lines; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id) { - for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); +static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) { + switch (pooling_type) { + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_NONE: + return true; + case LLAMA_POOLING_TYPE_CLS: + return pos == 0; + case LLAMA_POOLING_TYPE_LAST: + return pos == n_tokens - 1; + default: + GGML_ASSERT(false && "unsupported pooling type"); + } +} + +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id, enum llama_pooling_type pooling_type) { + int n_tokens = tokens.size(); + for (size_t i = 0; i < n_tokens; i++) { + bool logit = needs_logit(pooling_type, i, n_tokens); + llama_batch_add(batch, tokens[i], i, { seq_id }, logit); } } @@ -40,13 +56,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu // try to get sequence embeddings - supported only when pooling_type is not NONE const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]); - if (embd == NULL) { - embd = llama_get_embeddings_ith(ctx, i); - if (embd == NULL) { - fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i); - continue; - } - } + GGML_ASSERT(embd != NULL && "failed to get sequence embeddings"); float * out = output + batch.seq_id[i][0] * n_embd; //TODO: I would also add a parameter here to enable normalization or not. @@ -99,6 +109,12 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); + if (pooling_type == LLAMA_POOLING_TYPE_NONE) { + fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__); + return 1; + } + if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", __func__, n_ctx_train, n_ctx); @@ -178,7 +194,7 @@ int main(int argc, char ** argv) { } // add to batch - batch_add_seq(batch, inp, s); + batch_add_seq(batch, inp, s, pooling_type); s += 1; } diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 52fd719b38ee56..04a9a8082a1bc7 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -164,9 +164,13 @@ int main(int argc, char * argv[]) { llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); - // create new context - set to embedding mode + // create generation context + llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams); + + // create embedding context cparams.embeddings = true; - llama_context * ctx = llama_new_context_with_model(mdl, cparams); + cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; + llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -184,8 +188,8 @@ int main(int argc, char * argv[]) { }; // No need to add instruction for retrieval documents - const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); - const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); + const std::vector> d_rep = encode(ctx_emb, documents, gritlm_instruction("")); + const std::vector> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction)); const int n_embd = llama_n_embd(mdl); @@ -204,10 +208,11 @@ int main(int argc, char * argv[]) { // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction { const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; - std::string response = generate(ctx, prompt, true); + std::string response = generate(ctx_gen, prompt, true); } - llama_free(ctx); + llama_free(ctx_gen); + llama_free(ctx_emb); llama_free_model(mdl); llama_backend_free(); diff --git a/examples/retrieval/retrieval.cpp b/examples/retrieval/retrieval.cpp index 4e7530706d4a92..eae1822ebaf378 100644 --- a/examples/retrieval/retrieval.cpp +++ b/examples/retrieval/retrieval.cpp @@ -133,9 +133,25 @@ static std::vector chunk_file(const std::string & filename, int chunk_siz return chunks; } -static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id) { +static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) { + switch (pooling_type) { + case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_NONE: + return true; + case LLAMA_POOLING_TYPE_CLS: + return pos == 0; + case LLAMA_POOLING_TYPE_LAST: + return pos == n_tokens - 1; + default: + GGML_ASSERT(false && "unsupported pooling type"); + } +} + +static void batch_add_seq(llama_batch & batch, const std::vector & tokens, int seq_id, enum llama_pooling_type pooling_type) { + int n_tokens = tokens.size(); for (size_t i = 0; i < tokens.size(); i++) { - llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1); + bool logit = needs_logit(pooling_type, i, n_tokens); + llama_batch_add(batch, tokens[i], i, { seq_id }, logit); } } @@ -217,6 +233,7 @@ int main(int argc, char ** argv) { const int n_ctx_train = llama_n_ctx_train(model); const int n_ctx = llama_n_ctx(ctx); + const enum llama_pooling_type pooling_type = llama_pooling_type(ctx); if (n_ctx > n_ctx_train) { fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n", @@ -288,7 +305,7 @@ int main(int argc, char ** argv) { } // add to batch - batch_add_seq(batch, inp, s); + batch_add_seq(batch, inp, s, pooling_type); s += 1; } @@ -311,7 +328,7 @@ int main(int argc, char ** argv) { std::vector query_tokens = llama_tokenize(ctx, query, true); struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1); - batch_add_seq(query_batch, query_tokens, 0); + batch_add_seq(query_batch, query_tokens, 0, pooling_type); std::vector query_emb(n_embd, 0); batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd); diff --git a/llama.cpp b/llama.cpp index 34137c7ade6b26..d0254422ce8fa7 100644 --- a/llama.cpp +++ b/llama.cpp @@ -7041,6 +7041,44 @@ struct llm_build_context { return lctx.inp_s_seq; } + struct ggml_cgraph * append_pooling(struct ggml_cgraph * gf) { + struct ggml_tensor * inp = gf->nodes[gf->n_nodes - 1]; + if (strcmp(inp->name, "result_embd") != 0) { + inp = gf->nodes[gf->n_nodes - 2]; + GGML_ASSERT(strcmp(inp->name, "result_norm") == 0 && "embeddings tensor not found"); + } + + struct ggml_tensor * cur; + + switch (pooling_type) { + case LLAMA_POOLING_TYPE_MEAN: + { + struct ggml_tensor * inp_mean = build_inp_mean(); + cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, inp)), inp_mean); + } break; + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: + { + struct ggml_tensor * inp_cls = build_inp_cls(); + cur = ggml_get_rows(ctx0, inp, inp_cls); + } break; + case LLAMA_POOLING_TYPE_NONE: + { + cur = inp; + } break; + default: + { + GGML_ASSERT(false && "unknown pooling type"); + } break; + } + + cb(cur, "result_embd_pooled", -1); + + ggml_build_forward_expand(gf, cur); + + return gf; + } + struct ggml_cgraph * build_llama() { struct ggml_cgraph * gf = ggml_new_graph_custom(ctx0, LLAMA_MAX_NODES, false); @@ -8018,8 +8056,6 @@ struct llm_build_context { if (model.arch != LLM_ARCH_JINA_BERT_V2) { inp_pos = build_inp_pos(); } - struct ggml_tensor * inp_mean = build_inp_mean(); - struct ggml_tensor * inp_cls = build_inp_cls(); // construct input embeddings (token, type, position) inpL = llm_build_inp_embd(ctx0, lctx, hparams, batch, model.tok_embd, cb); @@ -8189,28 +8225,6 @@ struct llm_build_context { cur = inpL; cb(cur, "result_embd", -1); - // pooling layer - switch (pooling_type) { - case LLAMA_POOLING_TYPE_NONE: - { - // nop - } break; - case LLAMA_POOLING_TYPE_MEAN: - { - cur = ggml_mul_mat(ctx0, ggml_cont(ctx0, ggml_transpose(ctx0, cur)), inp_mean); - cb(cur, "result_embd_pooled", -1); - } break; - case LLAMA_POOLING_TYPE_CLS: - { - cur = ggml_get_rows(ctx0, cur, inp_cls); - cb(cur, "result_embd_pooled", -1); - } break; - case LLAMA_POOLING_TYPE_UNSPECIFIED: - { - GGML_ASSERT(false && "Invalid pooling type"); - } break; - } - ggml_build_forward_expand(gf, cur); return gf; @@ -10779,6 +10793,11 @@ static struct ggml_cgraph * llama_build_graph( GGML_ASSERT(false); } + // add on pooling layer + if (lctx.cparams.embeddings) { + result = llm.append_pooling(result); + } + llm.free(); return result; @@ -11000,6 +11019,37 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } + if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + const int64_t n_tokens = batch.n_tokens; + + GGML_ASSERT(lctx.inp_cls); + GGML_ASSERT(ggml_backend_buffer_is_host(lctx.inp_cls->buffer)); + + uint32_t * data = (uint32_t *) lctx.inp_cls->data; + memset(lctx.inp_cls->data, 0, n_tokens * ggml_element_size(lctx.inp_cls)); + + std::vector last_pos(n_tokens, -1); + std::vector last_row(n_tokens, -1); + + for (int i = 0; i < n_tokens; ++i) { + const llama_seq_id seq_id = batch.seq_id[i][0]; + const llama_pos pos = batch.pos[i]; + + GGML_ASSERT(seq_id < n_tokens && "seq_id cannot be larger than n_tokens with pooling_type == LAST"); + + if (pos >= last_pos[seq_id]) { + last_pos[seq_id] = pos; + last_row[seq_id] = i; + } + } + + for (int i = 0; i < n_tokens; ++i) { + if (last_row[i] >= 0) { + data[i] = last_row[i]; + } + } + } + if (kv_self.recurrent) { const int64_t n_kv = kv_self.n; @@ -11322,30 +11372,13 @@ static int llama_decode_internal( // no output res = nullptr; embd = nullptr; - } else if (!hparams.causal_attn) { - res = nullptr; // do not extract logits for embedding models such as BERT - - // token or sequence embeddings - embd = gf->nodes[gf->n_nodes - 1]; - - GGML_ASSERT(strcmp(embd->name, "result_embd") == 0 || strcmp(embd->name, "result_embd_pooled") == 0); } else if (cparams.embeddings) { - // the embeddings could be in the second to last tensor, or any of the previous tensors - int i_embd = gf->n_nodes - 2; - for (int i = 3; strcmp(embd->name, "result_norm") != 0; ++i) { - i_embd = gf->n_nodes - i; - if (i_embd < 0) { break; } - embd = gf->nodes[i_embd]; - } - GGML_ASSERT(i_embd >= 0 && "missing result_norm tensor"); - - // TODO: use a per-batch flag to know when to skip logits while keeping embeddings - if (!cparams.causal_attn) { - res = nullptr; // do not extract logits when not needed - // skip computing logits - // TODO: is this safe? - gf->n_nodes = i_embd + 1; + res = nullptr; // do not extract logits for embedding case + embd = gf->nodes[gf->n_nodes - 1]; + if (strcmp(embd->name, "result_embd_pooled") != 0) { + embd = gf->nodes[gf->n_nodes - 2]; } + GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0 && "missing embeddings tensor"); } else { embd = nullptr; // do not extract embeddings when not needed GGML_ASSERT(strcmp(res->name, "result_output") == 0 && "missing result_output tensor"); @@ -11425,11 +11458,10 @@ static int llama_decode_internal( ggml_backend_tensor_get_async(backend_embd, embd, embd_out, 0, n_outputs_new*n_embd*sizeof(float)); } } break; - case LLAMA_POOLING_TYPE_CLS: case LLAMA_POOLING_TYPE_MEAN: + case LLAMA_POOLING_TYPE_CLS: + case LLAMA_POOLING_TYPE_LAST: { - GGML_ASSERT(strcmp(embd->name, "result_embd_pooled") == 0); - // extract sequence embeddings auto & embd_seq_out = lctx.embd_seq; embd_seq_out.clear(); @@ -15239,6 +15271,7 @@ struct llama_context_params llama_context_default_params() { /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, + /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, @@ -15475,7 +15508,12 @@ struct llama_context * llama_new_context_with_model( } cparams.yarn_attn_factor *= hparams.rope_attn_factor; - cparams.causal_attn = hparams.causal_attn; + + if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { + cparams.causal_attn = hparams.causal_attn; + } else { + cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; + } if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { @@ -15820,6 +15858,10 @@ enum llama_pooling_type llama_pooling_type(const struct llama_context * ctx) { return ctx->cparams.pooling_type; } +bool llama_causal_attn(const struct llama_context * ctx) { + return ctx->cparams.causal_attn; +} + int32_t llama_n_vocab(const struct llama_model * model) { return model->hparams.n_vocab; } diff --git a/llama.h b/llama.h index b7bf2afcb403e0..109023a73e454c 100644 --- a/llama.h +++ b/llama.h @@ -158,6 +158,13 @@ extern "C" { LLAMA_POOLING_TYPE_NONE = 0, LLAMA_POOLING_TYPE_MEAN = 1, LLAMA_POOLING_TYPE_CLS = 2, + LLAMA_POOLING_TYPE_LAST = 3, + }; + + enum llama_attention_type { + LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1, + LLAMA_ATTENTION_TYPE_CAUSAL = 0, + LLAMA_ATTENTION_TYPE_NONCAUSAL = 1, }; enum llama_split_mode { @@ -275,7 +282,7 @@ extern "C" { enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id - // (ignored if no pooling layer) + enum llama_attention_type attention_type; // causal, non-causal, or unspecified // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model