From 1be98e214d3eefdb243f6ad382cfde47d530614b Mon Sep 17 00:00:00 2001 From: Vulcan <93451215+trholding@users.noreply.github.com> Date: Tue, 9 Jul 2024 12:48:29 +0530 Subject: [PATCH] Llama3 Support (WIP) use -l 3 option --- README.md | 4 +- run.c | 140 +++++++++++++++++++++++++++++++++++++++++------------- 2 files changed, 109 insertions(+), 35 deletions(-) diff --git a/README.md b/README.md index 0de40994..4be76f66 100644 --- a/README.md +++ b/README.md @@ -33,9 +33,9 @@ Learn more about the Llama2 models & architecture at Meta: [Llama 2 @ Meta](http # Features & Milestones -#### Llama 3 Support +#### Llama 3 Support WIP -Almost done - Coming Soonish (TM)... +Should support inference, WIP, use -l 3 option... #### L2E OS (Linux Kernel) diff --git a/run.c b/run.c index d77fa3df..ad2240f9 100644 --- a/run.c +++ b/run.c @@ -9,6 +9,12 @@ int buffertokens = 1; // output token buffer size int stats = 1; // extended status info +int llamaver = 2; // llama version (default is 2, valid 2 & 3) +float rope_sf = 10000.0; // Rope scaling factor, 10000.0 => llama2, 500000.0 > llama3 +int BOS = 1; // Beginning of Sentence token value, llama2 = 1 , llama3 = 128000 +int EOS = 2; // End of Sentence token value, llama2 = 2 , llama3 = 128009 (end of text) +char system_template[1024]=""; +char user_template[1024]=""; // ---------------------------------------------------------------------------- // L2E Humanoid : Linux Kernel Support Directives @@ -550,7 +556,9 @@ float* forward(Transformer* transformer, int token, int pos) { // RoPE relative positional encoding: complex-valued rotate q and k in each head for (int i = 0; i < dim; i+=2) { int head_dim = i % head_size; - float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size); +// L2E Addition + float freq = 1.0f / powf(rope_sf, head_dim / (float)head_size); +// END L2E Addition float val = pos * freq; float fcr = cosf(val); float fci = sinf(val); @@ -738,8 +746,10 @@ void free_tokenizer(Tokenizer* t) { char* decode(Tokenizer* t, int prev_token, int token) { char *piece = t->vocab[token]; - // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89) - if (prev_token == 1 && piece[0] == ' ') { piece++; } +// L2E Addition + // following BOS (1) or (2) token, sentencepiece decoder strips any leading whitespace (see PR #89) + if (prev_token == BOS && piece[0] == ' ') { piece++; } +// END L2E Addition // careful, some tokens designate raw bytes, and look like e.g. '<0x01>' // parse this and convert and return the actual byte unsigned char byte_val; @@ -772,7 +782,7 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) { void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) { // encode the string text (input) into an upper-bound preallocated tokens[] array - // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2) + // bos != 0 means prepend the BOS token, eos != 0 means append the EOS token if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); } if (t->sorted_vocab == NULL) { @@ -793,17 +803,25 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int * // start at 0 tokens *n_tokens = 0; - // add optional BOS (=1) token, if desired - if (bos) tokens[(*n_tokens)++] = 1; +// L2E Addition + // add optional BOS token, if desired + if (bos) tokens[(*n_tokens)++] = BOS; +// END L2E Addition + // add_dummy_prefix is true by default // so prepend a dummy prefix token to the input string, but only if text != "" // TODO: pretty sure this isn't correct in the general case but I don't have the // energy to read more of the sentencepiece code to figure out what it's doing - if (text[0] != '\0') { - int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size); - tokens[(*n_tokens)++] = dummy_prefix; + +// L2E Addition + if (llamaver == 2) { + if (text[0] != '\0') { + int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size); + tokens[(*n_tokens)++] = dummy_prefix; + } } +// END L2E Addition // Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia: // Code point ↔ UTF-8 conversion @@ -854,13 +872,16 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int * str_len = 0; // protect against a sequence of stray UTF8 continuation bytes } - // merge the best consecutive pair each iteration, according the scores in vocab_scores +// L2E Addition +// merge the best consecutive pair or triple each iteration, according to the scores in vocab_scores while (1) { float best_score = -1e10; int best_id = -1; int best_idx = -1; + int best_merge = 0; // length of the best merge sequence (2 for pair, 3 for triple) - for (int i=0; i < (*n_tokens-1); i++) { + // try to find the best pair or triple to merge + for (int i = 0; i < (*n_tokens - 1); i++) { // check if we can merge the pair (tokens[i], tokens[i+1]) sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]); int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); @@ -869,28 +890,45 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int * best_score = t->vocab_scores[id]; best_id = id; best_idx = i; + best_merge = 2; + } + + // check if we can merge the triple (tokens[i], tokens[i+1], tokens[i+2]) + if (i < (*n_tokens - 2)) { + sprintf(str_buffer, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]], t->vocab[tokens[i+2]]); + id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size); + if (id != -1 && t->vocab_scores[id] > best_score) { + // this merge triple exists in vocab! record its score and position + best_score = t->vocab_scores[id]; + best_id = id; + best_idx = i; + best_merge = 3; + } } } if (best_idx == -1) { - break; // we couldn't find any more pairs to merge, so we're done + break; // we couldn't find any more pairs or triples to merge, so we're done } - // merge the consecutive pair (best_idx, best_idx+1) into new token best_id + // merge the consecutive pair or triple (best_idx, best_idx+1[, best_idx+2]) into new token best_id tokens[best_idx] = best_id; - // delete token at position best_idx+1, shift the entire sequence back 1 - for (int i = best_idx+1; i < (*n_tokens-1); i++) { - tokens[i] = tokens[i+1]; + // delete token(s) at position best_idx+1 (and optionally best_idx+2), shift the entire sequence back + for (int i = best_idx + 1; i < (*n_tokens - best_merge + 1); i++) { + tokens[i] = tokens[i + best_merge - 1]; } - (*n_tokens)--; // token length decreased + (*n_tokens) -= (best_merge - 1); // token length decreased by the number of merged tokens minus one } - // add optional EOS (=2) token, if desired - if (eos) tokens[(*n_tokens)++] = 2; + // add optional EOS token, if desired + if (eos) tokens[(*n_tokens)++] = EOS; free(str_buffer); + } +// END L2E Addition + // ---------------------------------------------------------------------------- // The Sampler, which takes logits and returns a sampled token // sampling can be done in a few ways: greedy argmax, sampling, top-p sampling @@ -1089,9 +1127,11 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, next = sample(sampler, logits); } pos++; - - // data-dependent terminating condition: the BOS (=1) token delimits sequences - if (next == 1) { break; } + +// L2E Addition + // data-dependent terminating condition: the BOS token delimits sequences + if (next == BOS) { break; } +// END L2E Addition // print the token as string, decode it with the Tokenizer object char* piece = decode(tokenizer, token, next); @@ -1141,18 +1181,46 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, // buffers for reading the system prompt and user prompt from stdin // you'll notice they are soomewhat haphazardly and unsafely set atm +// L2E Addition char system_prompt[512]; char user_prompt[512]; - char rendered_prompt[1152]; + char rendered_prompt[2048]; int num_prompt_tokens = 0; - int* prompt_tokens = (int*)malloc(1152 * sizeof(int)); + int* prompt_tokens = (int*)malloc(2048 * sizeof(int)); +// END L2E Addition int user_idx; // start the main loop int8_t user_turn = 1; // user starts int next; // will store the next token in the sequence int token; // stores the current token to feed into the transformer - int prev_token; +// L2E Addition + /* System and user prompt templates for llama 2 and llama 3 + Llama 2: + System: + [INST] <>\n%s\n<>\n\n%s [/INST] + User: + [INST] %s [/INST] + + Llama 3: + System: + <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|> + User: + <|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|>\n + Assistant: (Starts Generating) + <|start_header_id|>assistant<|end_header_id|>\n\n + */ + if (llamaver == 3) { + BOS = 128000; // 128000 = <|begin_of_text|> + EOS = 128009; // 128009 = <|eot_id|> , 128001 = <|end_of_text|> + strcpy(system_template, "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|>"); + strcpy(user_template, "<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n"); + } else { + int prev_token; + strcpy(system_template,"[INST] <>\n%s\n<>\n\n%s [/INST]"); + strcpy(user_template, "[INST] %s [/INST]"); + } +// END L2E Addition int pos = 0; // position in the sequence while (pos < steps) { @@ -1177,14 +1245,14 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, // otherwise get user prompt from stdin read_stdin("User: ", user_prompt, sizeof(user_prompt)); } +// L2E Addition // render user/system prompts into the Llama 2 Chat schema if (pos == 0 && system_prompt[0] != '\0') { - char system_template[] = "[INST] <>\n%s\n<>\n\n%s [/INST]"; sprintf(rendered_prompt, system_template, system_prompt, user_prompt); } else { - char user_template[] = "[INST] %s [/INST]"; sprintf(rendered_prompt, user_template, user_prompt); } +// END L2E Addition // encode the rendered prompt into tokens encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens); user_idx = 0; // reset the user index @@ -1200,22 +1268,25 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler, // otherwise use the next token sampled from previous turn token = next; } - // EOS (=2) token ends the Assistant turn - if (token == 2) { user_turn = 1; } +// L2E Addition + // EOS token ends the Assistant turn + if (token == EOS) { user_turn = 1; } +// End L2E Addition // forward the transformer to get logits for the next token float* logits = forward(transformer, token, pos); next = sample(sampler, logits); pos++; - - if (user_idx >= num_prompt_tokens && next != 2) { +// L2E Addition + if (user_idx >= num_prompt_tokens && next != EOS) { // the Assistant is responding, so print its output char* piece = decode(tokenizer, token, next); safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes fflush(stdout); } - if (next == 2) { printf("\n"); } + if (next == EOS) { printf("\n"); } } +// End L2E Addition printf("\n"); free(prompt_tokens); } @@ -1254,7 +1325,8 @@ void error_usage() { fprintf(stderr, " -y (optional) system prompt in chat mode\n"); // L2E Addition fprintf(stderr, " -b number of tokens to buffer, default 1. 0 = max_seq_len\n"); - fprintf(stderr, " -x extended info / stats, default 1 = on. 0 = off\n"); + fprintf(stderr, " -x extended info / stats, default 1 = on. 0 = off\n"); + fprintf(stderr, " -l llama version / default 2 = llama2. 3 = llama3\n"); // END L2E Addition exit(EXIT_FAILURE); } @@ -1323,9 +1395,11 @@ int main(int argc, char *argv[]) { // L2E Addition else if (argv[i][1] == 'b') { buffertokens = atoi(argv[i + 1]); } else if (argv[i][1] == 'x') { stats = atoi(argv[i + 1]); } + else if (argv[i][1] == 'l') { llamaver = atoi(argv[i + 1]); } // END L2E Addition else { error_usage(); } } + if (llamaver == 3){ rope_sf = 500000.0; } // L2E Addition #endif // END L2E Addition