From 36a9335588a45fd9a038dfc1b5c576dd002551cc Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 19 Jan 2024 23:26:36 +0000 Subject: [PATCH] Removed `LLamaBatchSafeHandle` (using unmanaged memory, created by llama.cpp) and replaced it with a fully managed `LLamaBatch`. Modified the `BatchedDecoding` example to use new managed batch. --- LLama.Examples/Examples/BatchedDecoding.cs | 21 +-- LLama/Native/LLamaBatch.cs | 121 ++++++++++++++++ LLama/Native/LLamaBatchSafeHandle.cs | 158 --------------------- LLama/Native/LLamaNativeBatch.cs | 2 +- LLama/Native/LLamaSeqId.cs | 5 + LLama/Native/SafeLLamaContextHandle.cs | 6 +- 6 files changed, 137 insertions(+), 176 deletions(-) create mode 100644 LLama/Native/LLamaBatch.cs delete mode 100644 LLama/Native/LLamaBatchSafeHandle.cs diff --git a/LLama.Examples/Examples/BatchedDecoding.cs b/LLama.Examples/Examples/BatchedDecoding.cs index 306e74c13..ee2936d05 100644 --- a/LLama.Examples/Examples/BatchedDecoding.cs +++ b/LLama.Examples/Examples/BatchedDecoding.cs @@ -52,18 +52,11 @@ public static async Task Run() return; } - using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1); + var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1); // evaluate the initial prompt for (var i = 0; i < prompt_tokens.Length; i++) - batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false); - Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length); - - // llama_decode will output logits only for the last token of the prompt - unsafe - { - batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1; - } + batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1); if (context.NativeHandle.Decode(batch) != 0) { @@ -75,7 +68,7 @@ public static async Task Run() // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them for (var i = 1; i < n_parallel; ++i) { - NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens); + NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount); } if (n_parallel > 1) @@ -88,9 +81,9 @@ public static async Task Run() // we need this to determine which logits to sample from List i_batch = new(); for (var i = 0; i < n_parallel; i++) - i_batch.Add(batch.NativeBatch.n_tokens - 1); + i_batch.Add(batch.TokenCount - 1); - var n_cur = batch.NativeBatch.n_tokens; + var n_cur = batch.TokenCount; var n_decode = 0; var streams = new List[n_parallel]; @@ -133,7 +126,7 @@ public static async Task Run() streams[i].Add(new_token_id); - i_batch[i] = batch.NativeBatch.n_tokens; + i_batch[i] = batch.TokenCount; // push this new token for next evaluation batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); @@ -142,7 +135,7 @@ public static async Task Run() } // all streams are finished - if (batch.NativeBatch.n_tokens == 0) + if (batch.TokenCount == 0) { break; } diff --git a/LLama/Native/LLamaBatch.cs b/LLama/Native/LLamaBatch.cs new file mode 100644 index 000000000..e4dc6af64 --- /dev/null +++ b/LLama/Native/LLamaBatch.cs @@ -0,0 +1,121 @@ +using System; + +namespace LLama.Native; + +/// +/// A batch allows submitting multiple tokens to multiple sequences simultaneously +/// +public class LLamaBatch +{ + private readonly byte[] _logits; + + private readonly LLamaToken[] _tokens; + private readonly LLamaPos[] _positions; + + private readonly int[] _sequenceIdCount; + private readonly LLamaSeqId[][] _sequenceIds; + private readonly IntPtr[] _sequenceIdsPtrs; + + /// + /// The number of tokens in this batch + /// + public int TokenCount { get; private set; } + + /// + /// Create a new batch for submitting inputs to llama.cpp + /// + /// + /// + public LLamaBatch(int n_tokens, int n_seq_max) + { + _logits = new byte[n_tokens]; + _tokens = new LLamaToken[n_tokens]; + _positions = new LLamaPos[n_tokens]; + + _sequenceIdCount = new int[n_tokens]; + _sequenceIdsPtrs = new IntPtr[_sequenceIdCount.Length]; + + _sequenceIds = new LLamaSeqId[n_tokens][]; + for (var i = 0; i < _sequenceIds.Length; i++) + _sequenceIds[i] = new LLamaSeqId[n_seq_max]; + } + + internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) + { + // This group holds all of the memory pins + var group = new GroupDisposable(); + + unsafe + { + batch = new LLamaNativeBatch + { + n_tokens = TokenCount, + logits = (byte*)group.Add(_logits.AsMemory().Pin()).Pointer, + + n_seq_id = (int*)group.Add(_sequenceIdCount.AsMemory().Pin()).Pointer, + pos = (LLamaPos*)group.Add(_positions.AsMemory().Pin()).Pointer, + seq_id = (LLamaSeqId**)group.Add(_sequenceIdsPtrs.AsMemory().Pin()).Pointer, + + // embd is not currently supported, so this is always null! + embd = null, + + // Note that if embd is **not null** then this will be null! + tokens = (LLamaToken*)group.Add(_tokens.AsMemory().Pin()).Pointer, + }; + + // Create pointers to each of the arrays in turns + for (var i = 0; i < _sequenceIdsPtrs.Length; i++) + _sequenceIdsPtrs[i] = (IntPtr)group.Add(_sequenceIds[i].AsMemory().Pin()).Pointer; + } + + return group; + } + + /// + /// Add a single token to the batch at the same position in several sequences + /// + /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 + /// The token to add + /// The position to add it att + /// The set of sequences to add this token to + /// + public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) + { + _tokens[TokenCount] = token; + _positions[TokenCount] = pos; + + _sequenceIdCount[TokenCount] = sequences.Length; + for (var i = 0; i < sequences.Length; i++) + _sequenceIds[TokenCount][i] = sequences[i]; + + _logits[TokenCount] = Convert.ToByte(logits); + + TokenCount++; + } + + /// + /// Add a single token to the batch at a certain position for a single sequences + /// + /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 + /// The token to add + /// The position to add it att + /// The sequence to add this token to + /// + public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) + { + // Create a temporary span to contain 1 item without allocating + Span sequences = stackalloc LLamaSeqId[1]; + sequences[0] = sequence; + + // Add it + LLamaBatchAdd(token, pos, sequences, logits); + } + + /// + /// Set TokenCount to zero for this batch + /// + public void LLamaBatchClear() + { + TokenCount = 0; + } +} \ No newline at end of file diff --git a/LLama/Native/LLamaBatchSafeHandle.cs b/LLama/Native/LLamaBatchSafeHandle.cs deleted file mode 100644 index 4198ad02a..000000000 --- a/LLama/Native/LLamaBatchSafeHandle.cs +++ /dev/null @@ -1,158 +0,0 @@ -using System; - -namespace LLama.Native; - -/// -/// Input data for llama_decode. A llama_batch object can contain input about one or many sequences. -/// -public sealed class LLamaBatchSafeHandle - : SafeLLamaHandleBase -{ - private readonly int _embd; - - /// - /// Get the native llama_batch struct - /// - public LLamaNativeBatch NativeBatch; - - /// - /// the token ids of the input (used when embd is NULL) - /// - public Span Token - { - get - { - unsafe - { - if (_embd != 0) - return new Span(null, 0); - else - return new Span(NativeBatch.token, NativeBatch.n_tokens); - } - } - } - - /// - /// token embeddings (i.e. float vector of size n_embd) (used when token is NULL) - /// - public Span Embed - { - get - { - unsafe - { - // If embd != 0, llama_batch.embd will be allocated with size of n_tokens *embd * sizeof(float) - // Otherwise, llama_batch.token will be allocated to store n_tokens llama_token - - if (_embd != 0) - return new Span(NativeBatch.embd, NativeBatch.n_tokens * _embd); - else - return new Span(null, 0); - } - } - } - - /// - /// the positions of the respective token in the sequence - /// - public Span Pos - { - get - { - unsafe - { - return new Span(NativeBatch.pos, NativeBatch.n_tokens); - } - } - } - - /// - /// the sequence to which the respective token belongs - /// - public Span Sequence_ID - { - get - { - unsafe - { - return new Span(NativeBatch.seq_id, NativeBatch.n_tokens); - } - } - } - - /// - /// if zero, the logits for the respective token will not be output - /// - public Span Logits - { - get - { - unsafe - { - return new Span(NativeBatch.logits, NativeBatch.n_tokens); - } - } - } - - /// - /// Create a safe handle owning a `LLamaNativeBatch` - /// - /// - /// - public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd) - : base((nint)1) - { - _embd = embd; - NativeBatch = batch; - } - - /// - /// Call `llama_batch_init` and create a new batch - /// - /// - /// - /// - /// - public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max) - { - var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max); - return new LLamaBatchSafeHandle(batch, embd); - } - - /// - protected override bool ReleaseHandle() - { - NativeApi.llama_batch_free(NativeBatch); - NativeBatch = default; - SetHandle(IntPtr.Zero); - return true; - } - - /// - /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 - /// - public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan sequences, bool logits) - { - unsafe - { - NativeBatch.token[NativeBatch.n_tokens] = token; - NativeBatch.pos[NativeBatch.n_tokens] = pos; - NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length; - - for (var i = 0; i < sequences.Length; i++) - NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i]; - - NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits); - - NativeBatch.n_tokens++; - } - } - - /// - /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825 - /// - public void LLamaBatchClear() - { - NativeBatch.n_tokens = 0; - } -} \ No newline at end of file diff --git a/LLama/Native/LLamaNativeBatch.cs b/LLama/Native/LLamaNativeBatch.cs index 978e955c3..d58c61788 100644 --- a/LLama/Native/LLamaNativeBatch.cs +++ b/LLama/Native/LLamaNativeBatch.cs @@ -18,7 +18,7 @@ public unsafe struct LLamaNativeBatch /// /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created /// - public LLamaToken* token; + public LLamaToken* tokens; /// /// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created diff --git a/LLama/Native/LLamaSeqId.cs b/LLama/Native/LLamaSeqId.cs index bcee74f13..8a3dae5d8 100644 --- a/LLama/Native/LLamaSeqId.cs +++ b/LLama/Native/LLamaSeqId.cs @@ -8,6 +8,11 @@ namespace LLama.Native; [StructLayout(LayoutKind.Sequential)] public record struct LLamaSeqId { + /// + /// LLamaSeqId with value 0 + /// + public static readonly LLamaSeqId Zero = new LLamaSeqId(0); + /// /// The raw value /// diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 3f303123b..b10e083f0 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,5 +1,4 @@ using System; -using System.Buffers; using System.Runtime.InteropServices; using System.Text; using LLama.Exceptions; @@ -198,9 +197,10 @@ public bool Eval(ReadOnlySpan tokens, int n_past) /// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
/// - < 0: error
/// - public int Decode(LLamaBatchSafeHandle batch) + public int Decode(LLamaBatch batch) { - return NativeApi.llama_decode(this, batch.NativeBatch); + using (batch.ToNativeBatch(out var nb)) + return NativeApi.llama_decode(this, nb); } #region state