Skip to content

Commit

Permalink
Merge pull request #443 from martindevans/llama_batch_self_grow
Browse files Browse the repository at this point in the history
LLamaBatch Automatically Grow Capacity
  • Loading branch information
martindevans authored Jan 20, 2024
2 parents a0be27d + 99969e5 commit 250c20b
Show file tree
Hide file tree
Showing 6 changed files with 90 additions and 46 deletions.
12 changes: 6 additions & 6 deletions LLama.Examples/Examples/BatchedDecoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,13 @@ public static async Task Run()
return;
}

var batch = new LLamaBatch(Math.Max(prompt_tokens.Length, n_parallel), 1);
var batch = new LLamaBatch(1);

// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
batch.LLamaBatchAdd(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);
batch.Add(prompt_tokens[i], i, LLamaSeqId.Zero, i == prompt_tokens.Length - 1);

if (context.NativeHandle.Decode(batch) != 0)
if (await context.DecodeAsync(batch) != 0)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
Expand Down Expand Up @@ -97,7 +97,7 @@ public static async Task Run()
timer.Start();
while (n_cur <= n_len)
{
batch.LLamaBatchClear();
batch.Clear();

for (var i = 0; i < n_parallel; i++)
{
Expand Down Expand Up @@ -129,7 +129,7 @@ public static async Task Run()
i_batch[i] = batch.TokenCount;

// push this new token for next evaluation
batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
batch.Add(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);

n_decode++;
}
Expand All @@ -143,7 +143,7 @@ public static async Task Run()
n_cur++;

// evaluate the current batch with the transformer model
if (context.NativeHandle.Decode(batch) != 0)
if (await context.DecodeAsync(batch) != 0)
{
await Console.Error.WriteLineAsync("failed to eval");
return;
Expand Down
2 changes: 1 addition & 1 deletion LLama.Unittest/BeamTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ public void BasicBeam()

var initial_tokens = context.Tokenize(prompt);
result.Append(prompt);
context.Eval(initial_tokens, 0);
context.Eval(initial_tokens.AsSpan(), 0);

NativeApi.llama_beam_search(context.NativeHandle, (data, state) =>
{
Expand Down
2 changes: 1 addition & 1 deletion LLama.Unittest/StatelessExecutorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ public async Task Stateless()

var executor = new StatelessExecutor(_weights, _params);

const string question = "Question. what is a cat?\nAnswer: ";
const string question = "Question. what is a cat?\nAnswer:";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." }, SamplingPipeline = pipeline };

var timer = new Stopwatch();
Expand Down
49 changes: 26 additions & 23 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -8,10 +8,12 @@
using System.IO.MemoryMappedFiles;
using LLama.Common;
using System.Runtime.InteropServices;
using System.Threading.Tasks;
using LLama.Extensions;
using LLama.Abstractions;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
using System.Threading;

namespace LLama
{
Expand Down Expand Up @@ -344,16 +346,30 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dict

#region eval overloads
/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(LLamaToken[] tokens, int pastTokensCount)
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatch batch)
{
return NativeHandle.Decode(batch);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <param name="cancellationToken"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
{
return Eval(tokens.AsSpan(), pastTokensCount);
return Task.Run(() => NativeHandle.Decode(batch), cancellationToken);
}

/// <summary>
Expand All @@ -363,7 +379,7 @@ public int Eval(LLamaToken[] tokens, int pastTokensCount)
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
[Obsolete("use Decode() instead")]
public int Eval(List<LLamaToken> tokens, int pastTokensCount)
{
#if NET5_0_OR_GREATER
Expand Down Expand Up @@ -394,20 +410,7 @@ public int Eval(List<LLamaToken> tokens, int pastTokensCount)
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
public int Eval(ReadOnlyMemory<LLamaToken> tokens, int pastTokensCount)
{
return Eval(tokens.Span, pastTokensCount);
}

/// <summary>
///
/// </summary>
/// <param name="tokens"></param>
/// <param name="pastTokensCount"></param>
/// <returns>The updated `pastTokensCount`.</returns>
/// <exception cref="RuntimeError"></exception>
[Obsolete("use llama_decode() instead")]
[Obsolete("use Decode() instead")]
public int Eval(ReadOnlySpan<LLamaToken> tokens, int pastTokensCount)
{
var total = tokens.Length;
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaEmbedder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ public float[] GetEmbeddings(string text, bool addBos)
// TODO(Rinne): deal with log of prompt

if (embed_inp_array.Length > 0)
Context.Eval(embed_inp_array, 0);
Context.Eval(embed_inp_array.AsSpan(), 0);

Check warning on line 78 in LLama/LLamaEmbedder.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'LLamaContext.Eval(ReadOnlySpan<LLamaToken>, int)' is obsolete: 'use Decode() instead'

Check warning on line 78 in LLama/LLamaEmbedder.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'LLamaContext.Eval(ReadOnlySpan<LLamaToken>, int)' is obsolete: 'use Decode() instead'

Check warning on line 78 in LLama/LLamaEmbedder.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'LLamaContext.Eval(ReadOnlySpan<LLamaToken>, int)' is obsolete: 'use Decode() instead'

var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle);
if (embeddings == null)
Expand Down
69 changes: 55 additions & 14 deletions LLama/Native/LLamaBatch.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;

namespace LLama.Native;

Expand All @@ -7,27 +8,42 @@ namespace LLama.Native;
/// </summary>
public class LLamaBatch
{
private readonly byte[] _logits;
private byte[] _logits;

private readonly LLamaToken[] _tokens;
private readonly LLamaPos[] _positions;
private LLamaToken[] _tokens;
private LLamaPos[] _positions;

private readonly int[] _sequenceIdCount;
private readonly LLamaSeqId[][] _sequenceIds;
private readonly IntPtr[] _sequenceIdsPtrs;
private int[] _sequenceIdCount;
private LLamaSeqId[][] _sequenceIds;
private IntPtr[] _sequenceIdsPtrs;

/// <summary>
/// The number of tokens in this batch
/// </summary>
public int TokenCount { get; private set; }

/// <summary>
/// Maximum number of tokens that can be added to this batch
/// </summary>
private int TokenCapacity { get; set; }

/// <summary>
/// Maximum number of sequences a token can be assigned to
/// </summary>
public int MaxSequences { get; private set; }

/// <summary>
/// Create a new batch for submitting inputs to llama.cpp
/// </summary>
/// <param name="n_tokens"></param>
/// <param name="n_seq_max"></param>
public LLamaBatch(int n_tokens, int n_seq_max)
/// <param name="n_seq_max">Max number of sequences a token can be assigned to</param>
public LLamaBatch(int n_seq_max)
{
// The number of tokens can be grown later, start off with a reasonable guess.
const int n_tokens = 64;

MaxSequences = n_seq_max;
TokenCapacity = n_tokens;

_logits = new byte[n_tokens];
_tokens = new LLamaToken[n_tokens];
_positions = new LLamaPos[n_tokens];
Expand All @@ -37,7 +53,29 @@ public LLamaBatch(int n_tokens, int n_seq_max)

_sequenceIds = new LLamaSeqId[n_tokens][];
for (var i = 0; i < _sequenceIds.Length; i++)
_sequenceIds[i] = new LLamaSeqId[n_seq_max];
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
}

private void Grow()
{
var n_tokens = TokenCount * 2;
TokenCapacity = n_tokens;

Array.Resize(ref _logits, n_tokens);
Array.Resize(ref _tokens, n_tokens);
Array.Resize(ref _positions, n_tokens);

Array.Resize(ref _sequenceIdCount, n_tokens);
Array.Resize(ref _sequenceIdsPtrs, n_tokens);

Array.Resize(ref _sequenceIds, n_tokens);
for (int i = 0; i < _sequenceIds.Length; i++)
{
// Growing the array filled elements with null, temporarily violating the nullability contract!
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
if (_sequenceIds[i] == null)
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
}
}

internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
Expand Down Expand Up @@ -79,8 +117,11 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
/// <param name="pos">The position to add it att</param>
/// <param name="sequences">The set of sequences to add this token to</param>
/// <param name="logits"></param>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
if (TokenCount == TokenCapacity)
Grow();

_tokens[TokenCount] = token;
_positions[TokenCount] = pos;

Expand All @@ -101,20 +142,20 @@ public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqI
/// <param name="pos">The position to add it att</param>
/// <param name="sequence">The sequence to add this token to</param>
/// <param name="logits"></param>
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
public void Add(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits)
{
// Create a temporary span to contain 1 item without allocating
Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1];
sequences[0] = sequence;

// Add it
LLamaBatchAdd(token, pos, sequences, logits);
Add(token, pos, sequences, logits);
}

/// <summary>
/// Set TokenCount to zero for this batch
/// </summary>
public void LLamaBatchClear()
public void Clear()
{
TokenCount = 0;
}
Expand Down

0 comments on commit 250c20b

Please sign in to comment.