Skip to content

Commit

Permalink
Merge pull request #761 from martindevans/batched_executor_queueing
Browse files Browse the repository at this point in the history
Batch Queueing
  • Loading branch information
martindevans authored May 29, 2024
2 parents 22342c6 + 2cc7c2b commit be2c4fe
Show file tree
Hide file tree
Showing 3 changed files with 126 additions and 60 deletions.
101 changes: 80 additions & 21 deletions LLama/Batched/BatchedExecutor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
Expand All @@ -13,13 +15,15 @@ public sealed class BatchedExecutor
: IDisposable
{
private int _nextSequenceId;
private readonly List<LLamaBatch> _batchQueue = [ ];

private LLamaBatch _promptingBatch = new();
private LLamaBatch _nextBatch = new();
internal LLamaBatch Batch => _promptingBatch;
/// <summary>
/// Held while inference is running
/// </summary>
private readonly object _inferenceLock = new();

/// <summary>
/// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
/// Epoch is incremented twice every time Infer is called. Conversations can use this to keep track of
/// whether they're waiting for inference, or can be sampled.
/// </summary>
internal ulong Epoch { get; private set; }
Expand All @@ -33,11 +37,11 @@ public sealed class BatchedExecutor
/// The <see cref="LLamaWeights"/> this executor is using
/// </summary>
public LLamaWeights Model { get; }

/// <summary>
/// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
/// </summary>
public int BatchedTokenCount => Batch.TokenCount;
public int BatchedTokenCount => _batchQueue.Sum(a => a.TokenCount);

/// <summary>
/// Check if this executor has been disposed.
Expand Down Expand Up @@ -112,26 +116,53 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));

// Swap over batches. This means the next batch can be filled with
// tokens while inference is still running for the previous one.
var batch = _promptingBatch;
(_promptingBatch, _nextBatch) = (_nextBatch, _promptingBatch);

var status = await Context.DecodeAsync(batch, cancellation);
// If there's no work to do then we successfully completed all available work! immediately exit.
var next = GetNextBatch();
if (next == null)
return DecodeResult.Ok;

// If there was an error swap the previous batch back into place. This allows infer to be called again
// after the issue has been fixed (e.g. some KV cache space has been freed) to "retry" this operation.
if (status != DecodeResult.Ok)
// Take the inference lock, if this fails it's because inference is already running.
if (!Monitor.TryEnter(_inferenceLock))
throw new InvalidOperationException("Cannot start inference while it is already running");
try
{
(_promptingBatch, _nextBatch) = (_nextBatch, _promptingBatch);
// Advance epoch by one. This ensures that _nothing_ can be sampled while inference is running.
// Only do this if the epoch is odd. If it's even that means it was previously advanced by another
// inference run, and this run is a retry.
if ((Epoch & 1) == 1)
Epoch++;

// Run the actual inference. This is the slow bit!
var status = await Context.DecodeAsync(next, cancellation);

// If there was an error then early exit without incrementing the epoch. This allows infer to be called
// again after the issue has been fixed (e.g. some KV cache space has been freed) to retry this operation.
if (status != DecodeResult.Ok)
{
_batchQueue.Insert(0, next);
return status;
}

// Everything was ok, advance the epoch and clear the batch we just ran inference for.
Epoch++;
next.Clear();

return status;
}
finally
{
Monitor.Exit(_inferenceLock);
}

// Everything was ok, advance the epoch and clear the batch we just ran inference for.
Epoch++;
batch.Clear();

return status;
LLamaBatch? GetNextBatch()
{
if (_batchQueue.Count == 0)
return null;

var nextBatch = _batchQueue[0];
_batchQueue.RemoveAt(0);
return nextBatch;
}
}

/// <inheritdoc />
Expand All @@ -148,4 +179,32 @@ internal LLamaSeqId GetNextSequenceId()
{
return checked((LLamaSeqId)_nextSequenceId++);
}

/// <summary>
/// Get a reference to a batch that tokens can be added to.
/// </summary>
/// <param name="minCapacity"></param>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
internal (LLamaBatch batch, ulong epoch) GetTokenBatch(int minCapacity = 1)
{
if (minCapacity > Context.BatchSize)
throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})");

// Find a batch with space for at least minCapacity tokens
for (var i = 0; i < _batchQueue.Count; i++)
{
var capacity = Context.BatchSize - _batchQueue[i].TokenCount;
if (capacity < minCapacity)
continue;

if (_batchQueue[i].TokenCount < Context.BatchSize)
return (_batchQueue[i], Epoch + (uint)(i + 1) * 2);
}

// Add a new batch to the end of the queue
var end = new LLamaBatch();
_batchQueue.Add(end);
return (end, Epoch + (uint)_batchQueue.Count * 2);
}
}
38 changes: 23 additions & 15 deletions LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -232,22 +232,33 @@ public void Prompt(ReadOnlySpan<LLamaToken> tokens, bool allLogits = false)

_batchSampleCount = tokens.Length;

