forked from ggerganov/llama.cpp
-
Notifications
You must be signed in to change notification settings - Fork 364
/
gpttype_adapter.cpp
3655 lines (3267 loc) · 140 KB
/
gpttype_adapter.cpp
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
//This is Concedo's shitty adapter for adding python bindings for llama
//Considerations:
//Don't want to use pybind11 due to dependencies on MSVCC
//ZERO or MINIMAL changes as possible to main.cpp - do not move their function declarations here!
//Leave main.cpp UNTOUCHED, We want to be able to update the repo and pull any changes automatically.
//No dynamic memory allocation! Setup structs with FIXED (known) shapes and sizes for ALL output fields
//Python will ALWAYS provide the memory, we just write to it.
#include <cmath>
#include <time.h>
#include <mutex>
#include <unordered_map>
#include "model_adapter.h"
#include "otherarch.h"
#include "llama.h"
#include <vector>
#include <map>
#include <cstdint>
#include <string>
#include <cctype>
#include <locale>
//for easier compilation
//concat source files into one file for compilation purposes
#include "llama_v2.cpp"
#include "llama_v3.cpp"
#include "src/llama.cpp"
#include "utils.cpp"
#include "gptj_v1.cpp"
#include "gptj_v2.cpp"
#include "gptj_v3.cpp"
#include "gpt2_v1.cpp"
#include "gpt2_v2.cpp"
#include "gpt2_v3.cpp"
#include "rwkv_v2.cpp"
#include "rwkv_v3.cpp"
#include "neox_v2.cpp"
#include "neox_v3.cpp"
#include "mpt_v3.cpp"
#include "examples/llava/clip.h"
#include "examples/llava/llava.h"
//const
const int extra_context_handle_fragmentation = 120;
const int LLAVA_TOKEN_IDENTIFIER_A = -998; //alternate between both, changing when image changes
const int LLAVA_TOKEN_IDENTIFIER_B = -999;
//shared
std::string executable_path = "";
std::string lora_filename = "";
std::string lora_base = "";
std::string mmproj_filename = "";
bool generation_finished;
float last_process_time = 0;
float last_eval_time = 0;
int last_token_count = 0;
int last_seed = -1;
int total_gens = 0;
stop_reason last_stop_reason = stop_reason::INVALID;
std::vector<std::string> generated_tokens;
llama_grammar * grammar = nullptr; //currently used grammar
llama_grammar_parser parsed_grammar;
static std::string current_grammar = "";
//return val: 0=fail, 1=(original ggml, alpaca), 2=(ggmf), 3=(ggjt)
static FileFormat file_format = FileFormat::BADFORMAT;
static FileFormatExtraMeta file_format_meta;
static gpt_vocab vocab;
static int32_t n_vocab = 0;
static gptj_v1_model gptj_ctx_v1;
static gptj_v2_model gptj_ctx_v2;
static gptj_model gptj_ctx_v3;
static gpt2_v1_model gpt2_ctx_v1;
static gpt2_v2_model gpt2_ctx_v2;
static gpt2_model gpt2_ctx_v3;
static gpt_neox_v2_model neox_ctx_v2;
static gpt_neox_model neox_ctx_v3;
static mpt_model mpt_ctx_v3;
static rwkv_v2_context * rwkv_ctx_v2;
static rwkv_context * rwkv_ctx_v3;
static llama_v2_context * llama_ctx_v2;
static llama_v3_context * llama_ctx_v3;
static llama_context * llama_ctx_v4;
static clip_ctx * clp_ctx = nullptr; //for llava
static clip_image_u8 * clp_img_data = nullptr; //most recent image
static std::vector<llava_image> llava_images;
static std::string llava_composite_image_signature = ""; //for identifying when the llava images change, we need to invalidate the cache
static int current_llava_identifier = LLAVA_TOKEN_IDENTIFIER_A;
static kcpp_params * kcpp_data = nullptr;
static int max_context_limit_at_load = 0;
static int n_past = 0;
static int debugmode = 0; //-1 = hide all, 0 = normal, 1 = showall
static std::vector<gpt_vocab::id> last_n_tokens;
static std::vector<gpt_vocab::id> current_context_tokens;
static size_t mem_per_token = 0;
static std::vector<float> logits;
static std::vector<int> smartcontext;
static std::vector<std::string> stop_sequence;
static std::vector<int> special_stop_sequence; //for stop sequences that don't have a string representation
static std::vector<std::string> banned_tokens;
static std::vector<int> banned_token_ids;
static std::vector<std::string> banned_phrases;
static std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>> dry_sequence_breakers; // Multi-mapping from first token of sequence to tail of sequence (tail is empty for a single token)
static std::vector<int> dry_repeat_count; // Indexed as last_n_tokens
static std::unordered_map<gpt_vocab::id, int> dry_max_token_repeat;
static std::vector<TopPicksData> top_picks_history;
static int remaining_tokens = 0;
static int stopper_unused_tokens = 0;
static std::mutex concat_output_mtx;
static std::string concat_output = "";
static std::string concat_output_reader_copy_poll = ""; //for streaming
static std::string concat_output_reader_copy_res = ""; //for gen response
static std::vector<logit_bias> logit_biases;
static int delayed_generated_tokens_limit = 0;
std::deque<std::string> delayed_generated_tokens; //for use with antislop sampling
static std::map<int,std::vector<int>> antislop_banned_token_ids; //first is the npast position, second is the array of banned ids at that index
inline bool IsNanCheck(float f)
{
const unsigned int u = *(unsigned int*)&f;
return (u&0x7F800000) == 0x7F800000 && (u&0x7FFFFF); // Both NaN and qNan.
}
inline bool LogitsDuplicated(std::vector<float> & arr1, std::vector<float> & arr2)
{
int compareQty = 5;
if(arr1.size() < compareQty || arr2.size() < compareQty || arr1.size()!=arr2.size())
{
printf("\nError: Logit array sizes are bad!\n");
return false;
}
for(int i=0;i<compareQty;++i)
{
if(arr1[i]!=arr2[i])
{
return false;
}
}
return true;
}
static std::string FileFormatTokenizeID(int id, FileFormat file_format, bool return_special = false)
{
if(id<0)
{
return ""; //placeholder IDs cannot be tokenized!
}
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2)
{
return std::string(llama_v2_token_to_str(llama_ctx_v2, id));
}
else if (file_format == FileFormat::GGJT_3)
{
return std::string(llama_v3_token_to_str(llama_ctx_v3, id));
}
else if(file_format == FileFormat::GGUF_GENERIC)
{
return std::string(common_token_to_piece(llama_ctx_v4, id, return_special));
}
else
{
return vocab.id_to_token[id];
}
}
static void TokenizeString(const std::string & str_to_tokenize, std::vector<int> & output_tokens, FileFormat file_format, bool add_bos=true)
{
if (file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC)
{
if(file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 )
{
output_tokens = ::llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, add_bos);
}
else if (file_format == FileFormat::GGML)
{
output_tokens = ::legacy_llama_v2_tokenize(llama_ctx_v2, str_to_tokenize, add_bos);
}
else if (file_format == FileFormat::GGJT_3)
{
output_tokens = ::llama_v3_tokenize(llama_ctx_v3, str_to_tokenize, add_bos);
}
else
{
output_tokens = ::common_tokenize(llama_ctx_v4, str_to_tokenize, add_bos, true);
if(add_bos)
{
llama_token bostoadd = llama_token_bos(&(llama_ctx_v4->model));
if(bostoadd != LLAMA_TOKEN_NULL) //if bos does not exist, do not add it
{
if(output_tokens.size()==0)
{
output_tokens.push_back(bostoadd);
}
else
{
if(output_tokens[0]!=bostoadd)
{
output_tokens.insert(output_tokens.begin(), 1, bostoadd);
}
}
}
}
}
}
else
{
// tokenize the prompt
output_tokens = ::gpt_tokenize(vocab, str_to_tokenize);
}
}
static int GetEosID(FileFormat file_format, int32_t n_vocab)
{
unsigned int eosID = 0;
if(file_format == FileFormat::GGML || file_format == FileFormat::GGHF || file_format == FileFormat::GGJT || file_format == FileFormat::GGJT_2 || file_format == FileFormat::GGJT_3 || file_format == FileFormat::GGUF_GENERIC)
{
if(file_format == FileFormat::GGUF_GENERIC)
{
eosID = llama_token_eos(&(llama_ctx_v4->model));
}
else if(file_format == FileFormat::GGJT_3)
{
eosID = llama_v3_token_eos();
}
else
{
eosID = llama_v3_token_eos();
}
}
else
{
if (file_format == FileFormat::GPT2_1 ||
file_format == FileFormat::GPT2_2 ||
file_format == FileFormat::GPT2_3 ||
file_format == FileFormat::GPT2_4 ||
file_format == FileFormat::GPTJ_1 ||
file_format == FileFormat::GPTJ_2 ||
file_format == FileFormat::GPTJ_3 ||
file_format == FileFormat::GPTJ_4 ||
file_format == FileFormat::GPTJ_5)
{
eosID = 50256;
if (n_vocab <= eosID)
{
//special case, starcoder models use ID 0 for EOS
eosID = 0;
}
}
if (file_format == FileFormat::RWKV_1 ||
file_format == FileFormat::RWKV_2 ||
file_format == FileFormat::NEOX_1 ||
file_format == FileFormat::NEOX_2 ||
file_format == FileFormat::NEOX_3 ||
file_format == FileFormat::NEOX_4 ||
file_format == FileFormat::NEOX_5 ||
file_format == FileFormat::NEOX_6 ||
file_format == FileFormat::NEOX_7 ||
file_format == FileFormat::MPT_1)
{
eosID = 0;
}
}
return eosID;
}
static int GetEotID(FileFormat file_format)
{
if(file_format == FileFormat::GGUF_GENERIC)
{
return llama_token_eot(&(llama_ctx_v4->model));
}
return -1;
}
static float LowestLogit(const std::vector<float> & logits)
{
int topid = std::min_element(logits.begin(), logits.end()) - logits.begin();
float v = logits[topid];
return (v < 0 ? (v-8) : 0);
}
static float LowestLogit(const float *logits, size_t size)
{
if (size == 0) {
// Handle the case of an empty array
return 0.0;
}
int topid = std::min_element(logits, logits + size) - logits;
float v = logits[topid];
return (v < 0 ? (v-8) : 0);
}
static std::string RemoveBell(const std::string & input) //removes the bell character
{
std::string word2;
std::remove_copy(input.begin(), input.end(), std::back_inserter(word2), '\a');
return word2;
}
static std::string get_tok_vec_str(std::vector<int> &embd)
{
std::string tmp = "";
for (auto id : embd)
{
tmp += "'" + FileFormatTokenizeID(id, file_format, true) + " (" + std::to_string(id) + ")', ";
}
::utreplace(tmp, "\n", "\\n");
return tmp;
}
static void print_tok_vec_str(std::vector<int> &vec)
{
printf("\n[%s]\n", get_tok_vec_str(vec).c_str());
}
bool allExtendedUnicode(const std::string& str) {
if(str.size()==0)
{
return false;
}
for (unsigned char c : str) {
if (c <= 127) {
return false;
}
}
return true;
}
// Find tokens that completely contain `str`, either as a single token, or as a sequence of tokens.
// It's important to use a hash map for head tokens because some models have many of them.
// For example, the Llama 3 tokenizer has 6570 tokens containing the period ('.') character.
// Single tokens are allowed to extend past `str` at the front and back. This is to allow, for
// instance, the token '.\n' to be a head for both '.' and '\n'. However if a head token
// begins a multi-token sequence, the head can only extend past `str` at the beginning. The
// tail tokens are generated by tokenizing the remainder.
// If max_tail_len is >= 0, the maximum token length of a tail sequence is clamped to this value.
static void GetOverlappingTokenSequences(const std::string& str, std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>>& token_sequences, int max_tail_len = -1) {
bool isAllExtendedUnicode = allExtendedUnicode(str);
for(int v=0;v<n_vocab;++v)
{
std::string word = FileFormatTokenizeID(v, file_format, true);
if (word.find(str) != std::string::npos)
{
// The string is entirely contained within this single token.
// Ensure that token_sequences only contains one key-value-pair with an empty value.
auto its = token_sequences.equal_range(v);
bool empty = false;
for (auto it = its.first; it != its.second; ++it) {
if (it->second.empty()) {
empty = true;
break;
}
}
if (!empty) {
token_sequences.emplace(v, std::vector<gpt_vocab::id>());
}
} else {
// Check whether a prefix of the string overlaps with a suffix of the token.
// Just do a naive O(N^2) search, since the worst case is limited by the
// maximum character length of a token in the vocabulary.
size_t word_len = word.size(), str_len = str.size();
size_t pos = -1;
while ((pos = word.find(str[0], pos + 1)) != std::string::npos) {
bool match = true;
size_t i;
for (i = 1; i < str_len && i + pos < word_len; ++i) {
if (word[pos + i] != str[i]) {
match = false;
break;
}
}
if (match && !isAllExtendedUnicode) {
// We matched to the end of the string. Since `str` is not contained in `word`,
// there must be trailing letters in `str`.
std::vector<gpt_vocab::id> tokenization;
TokenizeString(str.substr(i), tokenization, file_format, false);
if (max_tail_len >= 0 && tokenization.size() > max_tail_len) {
tokenization.resize(max_tail_len);
}
// Ensure we don't already have a duplicate matching tokenization.
auto its = token_sequences.equal_range(v);
bool found = false;
for (auto it = its.first; it != its.second; ++it) {
if (tokenization == it->second) {
found = true;
break;
}
}
if (!found)
{
token_sequences.emplace(v, tokenization);
}
}
}
}
}
}
// Function to convert a UTF-8 encoded string to lowercase
static std::string toLowerCase(const std::string& str) {
std::string result;
std::locale loc;
for (char ch : str) {
result += std::tolower(ch, loc); // Use locale-aware tolower
}
return result;
}
void ContextRewind(std::vector<int> &embd, std::vector<int> ¤t_context_tokens, int &n_past, std::vector<int> &last_n_tokens, const int amount_rewind)
{
if(amount_rewind<=0 || current_context_tokens.size()==0)
{
return; //do nothing
}
if(embd.size()>1)
{
printf("\nWARNING: Don't use context rewind when in batch processing phase!\n");
return;
}
bool is_mamba = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_MAMBA);
bool is_rwkv_new = (file_format == FileFormat::GGUF_GENERIC && file_format_meta.model_architecture==GGUFArch::ARCH_RWKV);
if(file_format == FileFormat::RWKV_1 || file_format==FileFormat::RWKV_2 || is_mamba || is_rwkv_new)
{
printf("\nWARNING: RNN models do not support context rewind!\n");
return;
}
if (amount_rewind >= last_n_tokens.size())
{
last_n_tokens.clear();
}
else
{
last_n_tokens.resize(last_n_tokens.size() - amount_rewind);
}
if(amount_rewind >= top_picks_history.size())
{
top_picks_history.clear();
}
else
{
top_picks_history.resize(top_picks_history.size() - amount_rewind);
}
if (amount_rewind >= current_context_tokens.size())
{
current_context_tokens.clear();
}
else
{
current_context_tokens.resize(current_context_tokens.size() - amount_rewind);
}
if (amount_rewind >= n_past)
{
n_past = 0;
}
else
{
n_past -= amount_rewind;
}
if (file_format == FileFormat::GGUF_GENERIC)
{
llama_kv_cache_seq_rm(llama_ctx_v4, 0, n_past, -1);
}
embd.clear();
if(current_context_tokens.size()>0)
{
embd.push_back(current_context_tokens[current_context_tokens.size()-1]);
}
}
// KCPP SAMPLING FUNCTIONS
void sample_softmax(llama_token_data_array * cur_p) {
GGML_ASSERT(cur_p->size > 0);
// Sort the logits in descending order
if (!cur_p->sorted) {
std::sort(cur_p->data, cur_p->data + cur_p->size, [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
});
cur_p->sorted = true;
}
float max_l = cur_p->data[0].logit;
float cum_sum = 0.0f;
for (size_t i = 0; i < cur_p->size; ++i) {
float p = expf(cur_p->data[i].logit - max_l);
cur_p->data[i].p = p;
cum_sum += p;
}
for (size_t i = 0; i < cur_p->size; ++i) {
cur_p->data[i].p /= cum_sum;
}
}
void sample_top_k(llama_token_data_array * cur_p, int32_t k) {
// TODO: move bucket sort to separate function so that top_p/tail_free/typical/softmax first is equally fast
// if (k >= (int32_t)cur_p->size) {
// return;
// }
if (k <= 0) {
k = cur_p->size;
}
k = std::max(k, (int) 1); //min keep of 1
k = std::min(k, (int) cur_p->size);
// Sort scores in descending order
if (!cur_p->sorted) {
auto comp = [](const llama_token_data & a, const llama_token_data & b) {
return a.logit > b.logit;
};
if (k <= 128) {
std::partial_sort(cur_p->data, cur_p->data + k, cur_p->data + cur_p->size, comp);
} else {
constexpr int nbuckets = 128;
constexpr float bucket_low = -10.0f;
constexpr float bucket_high = 10.0f;
constexpr float bucket_scale = nbuckets/(bucket_high - bucket_low);
constexpr float bucket_inter = -bucket_low * bucket_scale;
std::vector<int> bucket_idx(cur_p->size);
std::vector<int> histo(nbuckets, 0);
for (int i = 0; i < (int)cur_p->size; ++i) {
const float val = cur_p->data[i].logit;
int ib = int(bucket_scale * val + bucket_inter); //nbuckets * (val - bucket_low) / (bucket_high - bucket_low);
ib = std::max(0, std::min(nbuckets-1, ib));
bucket_idx[i] = ib;
++histo[ib];
}
int nhave = 0;
int ib = nbuckets - 1;
for ( ; ib >= 0; --ib) {
nhave += histo[ib];
if (nhave >= k) {
break;
}
}
std::vector<llama_token_data> tmp_tokens(nhave);
auto * ptr = tmp_tokens.data();
std::vector<llama_token_data*> bucket_ptrs;
bucket_ptrs.reserve(nbuckets - ib);
for (int j = nbuckets - 1; j >= ib; --j) {
bucket_ptrs.push_back(ptr);
ptr += histo[j];
}
for (int i = 0; i < (int)cur_p->size; ++i) {
int j = bucket_idx[i];
if (j >= ib) {
*bucket_ptrs[nbuckets-1-j]++ = cur_p->data[i];
}
}
ptr = tmp_tokens.data();
int ndone = 0;
for (int j = nbuckets-1; j > ib; --j) {
std::sort(ptr, ptr + histo[j], comp);
ptr += histo[j];
ndone += histo[j];
}
std::partial_sort(ptr, ptr + k - ndone, ptr + histo[ib], comp);
std::memcpy(cur_p->data, tmp_tokens.data(), k*sizeof(llama_token_data));
}
cur_p->sorted = true;
}
cur_p->size = k;
}
llama_token sample_token(llama_token_data_array * candidates, std::mt19937 & rng)
{
sample_softmax(candidates);
std::vector<float> probs;
probs.reserve(candidates->size);
TopPicksData newpick;
for (size_t i = 0; i < candidates->size; ++i) {
probs.push_back(candidates->data[i].p);
}
std::discrete_distribution<> dist(probs.begin(), probs.end());
int idx = dist(rng);
newpick.selected_token = FileFormatTokenizeID(candidates->data[idx].id, file_format, true);
float rp1 = (candidates->data[idx].p<=0.0001?0.0001f:candidates->data[idx].p);
float sprob = logf(rp1);
sprob = (sprob > 999.0f?999.0f:sprob);
sprob = (sprob < -999.0f?-999.0f:sprob);
newpick.selected_logprob = sprob;
newpick.selected_probability = candidates->data[idx].p;
newpick.selected_tokenid = candidates->data[idx].id;
for (size_t i = 0; (i < candidates->size && i<logprobs_max); ++i)
{
newpick.tokens.push_back(FileFormatTokenizeID(candidates->data[i].id, file_format, true));
float rp2 = (candidates->data[i].p<=0.0001?0.0001f:candidates->data[i].p);
float prob = logf(rp2);
prob = (prob > 999.0f?999.0f:prob);
prob = (prob < -999.0f?-999.0f:prob);
newpick.logprobs.push_back(prob);
newpick.p.push_back(candidates->data[i].p);
newpick.tokenid.push_back(candidates->data[i].id);
}
top_picks_history.push_back(newpick);
llama_token result = candidates->data[idx].id;
return result;
}
llama_token sample_token_mirostat(int n_vocab, llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, int m, float * mu)
{
float N = float(n_vocab);
sample_softmax(candidates);
// Estimate s_hat using the most probable m tokens
float s_hat = 0.0;
float sum_ti_bi = 0.0;
float sum_ti_sq = 0.0;
for (size_t i = 0; i < size_t(m - 1) && i < candidates->size - 1; ++i) {
float t_i = logf(float(i + 2) / float(i + 1));
float b_i = logf(candidates->data[i].p / candidates->data[i + 1].p);
sum_ti_bi += t_i * b_i;
sum_ti_sq += t_i * t_i;
}
s_hat = sum_ti_bi / sum_ti_sq;
// Compute k from the estimated s_hat and target surprise value
float epsilon_hat = s_hat - 1;
float k = powf((epsilon_hat * powf(2, *mu)) / (1 - powf(N, -epsilon_hat)), 1 / s_hat);
// Sample the next word X using top-k sampling
sample_top_k(candidates, int(k));
llama_token X = sample_token(candidates, rng); // Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates->data[X_idx].p);
float e = observed_surprise - tau;
// Update mu using the learning rate and error
*mu = *mu - eta * e;
return X;
}
llama_token sample_token_mirostat_v2(llama_token_data_array * candidates, std::mt19937 & rng, float tau, float eta, float * mu)
{
sample_softmax(candidates);
// Truncate the words with surprise values greater than mu
candidates->size = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return -log2f(candidate.p) > *mu;
}));
if (candidates->size == 0) {
candidates->size = 1;
}
// Normalize the probabilities of the remaining words
sample_softmax(candidates);
// Sample the next word X from the remaining words
llama_token X = sample_token(candidates,rng);
// Compute error as the difference between observed surprise and target surprise value
size_t X_idx = std::distance(candidates->data, std::find_if(candidates->data, candidates->data + candidates->size, [&](const llama_token_data & candidate) {
return candidate.id == X;
}));
float observed_surprise = -log2f(candidates->data[X_idx].p);
float e = observed_surprise - tau;
// Update mu using the learning rate and error
*mu = *mu - eta * e;
return X;
}
// Top-a (remove all tokens that have softmax probability less than top_a*m^2 where m is the maximum softmax probability)
// top-a 0 is off (no effect)
void sample_top_a(llama_token_data_array * candidates, float a, size_t min_keep) {
if (a <= 0.0f || candidates->size<=1) {
return;
}
sample_softmax(candidates);
// Compute the cumulative probabilities
float maxprob = candidates->data[0].p;
float threshold = a * maxprob * maxprob; //tokens with probs less than this are removed
size_t last_idx = candidates->size;
for (size_t i = 0; i < candidates->size; ++i) {
// Go until we reach a value under the threshold
float checkprob = candidates->data[i].p;
if (checkprob < threshold && i >= min_keep) {
last_idx = i;
break;
}
}
// printf("\n\nCandidates: %d, A:%f, MaxProb: %f, Threshold: %f, LastIdx: %d",candidates->size,a,maxprob,threshold,last_idx);
// printf("\nCandidates: %f %f %f %f\n",candidates->data[0].p,candidates->data[1].p,candidates->data[2].p,candidates->data[3].p);
// Resize the output vector to keep only the selected tokens
candidates->size = last_idx;
}
void sample_xtc(llama_token_data_array * candidates, float xtc_threshold, float xtc_probability, std::mt19937 & rng)
{
if (xtc_threshold > 0.5f || xtc_probability <= 0.0f || candidates->size <= 1) {
return;
}
std::uniform_real_distribution<float> dist(0.0f, 1.0f);
float roll = dist(rng);
if(roll>=xtc_probability) //if dice roll fails, skip xtc
{
return;
}
sample_softmax(candidates);
//calculate how many tokens cross the xtc threshold
size_t last_idx = candidates->size;
for (size_t i = 0; i < candidates->size; ++i) {
// Go until we reach a value under the threshold
float checkprob = candidates->data[i].p;
if (checkprob < xtc_threshold) {
last_idx = i;
break;
}
}
if(last_idx>1) //if there are 2 or more viable candidates
{
if (debugmode==1) {
printf("XTC penalties [");
}
// then remove all other tokens above threshold EXCEPT the least likely one
for (size_t i = 0; i < last_idx - 1; ++i) {
if (debugmode==1)
{
gpt_vocab::id token = candidates->data[i].id;
std::string tokenizedstr = FileFormatTokenizeID(token, file_format);
::utreplace(tokenizedstr, "\n", "\\n");
printf("%s(%s %.02f%%)", i == 0 ? "" : " ", RemoveBell(tokenizedstr).c_str(), 100.f * candidates->data[i].p);
}
candidates->data[i].logit -= 999.0f; //infinity gets wonky results downstream, this hack works well enough
}
if (debugmode==1) {
printf("]\n");
}
candidates->sorted = false;
} //otherwise xtc does not do anything
// printf("\n\nCandidates: %d, Threshold: %f, LastIdx: %d",candidates->size,xtc_threshold,last_idx);
// printf("\nCandidates: %f %f %f %f\n",candidates->data[0].p,candidates->data[1].p,candidates->data[2].p,candidates->data[3].p);
}
void sample_dry(int n_ctx, int penalty_range, float penalty_multiplier, float penalty_base, int allowed_length, const std::unordered_multimap<gpt_vocab::id, std::vector<gpt_vocab::id>>& restart_sequences, llama_token_data_array * candidates) {
if (penalty_multiplier <= 0.0f || penalty_base <= 0.0f) {
return;
}
if (penalty_range <= 0 || penalty_range>n_ctx) {
penalty_range = n_ctx;
}
auto last_n_repeat = std::min(std::min((int)current_context_tokens.size(), penalty_range), n_ctx);
if (last_n_repeat <= allowed_length) {
return;
}
const llama_token * last_tokens = current_context_tokens.data() + current_context_tokens.size() - last_n_repeat;
dry_repeat_count.assign(last_n_repeat, 0);
dry_max_token_repeat.clear();
// Step 1: Look for restart sequences to limit the maximum repetition length.
// Work backwards through the context looking for any token that begins a restart sequence.
//
// The collection `restart_sequences` is a mapping from a "head" token to all "tail"
// sequences that together comprise a restart sequence. This allows us to quickly check
// whether each token is the head of a complete sequence. Most restart sequences are actually
// a single token, and for these the "tail" is an empty vector.
//
// If the token is a "head", test all restart sequences that begin with this token
// (there will often only be one sequence for each token, but if sequences like 'aaaq1' and
// 'aaa1' are used as restart strings, both could start with 'aaa' when tokenized). The
// longest matching sequence (if any) is used to limit the maximum repetition length.
//
// Note that in the case case of a short sequence contained in a longer one, this might fail to
// find the smallest value for `rep_limit`. For example, if 'amniotic' and 'ni' are both used as
// restart sequences, 'ni' will be found first, and since it's shorter it will fail to suppress
// 'otic'. This is a minor issue since fully contained restart sequences are likely to be rare.
//
// This is theoretically worst-case O(N^2) for arbitrary restart sequences, which is why we
// have already clamped the maximum tail sequence length when generating `restart_sequences`.
// With clamping, this scan is O(N) in the context length.
int rep_limit = last_n_repeat;
for (size_t i = 0; i < last_n_repeat; ++i) {
size_t ix = last_n_repeat - 1 - i;
auto its = restart_sequences.equal_range(last_tokens[ix]);
if (its.first == restart_sequences.end()) {
continue;
}
int longest_match = -1;
for (auto it = its.first; it != its.second; ++it) {
// Note that (*it) does not contain the head character, so seq_len will be
// the restart sequence length minus 1.
// In the common case of a single-token restart sequence, (*it) will be empty
// and we will trivially match.
int seq_len = (int)it->second.size();
if (seq_len > longest_match && seq_len <= i) {
bool match = true;
for (size_t offset = 0; offset < seq_len; ++offset) {
// The +1 when indexing `last_tokens` is because we already matched the head.
if (it->second[offset] != last_tokens[ix + 1 + offset]) {
match = false;
break;
}
}
if (match) {
longest_match = seq_len;
}
}
}
if (longest_match >= 0) {
// We found a restart sequence starting `i` tokens from the end and continuing for
// `longest_match` tokens.
rep_limit = (int)i - longest_match;
break;
}
}
if (rep_limit <= allowed_length) {
return;
}
// Step 2: Iterate in reverse over the last N tokens of the context, using the "Z-algorithm" (in
// the reverse direction) to efficiently compute the positions and lengths of suffixes appearing
// elsewhere in the context. We limit the suffix length to `rep_limit` to respect restart sequences.
//
// This algorithm is not currently documented on Wikipedia, but there is a clear description here:
// https://ivanyu.me/blog/2014/10/15/z-algorithm/
//
// The code below is adapted from the public domain implementation by the same author here:
// https://github.com/ivanyu/string-algorithms/blob/master/z_algorithm.py
//
// Example:
// Last N tokens: a b c c b c y a b c
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
// ^
// This `3` means that the last three tokens of the context (a b c) also appear here.
//
// This step is worst case O(N) since the Z-algorithm is linear, despite the appearance of nested
// for/while loops. This can be seen by observing that the `lt` and `rt` bounds are set after each
// repeated suffix is detected (i.e. after each while loop when n > 0). These bound variables
// ensure that the inner while loops only examine each token in the context once as the outer
// for loop iterates over the context.
{
const int last = last_n_repeat - 1;
int rt = 0, lt = 0;
for (int k = 1; k < last_n_repeat; ++k) {
if (k > rt) {
// If k is outside the current Z-box, do naive computation.
int n = 0;
while (n + k < last_n_repeat && last_tokens[last - n] == last_tokens[last - (n+k)]) {
++n;
}
dry_repeat_count[last - k] = std::min(n, rep_limit);
if (n > 0) {
lt = k;
rt = k+n-1;
}
} else {
// If k is inside the current Z-box, consider two cases.
int p = k - lt; // Pair index.
int right_part_len = rt - k + 1;
if (dry_repeat_count[last - p] < right_part_len) {
int n = std::min(dry_repeat_count[last - p], rep_limit);
dry_repeat_count[last - k] = n;
} else {
int i = rt + 1;
while (i < last_n_repeat && last_tokens[last - i] == last_tokens[last - (i - k)]) {
i += 1;
}
int n = std::min(i - k, rep_limit);
dry_repeat_count[last - k] = n;
lt = k;
rt = i - 1;
}
}
}
}
// Step 3: Iterate over dry_repeat_count and last_tokens, examining the maximum repeat length
// that would be generated by emitting each new token that would extend a sequence.
//
// Following the same example as above:
// Last N tokens: a b c c b c y a b c
// Repeat counts: 0 0 3 1 0 2 0 0 0 0
//
// For each non-zero, look ahead one token. This token, if emitted, would extend the repetition.
// c: 3 -> 4 (from `a b c` to `a b c c`)
// b: 1 -> 2 (from `c` to `c b`)
// y: 2 -> 3 (from `b c` to `b c y`)
for (size_t i = 0; i < last_n_repeat - 1; ++i) {
int repeat_len = dry_repeat_count[i];
if (repeat_len >= allowed_length) {
// This token ends a repeat, so the next token would continue one.
// By convention, the value of `repeat_len` only includes the tokens currently
// in the context, not the new token that would be added.
gpt_vocab::id token = last_tokens[i + 1];
// Track the maximum sequence ending in this token.
const auto& it = dry_max_token_repeat.find(token);
if (it == dry_max_token_repeat.end() || it->second < repeat_len) {
dry_max_token_repeat[token] = repeat_len;
}
}
}
// Step 4: Apply logit penalties based on the maximum repeat length for relevant tokens.
// Prevent floating point overflow in `pow(penalty_base, exponent)` by clamping to `max_exponent`.
// Compute it from `penalty_base` and the approximate log of `std::numeric_limits<float>::max()`
const float FLOAT_MAX_LOG = 88.7228391f;
int max_exponent = 0;
if (penalty_base > 1.000001f) {
max_exponent = FLOAT_MAX_LOG / std::log(penalty_base);
}
if (debugmode==1 && !dry_max_token_repeat.empty()) {
printf("DRY penalties [");
}
size_t count = 0;
for (const auto& kvp: dry_max_token_repeat) {
gpt_vocab::id token = kvp.first;
int repeat_exp = kvp.second - allowed_length;
if (max_exponent > 0 && repeat_exp > max_exponent) {
repeat_exp = max_exponent;
}
float penalty = penalty_multiplier * pow(penalty_base, repeat_exp);
if (debugmode==1)
{
std::string tokenizedstr = FileFormatTokenizeID(token, file_format);
::utreplace(tokenizedstr, "\n", "\\n");
printf("%s(%s %.02f)", count == 0 ? "" : " ", RemoveBell(tokenizedstr).c_str(), penalty);
}
candidates->data[token].logit -= penalty;
++count;
}
if(count>0)
{
candidates->sorted = false;
}
if (debugmode==1 && !dry_max_token_repeat.empty()) {
printf("]\n");
}
}
void sample_rep_pen(int n_ctx, int rep_pen_range, float rep_pen, float rep_pen_slope, float presence_penalty, llama_token_data_array * candidates_p)
{
auto last_n_repeat = std::min(std::min((int)last_n_tokens.size(), rep_pen_range), n_ctx);
const llama_token * last_tokens = last_n_tokens.data() + last_n_tokens.size() - last_n_repeat;
size_t last_tokens_size = last_n_repeat;
llama_token_data_array * candidates = candidates_p;
if (last_tokens_size == 0 || (rep_pen == 1.0f && presence_penalty==0)) {
return;
}
const int64_t t_start_sample_us = ggml_time_us();
// Create a frequency map to count occurrences of each token in last_tokens
std::unordered_map<llama_token, int> token_count_near;
std::unordered_map<llama_token, int> token_count_far;
for (size_t i = 0; i < last_n_repeat; ++i) {