diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs
index 1d55ff12f..c51c49be4 100644
--- a/LLama.Examples/Examples/BatchedDecoding.cs
+++ b/LLama.Examples/Examples/BatchedDecoding.cs
@@ -52,7 +52,7 @@ public static async Task Run()
return;
}
- var batch = new LLamaBatch(1);
+ var batch = new LLamaBatch();
// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs
index 9abb15ae9..20e145306 100644
--- a/LLama/Native/LLamaBatch.cs
+++ b/LLama/Native/LLamaBatch.cs
@@ -1,5 +1,4 @@
using System;
-using System.Collections.Generic;
namespace LLama.Native;
@@ -35,11 +34,11 @@ public class LLamaBatch
///
/// Create a new batch for submitting inputs to llama.cpp
///
- /// Max number of sequences a token can be assigned to
- public LLamaBatch(int n_seq_max)
+ public LLamaBatch()
{
- // The number of tokens can be grown later, start off with a reasonable guess.
- const int n_tokens = 64;
+ // These can both be grown later, start off with reasonable numbers.
+ const int n_tokens = 128;
+ const int n_seq_max = 1;
MaxSequences = n_seq_max;
TokenCapacity = n_tokens;
@@ -56,7 +55,7 @@ public LLamaBatch(int n_seq_max)
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
}
- private void Grow()
+ private void GrowTokenCapacity()
{
var n_tokens = TokenCount * 2;
TokenCapacity = n_tokens;
@@ -78,6 +77,15 @@ private void Grow()
}
}
+ private void GrowMaxSequences(int atLeast)
+ {
+ var n_seq = Math.Max(MaxSequences * 2, atLeast);
+ MaxSequences = n_seq;
+
+ for (var i = 0; i < _sequenceIds.Length; i++)
+ Array.Resize(ref _sequenceIds[i], MaxSequences);
+ }
+
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
{
// This group holds all of the memory pins
@@ -120,7 +128,9 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits)
{
if (TokenCount == TokenCapacity)
- Grow();
+ GrowTokenCapacity();
+ if (sequences.Length > MaxSequences)
+ GrowMaxSequences(sequences.Length);
_tokens[TokenCount] = token;
_positions[TokenCount] = pos;