diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs index 59e8869ab..1d55ff12f 100644 --- a/LLama.Examples/Examples/BatchedDecoding.cs +++ b/LLama.Examples/Examples/BatchedDecoding.cs @@ -34,7 +34,7 @@ public static async Task Run() using var model = LLamaWeights.LoadFromFile(parameters); // Tokenize prompt - var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8); + var prompt_tokens = model.Tokenize(prompt, true, false, Encoding.UTF8); var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel; // Create a context @@ -86,9 +86,9 @@ public static async Task Run() var n_cur = batch.TokenCount; var n_decode = 0; - var streams = new List[n_parallel]; + var streams = new StreamingTokenDecoder[n_parallel]; for (var i = 0; i < n_parallel; i++) - streams[i] = new(); + streams[i] = new StreamingTokenDecoder(context); var eos = model.EndOfSentenceToken; var nl = model.NewlineToken; @@ -159,7 +159,7 @@ public static async Task Run() var index = 0; foreach (var stream in streams) { - var text = context.DeTokenize(stream); + var text = stream.Read(); Console.ForegroundColor = ConsoleColor.Green; Console.Write($"{index++}. {prompt}"); diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index df0eadcad..0adb19875 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -1,5 +1,6 @@ using System; using System.Collections.Generic; +using System.Text; using LLama.Abstractions; using LLama.Extensions; using LLama.Native; @@ -109,5 +110,18 @@ public LLamaContext CreateContext(IContextParams @params, ILogger? logger = null { return new LLamaContext(this, @params, logger); } + + /// + /// Convert a string of text into tokens + /// + /// + /// + /// + /// Allow tokenizing special and/or control tokens which otherwise are not exposed and treated as plaintext. + /// + public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding encoding) + { + return NativeHandle.Tokenize(text, add_bos, special, encoding); + } } }