diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs index 1d55ff12f..c51c49be4 100644 --- a/LLama.Examples/Examples/BatchedDecoding.cs +++ b/LLama.Examples/Examples/BatchedDecoding.cs @@ -52,7 +52,7 @@ public static async Task Run() return; } - var batch = new LLamaBatch(1); + var batch = new LLamaBatch(); // evaluate the initial prompt for (var i = 0; i < prompt_tokens.Length; i++) diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs index 9abb15ae9..20e145306 100644 --- a/LLama/Native/LLamaBatch.cs +++ b/LLama/Native/LLamaBatch.cs @@ -1,5 +1,4 @@ using System; -using System.Collections.Generic; namespace LLama.Native; @@ -35,11 +34,11 @@ public class LLamaBatch /// /// Create a new batch for submitting inputs to llama.cpp /// - /// Max number of sequences a token can be assigned to - public LLamaBatch(int n_seq_max) + public LLamaBatch() { - // The number of tokens can be grown later, start off with a reasonable guess. - const int n_tokens = 64; + // These can both be grown later, start off with reasonable numbers. + const int n_tokens = 128; + const int n_seq_max = 1; MaxSequences = n_seq_max; TokenCapacity = n_tokens; @@ -56,7 +55,7 @@ public LLamaBatch(int n_seq_max) _sequenceIds[i] = new LLamaSeqId[MaxSequences]; } - private void Grow() + private void GrowTokenCapacity() { var n_tokens = TokenCount * 2; TokenCapacity = n_tokens; @@ -78,6 +77,15 @@ private void Grow() } } + private void GrowMaxSequences(int atLeast) + { + var n_seq = Math.Max(MaxSequences * 2, atLeast); + MaxSequences = n_seq; + + for (var i = 0; i < _sequenceIds.Length; i++) + Array.Resize(ref _sequenceIds[i], MaxSequences); + } + internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) { // This group holds all of the memory pins @@ -120,7 +128,9 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) { if (TokenCount == TokenCapacity) - Grow(); + GrowTokenCapacity(); + if (sequences.Length > MaxSequences) + GrowMaxSequences(sequences.Length); _tokens[TokenCount] = token; _positions[TokenCount] = pos;