Skip to content

Commit

Permalink
Merge pull request #442 from martindevans/managed_llama_batch
Browse files Browse the repository at this point in the history
Managed `LLamaBatch`
  • Loading branch information
martindevans authored Jan 20, 2024
2 parents 4b11fed + 36a9335 commit a0be27d
Show file tree
Hide file tree
Showing 6 changed files with 137 additions and 176 deletions.
21 changes: 7 additions & 14 deletions LLama.Examples/Examples/BatchedDecoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,11 @@ public static async Task Run()
return;
}

using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1);
var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1);

// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false);
Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length);

// llama_decode will output logits only for the last token of the prompt
unsafe
{
batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
}
batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);

if (context.NativeHandle.Decode(batch) != 0)
{
Expand All @@ -75,7 +68,7 @@ public static async Task Run()
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (var i = 1; i < n_parallel; ++i)
{
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens);
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.TokenCount);
}

if (n_parallel > 1)
Expand All @@ -88,9 +81,9 @@ public static async Task Run()
// we need this to determine which logits to sample from
List<int> i_batch = new();
for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.NativeBatch.n_tokens - 1);
i_batch.Add(batch.TokenCount - 1);

var n_cur = batch.NativeBatch.n_tokens;
var n_cur = batch.TokenCount;
var n_decode = 0;

var streams = new List<LLamaToken>[n_parallel];
Expand Down Expand Up @@ -133,7 +126,7 @@ public static async Task Run()

streams[i].Add(new_token_id);

i_batch[i] = batch.NativeBatch.n_tokens;
i_batch[i] = batch.TokenCount;

// push this new token for next evaluation
batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
Expand All @@ -142,7 +135,7 @@ public static async Task Run()
}

// all streams are finished
if (batch.NativeBatch.n_tokens == 0)
if (batch.TokenCount == 0)
{
break;
}
Expand Down
121 changes: 121 additions & 0 deletions LLama/Native/LLamaBatch.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
using System;

namespace LLama.Native;

/// <summary>
/// A batch allows submitting multiple tokens to multiple sequences simultaneously
/// </summary>
public class LLamaBatch
{
private readonly byte[] _logits;

private readonly LLamaToken[] _tokens;
private readonly LLamaPos[] _positions;

private readonly int[] _sequenceIdCount;
private readonly LLamaSeqId[][] _sequenceIds;
private readonly IntPtr[] _sequenceIdsPtrs;

/// <summary>
/// The number of tokens in this batch
/// </summary>
public int TokenCount { get; private set; }

/// <summary>
/// Create a new batch for submitting inputs to llama.cpp
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="n_seq_max"></param>
public LLamaBatch(int n_tokens, int n_seq_max)
{
_logits = new byte[n_tokens];
_tokens = new LLamaToken[n_tokens];
_positions = new LLamaPos[n_tokens];

_sequenceIdCount = new int[n_tokens];
_sequenceIdsPtrs = new IntPtr[_sequenceIdCount.Length];

_sequenceIds = new LLamaSeqId[n_tokens][];
for (var i = 0; i < _sequenceIds.Length; i++)
_sequenceIds[i] = new LLamaSeqId[n_seq_max];
}

internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
{
// This group holds all of the memory pins
var group = new GroupDisposable();

unsafe
{
batch = new LLamaNativeBatch
{
n_tokens = TokenCount,
logits = (byte*)group.Add(_logits.AsMemory().Pin()).Pointer,

n_seq_id = (int*)group.Add(_sequenceIdCount.AsMemory().Pin()).Pointer,
pos = (LLamaPos*)group.Add(_positions.AsMemory().Pin()).Pointer,
seq_id = (LLamaSeqId**)group.Add(_sequenceIdsPtrs.AsMemory().Pin()).Pointer,

// embd is not currently supported, so this is always null!
embd = null,

// Note that if embd is **not null** then this will be null!
tokens = (LLamaToken*)group.Add(_tokens.AsMemory().Pin()).Pointer,
};

// Create pointers to each of the arrays in turns
for (var i = 0; i < _sequenceIdsPtrs.Length; i++)
_sequenceIdsPtrs[i] = (IntPtr)group.Add(_sequenceIds[i].AsMemory().Pin()).Pointer;
}

return group;
}

/// <summary>
/// Add a single token to the batch at the same position in several sequences
/// </summary>
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks>
/// <param name="token">The token to add</param>
/// <param name="pos">The position to add it att</param>
/// <param name="sequences">The set of sequences to add this token to</param>
/// <param name="logits"></param>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
_tokens[TokenCount] = token;
_positions[TokenCount] = pos;

_sequenceIdCount[TokenCount] = sequences.Length;
for (var i = 0; i < sequences.Length; i++)
_sequenceIds[TokenCount][i] = sequences[i];

_logits[TokenCount] = Convert.ToByte(logits);

TokenCount++;
}

/// <summary>
/// Add a single token to the batch at a certain position for a single sequences
/// </summary>
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks>
/// <param name="token">The token to add</param>
/// <param name="pos">The position to add it att</param>
/// <param name="sequence">The sequence to add this token to</param>
/// <param name="logits"></param>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
{
// Create a temporary span to contain 1 item without allocating
Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1];
sequences[0] = sequence;

// Add it
LLamaBatchAdd(token, pos, sequences, logits);
}

/// <summary>
/// Set TokenCount to zero for this batch
/// </summary>
public void LLamaBatchClear()
{
TokenCount = 0;
}
}
158 changes: 0 additions & 158 deletions LLama/Native/LLamaBatchSafeHandle.cs

This file was deleted.

2 changes: 1 addition & 1 deletion LLama/Native/LLamaNativeBatch.cs
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public unsafe struct LLamaNativeBatch
/// <summary>
/// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
/// </summary>
public LLamaToken* token;
public LLamaToken* tokens;

/// <summary>
/// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created
Expand Down
5 changes: 5 additions & 0 deletions LLama/Native/LLamaSeqId.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,11 @@ namespace LLama.Native;
[StructLayout(LayoutKind.Sequential)]
public record struct LLamaSeqId
{
/// <summary>
/// LLamaSeqId with value 0
/// </summary>
public static readonly LLamaSeqId Zero = new LLamaSeqId(0);

/// <summary>
/// The raw value
/// </summary>
Expand Down
6 changes: 3 additions & 3 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Buffers;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
Expand Down Expand Up @@ -198,9 +197,10 @@ public bool Eval(ReadOnlySpan<LLamaToken> tokens, int n_past)
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatchSafeHandle batch)
public int Decode(LLamaBatch batch)
{
return NativeApi.llama_decode(this, batch.NativeBatch);
using (batch.ToNativeBatch(out var nb))
return NativeApi.llama_decode(this, nb);
}

#region state
Expand Down

0 comments on commit a0be27d

Please sign in to comment.