diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs index ee2936d05..59e8869ab 100644 --- a/LLama.Examples/Examples/BatchedDecoding.cs +++ b/LLama.Examples/Examples/BatchedDecoding.cs @@ -52,13 +52,13 @@ public static async Task Run() return; } - var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1); + var batch = new LLamaBatch(1); // evaluate the initial prompt for (var i = 0; i < prompt_tokens.Length; i++) - batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1); + batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1); - if (context.NativeHandle.Decode(batch) != 0) + if (await context.DecodeAsync(batch) != 0) { await Console.Error.WriteLineAsync("llama_decode failed"); return; @@ -97,7 +97,7 @@ public static async Task Run() timer.Start(); while (n_cur <= n_len) { - batch.LLamaBatchClear(); + batch.Clear(); for (var i = 0; i < n_parallel; i++) { @@ -129,7 +129,7 @@ public static async Task Run() i_batch[i] = batch.TokenCount; // push this new token for next evaluation - batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); + batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); n_decode++; } @@ -143,7 +143,7 @@ public static async Task Run() n_cur++; // evaluate the current batch with the transformer model - if (context.NativeHandle.Decode(batch) != 0) + if (await context.DecodeAsync(batch) != 0) { await Console.Error.WriteLineAsync("failed to eval"); return; diff --git a/LLama.Unittest/BeamTests.cs b/LLama.Unittest/BeamTests.cs index 3014894ec..994860886 100644 --- a/LLama.Unittest/BeamTests.cs +++ b/LLama.Unittest/BeamTests.cs @@ -40,7 +40,7 @@ public void BasicBeam() var initial_tokens = context.Tokenize(prompt); result.Append(prompt); - context.Eval(initial_tokens, 0); + context.Eval(initial_tokens.AsSpan(), 0); NativeApi.llama_beam_search(context.NativeHandle, (data, state) => { diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 72e9acf87..8d4be20cb 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -36,7 +36,7 @@ public async Task Stateless() var executor = new StatelessExecutor(_weights, _params); - const string question = "Question. what is a cat?\nAnswer: "; + const string question = "Question. what is a cat?\nAnswer:"; var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline }; var timer = new Stopwatch(); diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index dd3d081a4..ea745d029 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -8,10 +8,12 @@ using System.IO.MemoryMappedFiles; using LLama.Common; using System.Runtime.InteropServices; +using System.Threading.Tasks; using LLama.Extensions; using LLama.Abstractions; using LLama.Sampling; using Microsoft.Extensions.Logging; +using System.Threading; namespace LLama { @@ -344,16 +346,30 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dict #region eval overloads /// - /// /// - /// - /// - /// The updated `pastTokensCount`. - /// - [Obsolete("use llama_decode() instead")] - public int Eval(LLamaToken[] tokens, int pastTokensCount) + /// + /// Positive return values does not mean a fatal error, but rather a warning:
+ /// - 0: success
+ /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ /// - < 0: error
+ ///
+ public int Decode(LLamaBatch batch) + { + return NativeHandle.Decode(batch); + } + + /// + /// + /// + /// + /// Positive return values does not mean a fatal error, but rather a warning:
+ /// - 0: success
+ /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
+ /// - < 0: error
+ ///
+ public Task DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default) { - return Eval(tokens.AsSpan(), pastTokensCount); + return Task.Run(() => NativeHandle.Decode(batch), cancellationToken); } /// @@ -363,7 +379,7 @@ public int Eval(LLamaToken[] tokens, int pastTokensCount) /// /// The updated `pastTokensCount`. /// - [Obsolete("use llama_decode() instead")] + [Obsolete("use Decode() instead")] public int Eval(List tokens, int pastTokensCount) { #if NET5_0_OR_GREATER @@ -394,20 +410,7 @@ public int Eval(List tokens, int pastTokensCount) /// /// The updated `pastTokensCount`. /// - [Obsolete("use llama_decode() instead")] - public int Eval(ReadOnlyMemory tokens, int pastTokensCount) - { - return Eval(tokens.Span, pastTokensCount); - } - - /// - /// - /// - /// - /// - /// The updated `pastTokensCount`. - /// - [Obsolete("use llama_decode() instead")] + [Obsolete("use Decode() instead")] public int Eval(ReadOnlySpan tokens, int pastTokensCount) { var total = tokens.Length; diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index 0c6cc87ca..bccfd1416 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -75,7 +75,7 @@ public float[] GetEmbeddings(string text, bool addBos) // TODO(Rinne): deal with log of prompt if (embed_inp_array.Length > 0) - Context.Eval(embed_inp_array, 0); + Context.Eval(embed_inp_array.AsSpan(), 0); var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); if (embeddings == null) diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index e4dc6af64..9abb15ae9 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -1,4 +1,5 @@ using System; +using System.Collections.Generic; namespace LLama.Native; @@ -7,27 +8,42 @@ namespace LLama.Native; /// public class LLamaBatch { - private readonly byte[] _logits; + private byte[] _logits; - private readonly LLamaToken[] _tokens; - private readonly LLamaPos[] _positions; + private LLamaToken[] _tokens; + private LLamaPos[] _positions; - private readonly int[] _sequenceIdCount; - private readonly LLamaSeqId[][] _sequenceIds; - private readonly IntPtr[] _sequenceIdsPtrs; + private int[] _sequenceIdCount; + private LLamaSeqId[][] _sequenceIds; + private IntPtr[] _sequenceIdsPtrs; /// /// The number of tokens in this batch /// public int TokenCount { get; private set; } + /// + /// Maximum number of tokens that can be added to this batch + /// + private int TokenCapacity { get; set; } + + /// + /// Maximum number of sequences a token can be assigned to + /// + public int MaxSequences { get; private set; } + /// /// Create a new batch for submitting inputs to llama.cpp /// - /// - /// - public LLamaBatch(int n_tokens, int n_seq_max) + /// Max number of sequences a token can be assigned to + public LLamaBatch(int n_seq_max) { + // The number of tokens can be grown later, start off with a reasonable guess. + const int n_tokens = 64; + + MaxSequences = n_seq_max; + TokenCapacity = n_tokens; + _logits = new byte[n_tokens]; _tokens = new LLamaToken[n_tokens]; _positions = new LLamaPos[n_tokens]; @@ -37,7 +53,29 @@ public LLamaBatch(int n_tokens, int n_seq_max) _sequenceIds = new LLamaSeqId[n_tokens][]; for (var i = 0; i < _sequenceIds.Length; i++) - _sequenceIds[i] = new LLamaSeqId[n_seq_max]; + _sequenceIds[i] = new LLamaSeqId[MaxSequences]; + } + + private void Grow() + { + var n_tokens = TokenCount * 2; + TokenCapacity = n_tokens; + + Array.Resize(ref _logits, n_tokens); + Array.Resize(ref _tokens, n_tokens); + Array.Resize(ref _positions, n_tokens); + + Array.Resize(ref _sequenceIdCount, n_tokens); + Array.Resize(ref _sequenceIdsPtrs, n_tokens); + + Array.Resize(ref _sequenceIds, n_tokens); + for (int i = 0; i < _sequenceIds.Length; i++) + { + // Growing the array filled elements with null, temporarily violating the nullability contract! + // ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract + if (_sequenceIds[i] == null) + _sequenceIds[i] = new LLamaSeqId[MaxSequences]; + } } internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) @@ -79,8 +117,11 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) /// The position to add it att /// The set of sequences to add this token to /// - public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) + public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) { + if (TokenCount == TokenCapacity) + Grow(); + _tokens[TokenCount] = token; _positions[TokenCount] = pos; @@ -101,20 +142,20 @@ public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpanThe position to add it att /// The sequence to add this token to /// - public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) + public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) { // Create a temporary span to contain 1 item without allocating Span sequences = stackalloc LLamaSeqId[1]; sequences[0] = sequence; // Add it - LLamaBatchAdd(token, pos, sequences, logits); + Add(token, pos, sequences, logits); } /// /// Set TokenCount to zero for this batch /// - public void LLamaBatchClear() + public void Clear() { TokenCount = 0; }