Skip to content

Commit

Permalink
create append_pooling operation; allow to specify attention_type; add…
Browse files Browse the repository at this point in the history
… last token pooling; update examples
  • Loading branch information
iamlemec committed May 22, 2024
1 parent 1e37436 commit 209cc2b
Show file tree
Hide file tree
Showing 7 changed files with 176 additions and 73 deletions.
17 changes: 16 additions & 1 deletion common/common.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -528,6 +528,18 @@ bool gpt_params_find_arg(int argc, char ** argv, const std::string & arg, gpt_pa
/**/ if (value == "none") { params.pooling_type = LLAMA_POOLING_TYPE_NONE; }
else if (value == "mean") { params.pooling_type = LLAMA_POOLING_TYPE_MEAN; }
else if (value == "cls") { params.pooling_type = LLAMA_POOLING_TYPE_CLS; }
else if (value == "last") { params.pooling_type = LLAMA_POOLING_TYPE_LAST; }
else { invalid_param = true; }
return true;
}
if (arg == "--attention") {
if (++i >= argc) {
invalid_param = true;
return true;
}
std::string value(argv[i]);
/**/ if (value == "causal") { params.attention_type = LLAMA_ATTENTION_TYPE_CAUSAL; }
else if (value == "non-causal") { params.attention_type = LLAMA_ATTENTION_TYPE_NONCAUSAL; }
else { invalid_param = true; }
return true;
}
Expand Down Expand Up @@ -1443,8 +1455,10 @@ void gpt_params_print_usage(int /*argc*/, char ** argv, const gpt_params & param
printf(" --yarn-attn-factor N YaRN: scale sqrt(t) or attention magnitude (default: 1.0)\n");
printf(" --yarn-beta-slow N YaRN: high correction dim or alpha (default: %.1f)\n", params.yarn_beta_slow);
printf(" --yarn-beta-fast N YaRN: low correction dim or beta (default: %.1f)\n", params.yarn_beta_fast);
printf(" --pooling {none,mean,cls}\n");
printf(" --pooling {none,mean,cls,last}\n");
printf(" pooling type for embeddings, use model default if unspecified\n");
printf(" --attn-type {causal,non-causal}\n");
printf(" attention type for generation, use model default if unspecified\n");
printf(" -dt N, --defrag-thold N\n");
printf(" KV cache defragmentation threshold (default: %.1f, < 0 - disabled)\n", params.defrag_thold);
printf(" --ignore-eos ignore end of stream token and continue generating (implies --logit-bias 2-inf)\n");
Expand Down Expand Up @@ -2042,6 +2056,7 @@ struct llama_context_params llama_context_params_from_gpt_params(const gpt_param
cparams.yarn_beta_slow = params.yarn_beta_slow;
cparams.yarn_orig_ctx = params.yarn_orig_ctx;
cparams.pooling_type = params.pooling_type;
cparams.attention_type = params.attention_type;
cparams.defrag_thold = params.defrag_thold;
cparams.cb_eval = params.cb_eval;
cparams.cb_eval_user_data = params.cb_eval_user_data;
Expand Down
1 change: 1 addition & 0 deletions common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,7 @@ struct gpt_params {

enum llama_rope_scaling_type rope_scaling_type = LLAMA_ROPE_SCALING_TYPE_UNSPECIFIED;
enum llama_pooling_type pooling_type = LLAMA_POOLING_TYPE_UNSPECIFIED; // pooling type for embeddings
enum llama_attention_type attention_type = LLAMA_ATTENTION_TYPE_UNSPECIFIED; // attention type

// // sampling parameters
struct llama_sampling_params sparams;
Expand Down
38 changes: 27 additions & 11 deletions examples/embedding/embedding.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,25 @@ static std::vector<std::string> split_lines(const std::string & s) {
return lines;
}

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
return true;
case LLAMA_POOLING_TYPE_CLS:
return pos == 0;
case LLAMA_POOLING_TYPE_LAST:
return pos == n_tokens - 1;
default:
GGML_ASSERT(false && "unsupported pooling type");
}
}

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
int n_tokens = tokens.size();
for (size_t i = 0; i < n_tokens; i++) {
bool logit = needs_logit(pooling_type, i, n_tokens);
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
}
}

Expand All @@ -40,13 +56,7 @@ static void batch_decode(llama_context * ctx, llama_batch & batch, float * outpu

// try to get sequence embeddings - supported only when pooling_type is not NONE
const float * embd = llama_get_embeddings_seq(ctx, batch.seq_id[i][0]);
if (embd == NULL) {
embd = llama_get_embeddings_ith(ctx, i);
if (embd == NULL) {
fprintf(stderr, "%s: failed to get embeddings for token %d\n", __func__, i);
continue;
}
}
GGML_ASSERT(embd != NULL && "failed to get sequence embeddings");

float * out = output + batch.seq_id[i][0] * n_embd;
//TODO: I would also add a parameter here to enable normalization or not.
Expand Down Expand Up @@ -99,6 +109,12 @@ int main(int argc, char ** argv) {
const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);

const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);
if (pooling_type == LLAMA_POOLING_TYPE_NONE) {
fprintf(stderr, "%s: error: pooling type NONE not supported\n", __func__);
return 1;
}