// We need to add all tokens to a single batch, so they can all be sampled at once.
// Request a batch with sufficient space.
(var batch, _requiredEpoch) = Executor.GetTokenBatch(tokens.Length);

// Add everything to that batch
for (var i = 0; i < tokens.Length; i++)
_batchSampleIndices[i] = Executor.Batch.Add(tokens[i], _end++, ConversationId, true);
_batchSampleIndices[i] = batch.Add(tokens[i], _end++, ConversationId, true);
}
else
{
_batchSampleCount = 1;

for (var i = 0; i < tokens.Length; i++)
_batchSampleIndices[0] = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);

while (tokens.Length > 0)
{
// Get a batch with capacity for at least 1 token
(var batch, _requiredEpoch) = Executor.GetTokenBatch();

// Add as many tokens as possible
var count = Math.Min(tokens.Length, checked((int)Executor.Context.BatchSize) - batch.TokenCount);
for (var i = 0; i < count; i++)
_batchSampleIndices[0] = batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);

// Slice the array to remove tokens we've already added to a batch
tokens = tokens.Slice(count);
}
}



// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;

// Unset the forked flag. Since this conversation has just been prompted it's no longer
// sharing anything with any other conversations.
_forked = false;
Expand All @@ -263,12 +274,9 @@ public void Prompt(ReadOnlySpan<LLamaToken> tokens, bool allLogits = false)
public void Prompt(LLamaToken token)
{
AssertCanBePrompted();

unsafe
{
Span<LLamaToken> span = stackalloc LLamaToken[1] { token };
Prompt(span);
}

Span<LLamaToken> span = [ token ];
Prompt(span);
}
#endregion

Expand Down
47 changes: 23 additions & 24 deletions LLama/Native/LLamaBatch.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Collections.Generic;
using System.Runtime.InteropServices;

Expand Down Expand Up @@ -54,51 +54,50 @@ public class LLamaBatch
public LLamaBatch()
{
// These can both be grown later, start off with reasonable numbers.
const int n_tokens = 128;
const int n_seq_max = 1;
const int tokensCapacity = 128;
const int seqCapacity = 1;

SequenceCapacity = n_seq_max;
TokenCapacity = n_tokens;
SequenceCapacity = seqCapacity;
TokenCapacity = tokensCapacity;

_logits = new byte[n_tokens];
_tokens = new LLamaToken[n_tokens];
_positions = new LLamaPos[n_tokens];
_logits = new byte[tokensCapacity];
_tokens = new LLamaToken[tokensCapacity];
_positions = new LLamaPos[tokensCapacity];

_sequenceIdCount = new int[n_tokens];
_sequenceIdCount = new int[tokensCapacity];
_sequenceIdsPtrs = new IntPtr[_sequenceIdCount.Length];

_sequenceIds = new LLamaSeqId[n_tokens][];
_sequenceIds = new LLamaSeqId[tokensCapacity][];
for (var i = 0; i < _sequenceIds.Length; i++)
_sequenceIds[i] = new LLamaSeqId[SequenceCapacity];
}

#region grow
private void GrowTokenCapacity()
{
var n_tokens = TokenCount * 2;
TokenCapacity = n_tokens;
var tokenCapacity = TokenCount * 2;
TokenCapacity = tokenCapacity;

Array.Resize(ref _logits, n_tokens);
Array.Resize(ref _tokens, n_tokens);
Array.Resize(ref _positions, n_tokens);
Array.Resize(ref _logits, tokenCapacity);
Array.Resize(ref _tokens, tokenCapacity);
Array.Resize(ref _positions, tokenCapacity);

Array.Resize(ref _sequenceIdCount, n_tokens);
Array.Resize(ref _sequenceIdsPtrs, n_tokens);
Array.Resize(ref _sequenceIdCount, tokenCapacity);
Array.Resize(ref _sequenceIdsPtrs, tokenCapacity);

Array.Resize(ref _sequenceIds, n_tokens);
for (int i = 0; i < _sequenceIds.Length; i++)
Array.Resize(ref _sequenceIds, tokenCapacity);
for (var i = 0; i < _sequenceIds.Length; i++)
{
// Growing the array filled elements with null, temporarily violating the nullability contract!
// ReSharper disable once ConditionIsAlwaysTrueOrFalseAccordingToNullableAPIContract
if (_sequenceIds[i] == null)
_sequenceIds[i] = new LLamaSeqId[SequenceCapacity];
// ReSharper disable once NullCoalescingConditionIsAlwaysNotNullAccordingToAPIContract
_sequenceIds[i] ??= new LLamaSeqId[SequenceCapacity];
}
}

private void GrowMaxSequences(int atLeast)
{
var n_seq = Math.Max(SequenceCapacity * 2, atLeast);
SequenceCapacity = n_seq;
var seqCapacity = Math.Max(SequenceCapacity * 2, atLeast);
SequenceCapacity = seqCapacity;

for (var i = 0; i < _sequenceIds.Length; i++)
Array.Resize(ref _sequenceIds[i], SequenceCapacity);
Expand Down

0 comments on commit be2c4fe

Please sign in to comment.