From 1be98e214d3eefdb243f6ad382cfde47d530614b Mon Sep 17 00:00:00 2001
From: Vulcan <93451215+trholding@users.noreply.github.com>
Date: Tue, 9 Jul 2024 12:48:29 +0530
Subject: [PATCH] Llama3 Support (WIP)
use -l 3 option
---
README.md | 4 +-
run.c | 140 +++++++++++++++++++++++++++++++++++++++++-------------
2 files changed, 109 insertions(+), 35 deletions(-)
diff --git a/README.md b/README.md
index 0de40994..4be76f66 100644
--- a/README.md
+++ b/README.md
@@ -33,9 +33,9 @@ Learn more about the Llama2 models & architecture at Meta: [Llama 2 @ Meta](http
# Features & Milestones
-#### Llama 3 Support
+#### Llama 3 Support WIP
-Almost done - Coming Soonish (TM)...
+Should support inference, WIP, use -l 3 option...
#### L2E OS (Linux Kernel)
diff --git a/run.c b/run.c
index d77fa3df..ad2240f9 100644
--- a/run.c
+++ b/run.c
@@ -9,6 +9,12 @@
int buffertokens = 1; // output token buffer size
int stats = 1; // extended status info
+int llamaver = 2; // llama version (default is 2, valid 2 & 3)
+float rope_sf = 10000.0; // Rope scaling factor, 10000.0 => llama2, 500000.0 > llama3
+int BOS = 1; // Beginning of Sentence token value, llama2 = 1 , llama3 = 128000
+int EOS = 2; // End of Sentence token value, llama2 = 2 , llama3 = 128009 (end of text)
+char system_template[1024]="";
+char user_template[1024]="";
// ----------------------------------------------------------------------------
// L2E Humanoid : Linux Kernel Support Directives
@@ -550,7 +556,9 @@ float* forward(Transformer* transformer, int token, int pos) {
// RoPE relative positional encoding: complex-valued rotate q and k in each head
for (int i = 0; i < dim; i+=2) {
int head_dim = i % head_size;
- float freq = 1.0f / powf(10000.0f, head_dim / (float)head_size);
+// L2E Addition
+ float freq = 1.0f / powf(rope_sf, head_dim / (float)head_size);
+// END L2E Addition
float val = pos * freq;
float fcr = cosf(val);
float fci = sinf(val);
@@ -738,8 +746,10 @@ void free_tokenizer(Tokenizer* t) {
char* decode(Tokenizer* t, int prev_token, int token) {
char *piece = t->vocab[token];
- // following BOS (1) token, sentencepiece decoder strips any leading whitespace (see PR #89)
- if (prev_token == 1 && piece[0] == ' ') { piece++; }
+// L2E Addition
+ // following BOS (1) or (2) token, sentencepiece decoder strips any leading whitespace (see PR #89)
+ if (prev_token == BOS && piece[0] == ' ') { piece++; }
+// END L2E Addition
// careful, some tokens designate raw bytes, and look like e.g. '<0x01>'
// parse this and convert and return the actual byte
unsigned char byte_val;
@@ -772,7 +782,7 @@ int str_lookup(char *str, TokenIndex *sorted_vocab, int vocab_size) {
void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *n_tokens) {
// encode the string text (input) into an upper-bound preallocated tokens[] array
- // bos != 0 means prepend the BOS token (=1), eos != 0 means append the EOS token (=2)
+ // bos != 0 means prepend the BOS token, eos != 0 means append the EOS token
if (text == NULL) { fprintf(stderr, "cannot encode NULL text\n"); exit(EXIT_FAILURE); }
if (t->sorted_vocab == NULL) {
@@ -793,17 +803,25 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
// start at 0 tokens
*n_tokens = 0;
- // add optional BOS (=1) token, if desired
- if (bos) tokens[(*n_tokens)++] = 1;
+// L2E Addition
+ // add optional BOS token, if desired
+ if (bos) tokens[(*n_tokens)++] = BOS;
+// END L2E Addition
+
// add_dummy_prefix is true by default
// so prepend a dummy prefix token to the input string, but only if text != ""
// TODO: pretty sure this isn't correct in the general case but I don't have the
// energy to read more of the sentencepiece code to figure out what it's doing
- if (text[0] != '\0') {
- int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
- tokens[(*n_tokens)++] = dummy_prefix;
+
+// L2E Addition
+ if (llamaver == 2) {
+ if (text[0] != '\0') {
+ int dummy_prefix = str_lookup(" ", t->sorted_vocab, t->vocab_size);
+ tokens[(*n_tokens)++] = dummy_prefix;
+ }
}
+// END L2E Addition
// Okay UTF-8 time. This will get messy. Here is the reference from Wikipedia:
// Code point ↔ UTF-8 conversion
@@ -854,13 +872,16 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
str_len = 0; // protect against a sequence of stray UTF8 continuation bytes
}
- // merge the best consecutive pair each iteration, according the scores in vocab_scores
+// L2E Addition
+// merge the best consecutive pair or triple each iteration, according to the scores in vocab_scores
while (1) {
float best_score = -1e10;
int best_id = -1;
int best_idx = -1;
+ int best_merge = 0; // length of the best merge sequence (2 for pair, 3 for triple)
- for (int i=0; i < (*n_tokens-1); i++) {
+ // try to find the best pair or triple to merge
+ for (int i = 0; i < (*n_tokens - 1); i++) {
// check if we can merge the pair (tokens[i], tokens[i+1])
sprintf(str_buffer, "%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]]);
int id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
@@ -869,28 +890,45 @@ void encode(Tokenizer* t, char *text, int8_t bos, int8_t eos, int *tokens, int *
best_score = t->vocab_scores[id];
best_id = id;
best_idx = i;
+ best_merge = 2;
+ }
+
+ // check if we can merge the triple (tokens[i], tokens[i+1], tokens[i+2])
+ if (i < (*n_tokens - 2)) {
+ sprintf(str_buffer, "%s%s%s", t->vocab[tokens[i]], t->vocab[tokens[i+1]], t->vocab[tokens[i+2]]);
+ id = str_lookup(str_buffer, t->sorted_vocab, t->vocab_size);
+ if (id != -1 && t->vocab_scores[id] > best_score) {
+ // this merge triple exists in vocab! record its score and position
+ best_score = t->vocab_scores[id];
+ best_id = id;
+ best_idx = i;
+ best_merge = 3;
+ }
}
}
if (best_idx == -1) {
- break; // we couldn't find any more pairs to merge, so we're done
+ break; // we couldn't find any more pairs or triples to merge, so we're done
}
- // merge the consecutive pair (best_idx, best_idx+1) into new token best_id
+ // merge the consecutive pair or triple (best_idx, best_idx+1[, best_idx+2]) into new token best_id
tokens[best_idx] = best_id;
- // delete token at position best_idx+1, shift the entire sequence back 1
- for (int i = best_idx+1; i < (*n_tokens-1); i++) {
- tokens[i] = tokens[i+1];
+ // delete token(s) at position best_idx+1 (and optionally best_idx+2), shift the entire sequence back
+ for (int i = best_idx + 1; i < (*n_tokens - best_merge + 1); i++) {
+ tokens[i] = tokens[i + best_merge - 1];
}
- (*n_tokens)--; // token length decreased
+ (*n_tokens) -= (best_merge - 1); // token length decreased by the number of merged tokens minus one
}
- // add optional EOS (=2) token, if desired
- if (eos) tokens[(*n_tokens)++] = 2;
+ // add optional EOS token, if desired
+ if (eos) tokens[(*n_tokens)++] = EOS;
free(str_buffer);
+
}
+// END L2E Addition
+
// ----------------------------------------------------------------------------
// The Sampler, which takes logits and returns a sampled token
// sampling can be done in a few ways: greedy argmax, sampling, top-p sampling
@@ -1089,9 +1127,11 @@ void generate(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
next = sample(sampler, logits);
}
pos++;
-
- // data-dependent terminating condition: the BOS (=1) token delimits sequences
- if (next == 1) { break; }
+
+// L2E Addition
+ // data-dependent terminating condition: the BOS token delimits sequences
+ if (next == BOS) { break; }
+// END L2E Addition
// print the token as string, decode it with the Tokenizer object
char* piece = decode(tokenizer, token, next);
@@ -1141,18 +1181,46 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
// buffers for reading the system prompt and user prompt from stdin
// you'll notice they are soomewhat haphazardly and unsafely set atm
+// L2E Addition
char system_prompt[512];
char user_prompt[512];
- char rendered_prompt[1152];
+ char rendered_prompt[2048];
int num_prompt_tokens = 0;
- int* prompt_tokens = (int*)malloc(1152 * sizeof(int));
+ int* prompt_tokens = (int*)malloc(2048 * sizeof(int));
+// END L2E Addition
int user_idx;
// start the main loop
int8_t user_turn = 1; // user starts
int next; // will store the next token in the sequence
int token; // stores the current token to feed into the transformer
- int prev_token;
+// L2E Addition
+ /* System and user prompt templates for llama 2 and llama 3
+ Llama 2:
+ System:
+ [INST] <>\n%s\n<>\n\n%s [/INST]
+ User:
+ [INST] %s [/INST]
+
+ Llama 3:
+ System:
+ <|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|>
+ User:
+ <|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|>\n
+ Assistant: (Starts Generating)
+ <|start_header_id|>assistant<|end_header_id|>\n\n
+ */
+ if (llamaver == 3) {
+ BOS = 128000; // 128000 = <|begin_of_text|>
+ EOS = 128009; // 128009 = <|eot_id|> , 128001 = <|end_of_text|>
+ strcpy(system_template, "<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\n%s<|eot_id|>");
+ strcpy(user_template, "<|start_header_id|>user<|end_header_id|>\n\n%s<|eot_id|>\n<|start_header_id|>assistant<|end_header_id|>\n\n");
+ } else {
+ int prev_token;
+ strcpy(system_template,"[INST] <>\n%s\n<>\n\n%s [/INST]");
+ strcpy(user_template, "[INST] %s [/INST]");
+ }
+// END L2E Addition
int pos = 0; // position in the sequence
while (pos < steps) {
@@ -1177,14 +1245,14 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
// otherwise get user prompt from stdin
read_stdin("User: ", user_prompt, sizeof(user_prompt));
}
+// L2E Addition
// render user/system prompts into the Llama 2 Chat schema
if (pos == 0 && system_prompt[0] != '\0') {
- char system_template[] = "[INST] <>\n%s\n<>\n\n%s [/INST]";
sprintf(rendered_prompt, system_template, system_prompt, user_prompt);
} else {
- char user_template[] = "[INST] %s [/INST]";
sprintf(rendered_prompt, user_template, user_prompt);
}
+// END L2E Addition
// encode the rendered prompt into tokens
encode(tokenizer, rendered_prompt, 1, 0, prompt_tokens, &num_prompt_tokens);
user_idx = 0; // reset the user index
@@ -1200,22 +1268,25 @@ void chat(Transformer *transformer, Tokenizer *tokenizer, Sampler *sampler,
// otherwise use the next token sampled from previous turn
token = next;
}
- // EOS (=2) token ends the Assistant turn
- if (token == 2) { user_turn = 1; }
+// L2E Addition
+ // EOS token ends the Assistant turn
+ if (token == EOS) { user_turn = 1; }
+// End L2E Addition
// forward the transformer to get logits for the next token
float* logits = forward(transformer, token, pos);
next = sample(sampler, logits);
pos++;
-
- if (user_idx >= num_prompt_tokens && next != 2) {
+// L2E Addition
+ if (user_idx >= num_prompt_tokens && next != EOS) {
// the Assistant is responding, so print its output
char* piece = decode(tokenizer, token, next);
safe_printf(piece); // same as printf("%s", piece), but skips "unsafe" bytes
fflush(stdout);
}
- if (next == 2) { printf("\n"); }
+ if (next == EOS) { printf("\n"); }
}
+// End L2E Addition
printf("\n");
free(prompt_tokens);
}
@@ -1254,7 +1325,8 @@ void error_usage() {
fprintf(stderr, " -y (optional) system prompt in chat mode\n");
// L2E Addition
fprintf(stderr, " -b number of tokens to buffer, default 1. 0 = max_seq_len\n");
- fprintf(stderr, " -x extended info / stats, default 1 = on. 0 = off\n");
+ fprintf(stderr, " -x extended info / stats, default 1 = on. 0 = off\n");
+ fprintf(stderr, " -l llama version / default 2 = llama2. 3 = llama3\n");
// END L2E Addition
exit(EXIT_FAILURE);
}
@@ -1323,9 +1395,11 @@ int main(int argc, char *argv[]) {
// L2E Addition
else if (argv[i][1] == 'b') { buffertokens = atoi(argv[i + 1]); }
else if (argv[i][1] == 'x') { stats = atoi(argv[i + 1]); }
+ else if (argv[i][1] == 'l') { llamaver = atoi(argv[i + 1]); }
// END L2E Addition
else { error_usage(); }
}
+ if (llamaver == 3){ rope_sf = 500000.0; }
// L2E Addition
#endif
// END L2E Addition