if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
__func__, n_ctx_train, n_ctx);
Expand Down Expand Up @@ -178,7 +194,7 @@ int main(int argc, char ** argv) {
}

// add to batch
batch_add_seq(batch, inp, s);
batch_add_seq(batch, inp, s, pooling_type);
s += 1;
}

Expand Down
17 changes: 11 additions & 6 deletions examples/gritlm/gritlm.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -164,9 +164,13 @@ int main(int argc, char * argv[]) {

llama_model * mdl = llama_load_model_from_file(params.model.c_str(), mparams);

// create new context - set to embedding mode
// create generation context
llama_context * ctx_gen = llama_new_context_with_model(mdl, cparams);

// create embedding context
cparams.embeddings = true;
llama_context * ctx = llama_new_context_with_model(mdl, cparams);
cparams.pooling_type = LLAMA_POOLING_TYPE_NONE;
llama_context * ctx_emb = llama_new_context_with_model(mdl, cparams);

// ### Embedding/Representation ###
// samples taken from: https://github.com/ContextualAI/gritlm#basic
Expand All @@ -184,8 +188,8 @@ int main(int argc, char * argv[]) {
};

// No need to add instruction for retrieval documents
const std::vector<std::vector<float>> d_rep = encode(ctx, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx, queries, gritlm_instruction(instruction));
const std::vector<std::vector<float>> d_rep = encode(ctx_emb, documents, gritlm_instruction(""));
const std::vector<std::vector<float>> q_rep = encode(ctx_emb, queries, gritlm_instruction(instruction));

const int n_embd = llama_n_embd(mdl);

Expand All @@ -204,10 +208,11 @@ int main(int argc, char * argv[]) {
// GritLM models are not finetuned with system prompts, as you can just include system-like instructions together with your user instruction
{
const std::string prompt = "<|user|>\nPlease write me a poem about my recent hike of Mt. Fuji at midnight in the style of Shakespeare.\n<|assistant|>\n";
std::string response = generate(ctx, prompt, true);
std::string response = generate(ctx_gen, prompt, true);
}

llama_free(ctx);
llama_free(ctx_gen);
llama_free(ctx_emb);
llama_free_model(mdl);
llama_backend_free();

Expand Down
25 changes: 21 additions & 4 deletions examples/retrieval/retrieval.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -133,9 +133,25 @@ static std::vector<chunk> chunk_file(const std::string & filename, int chunk_siz
return chunks;
}

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id) {
static bool needs_logit(enum llama_pooling_type pooling_type, int pos, int n_tokens) {
switch (pooling_type) {
case LLAMA_POOLING_TYPE_MEAN:
case LLAMA_POOLING_TYPE_NONE:
return true;
case LLAMA_POOLING_TYPE_CLS:
return pos == 0;
case LLAMA_POOLING_TYPE_LAST:
return pos == n_tokens - 1;
default:
GGML_ASSERT(false && "unsupported pooling type");
}
}

static void batch_add_seq(llama_batch & batch, const std::vector<int32_t> & tokens, int seq_id, enum llama_pooling_type pooling_type) {
int n_tokens = tokens.size();
for (size_t i = 0; i < tokens.size(); i++) {
llama_batch_add(batch, tokens[i], i, { seq_id }, i == tokens.size() - 1);
bool logit = needs_logit(pooling_type, i, n_tokens);
llama_batch_add(batch, tokens[i], i, { seq_id }, logit);
}
}

Expand Down Expand Up @@ -217,6 +233,7 @@ int main(int argc, char ** argv) {

const int n_ctx_train = llama_n_ctx_train(model);
const int n_ctx = llama_n_ctx(ctx);
const enum llama_pooling_type pooling_type = llama_pooling_type(ctx);

if (n_ctx > n_ctx_train) {
fprintf(stderr, "%s: warning: model was trained on only %d context tokens (%d specified)\n",
Expand Down Expand Up @@ -288,7 +305,7 @@ int main(int argc, char ** argv) {
}

// add to batch
batch_add_seq(batch, inp, s);
batch_add_seq(batch, inp, s, pooling_type);
s += 1;
}

Expand All @@ -311,7 +328,7 @@ int main(int argc, char ** argv) {
std::vector<int32_t> query_tokens = llama_tokenize(ctx, query, true);

struct llama_batch query_batch = llama_batch_init(n_batch, 0, 1);
batch_add_seq(query_batch, query_tokens, 0);
batch_add_seq(query_batch, query_tokens, 0, pooling_type);

std::vector<float> query_emb(n_embd, 0);
batch_decode(ctx, query_batch, query_emb.data(), 1, n_embd);
Expand Down
Loading

0 comments on commit 209cc2b

Please sign in to comment.