diff --git a/common/common.cpp b/common/common.cpp index 0ad072e572c103..513d354b0f60e4 100644 --- a/common/common.cpp +++ b/common/common.cpp @@ -532,17 +532,6 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa else { invalid_param = true; } return true; } - if (arg == "--attention") { - if (++i >= argc) { - invalid_param = true; - return true; - } - std::string value(argv[i]); - /**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; } - else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; } - else { invalid_param = true; } - return true; - } if (arg == "--defrag-thold" || arg == "-dt") { if (++i >= argc) { invalid_param = true; @@ -1457,8 +1446,6 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast); printf(" --pooling {none,mean,cls,last}\n"); printf(" pooling type for embeddings, use model default if unspecified\n"); - printf(" --attn-type {causal,non-causal}\n"); - printf(" attention type for generation, use model default if unspecified\n"); printf(" -dt N, --defrag-thold N\n"); printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold); printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n"); @@ -2056,7 +2043,6 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param cparams.yarn_beta_slow = params.yarn_beta_slow; cparams.yarn_orig_ctx = params.yarn_orig_ctx; cparams.pooling_type = params.pooling_type; - cparams.attention_type = params.attention_type; cparams.defrag_thold = params.defrag_thold; cparams.cb_eval = params.cb_eval; cparams.cb_eval_user_data = params.cb_eval_user_data; diff --git a/common/common.h b/common/common.h index 13b9b32709fcf2..f68f3c2979b94b 100644 --- a/common/common.h +++ b/common/common.h @@ -95,7 +95,6 @@ struct gpt_params { enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED; enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings - enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type // // sampling parameters struct llama_sampling_params sparams; diff --git a/examples/gritlm/gritlm.cpp b/examples/gritlm/gritlm.cpp index 392e9311f444ca..9ce9071c195f09 100644 --- a/examples/gritlm/gritlm.cpp +++ b/examples/gritlm/gritlm.cpp @@ -44,6 +44,8 @@ static std::vector> encode(llama_context * ctx, const std::ve // clear previous kv_cache values (irrelevant for embeddings) llama_kv_cache_clear(ctx); + llama_set_embeddings(ctx, true); + llama_set_causal_attn(ctx, false); // run model llama_decode(ctx, batch); @@ -97,6 +99,9 @@ static std::string generate(llama_context * ctx, const std::string & prompt, boo llama_token eos_token = llama_token_eos(mdl); llama_kv_cache_clear(ctx); + llama_set_embeddings(ctx, false); + llama_set_causal_attn(ctx, true); + llama_batch bat = llama_batch_init(llama_n_batch(ctx), 0, 1); std::vector inputs = llama_tokenize(mdl, prompt, false, true); @@ -163,13 +168,7 @@ int main(int argc, char * argv[]) { llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams); // create generation context - llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams); - - // create embedding context - cparams.embeddings = true; - cparams.pooling_type = LLAMA_POOLING_TYPE_NONE; - cparams.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; - llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams); + llama_context * ctx = llama_new_context_with_model(mdl, cparams); // ### Embedding/Representation ### // samples taken from: https://github.com/ContextualAI/gritlm#basic @@ -187,8 +186,8 @@ int main(int argc, char * argv[]) { }; // No need to add instruction for retrieval documents - const std::vector> d_rep = encode(ctx_emb, documents, gritlm_instruction("")); - const std::vector> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction)); + const std::vector> d_rep = encode(ctx, documents, gritlm_instruction("")); + const std::vector> q_rep = encode(ctx, queries, gritlm_instruction(instruction)); const int n_embd = llama_n_embd(mdl); @@ -207,11 +206,10 @@ int main(int argc, char * argv[]) { // GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction { const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n"; - std::string response = generate(ctx_gen, prompt, true); + std::string response = generate(ctx, prompt, true); } - llama_free(ctx_gen); - llama_free(ctx_emb); + llama_free(ctx); llama_free_model(mdl); llama_backend_free(); diff --git a/llama.cpp b/llama.cpp index cd7f66a5dba4d5..a3736e4cf0c48c 100644 --- a/llama.cpp +++ b/llama.cpp @@ -15277,7 +15277,6 @@ struct llama_context_params llama_context_default_params() { /*.n_threads_batch =*/ GGML_DEFAULT_N_THREADS, /*.rope_scaling_type =*/ LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED, /*.pooling_type =*/ LLAMA_POOLING_TYPE_UNSPECIFIED, - /*.attention_type =*/ LLAMA_ATTENTION_TYPE_UNSPECIFIED, /*.rope_freq_base =*/ 0.0f, /*.rope_freq_scale =*/ 0.0f, /*.yarn_ext_factor =*/ -1.0f, @@ -15514,12 +15513,7 @@ struct llama_context * llama_new_context_with_model( } cparams.yarn_attn_factor *= hparams.rope_attn_factor; - - if (params.attention_type == LLAMA_ATTENTION_TYPE_UNSPECIFIED) { - cparams.causal_attn = hparams.causal_attn; - } else { - cparams.causal_attn = params.attention_type == LLAMA_ATTENTION_TYPE_CAUSAL; - } + cparams.causal_attn = hparams.causal_attn; if (cparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { if (hparams.pooling_type == LLAMA_POOLING_TYPE_UNSPECIFIED) { @@ -17232,6 +17226,10 @@ void llama_set_abort_callback(struct llama_context * ctx, bool (*abort_callback) ctx->abort_callback_data = abort_callback_data; } +void llama_set_embeddings(struct llama_context * ctx, bool embeddings) { + ctx->cparams.embeddings = embeddings; +} + void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn) { ctx->cparams.causal_attn = causal_attn; } diff --git a/llama.h b/llama.h index 109023a73e454c..9a549f5e75067f 100644 --- a/llama.h +++ b/llama.h @@ -161,12 +161,6 @@ extern "C" { LLAMA_POOLING_TYPE_LAST = 3, }; - enum llama_attention_type { - LLAMA_ATTENTION_TYPE_UNSPECIFIED = -1, - LLAMA_ATTENTION_TYPE_CAUSAL = 0, - LLAMA_ATTENTION_TYPE_NONCAUSAL = 1, - }; - enum llama_split_mode { LLAMA_SPLIT_MODE_NONE = 0, // single GPU LLAMA_SPLIT_MODE_LAYER = 1, // split layers and KV across GPUs @@ -282,7 +276,6 @@ extern "C" { enum llama_rope_scaling_type rope_scaling_type; // RoPE scaling type, from `enum llama_rope_scaling_type` enum llama_pooling_type pooling_type; // whether to pool (sum) embedding results by sequence id - enum llama_attention_type attention_type; // causal, non-causal, or unspecified // ref: https://github.com/ggerganov/llama.cpp/pull/2054 float rope_freq_base; // RoPE base frequency, 0 = from model @@ -766,6 +759,10 @@ extern "C" { // n_threads_batch is the number of threads used for prompt and batch processing (multiple tokens) LLAMA_API void llama_set_n_threads(struct llama_context * ctx, uint32_t n_threads, uint32_t n_threads_batch); + // Set whether the model is in embeddings model or not + // If true, embeddings will be returned but logits will not + LLAMA_API void llama_set_embeddings(struct llama_context * ctx, bool embeddings); + // Set whether to use causal attention or not // If set to true, the model will only attend to the past tokens LLAMA_API void llama_set_causal_attn(struct llama_context * ctx, bool causal_attn);