diff --git a/llama.cpp b/llama.cpp index 08c130643db04f..68ef80d7689d48 100644 --- a/llama.cpp +++ b/llama.cpp @@ -12343,7 +12343,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_MEAN) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(lctx.inp_mean); @@ -12375,7 +12375,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_CLS) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(lctx.inp_cls); @@ -12396,7 +12396,7 @@ static void llama_set_inputs(llama_context & lctx, const llama_batch & batch) { } } - if (cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { + if (cparams.embeddings && cparams.pooling_type == LLAMA_POOLING_TYPE_LAST) { const int64_t n_tokens = batch.n_tokens; GGML_ASSERT(lctx.inp_cls);