Skip to content

Commit

Permalink
Merge pull request #748 from martindevans/batched_executor_double_buffer
Browse files Browse the repository at this point in the history
BatchedExecutor Double Buffering
  • Loading branch information
martindevans authored May 23, 2024
2 parents 87e07a4 + 71be0ac commit ad6f22c
Showing 1 changed file with 23 additions and 13 deletions.
36 changes: 23 additions & 13 deletions LLama/Batched/BatchedExecutor.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using System;
using System;
using System.Threading;
using System.Threading.Tasks;
using LLama.Abstractions;
Expand All @@ -13,8 +13,10 @@ public sealed class BatchedExecutor
: IDisposable
{
private int _nextSequenceId;

internal LLamaBatch Batch { get; }

private LLamaBatch _promptingBatch = new();
private LLamaBatch _nextBatch = new();
internal LLamaBatch Batch => _promptingBatch;

/// <summary>
/// Epoch is incremented every time Infer is called. Conversations can use this to keep track of
Expand Down Expand Up @@ -50,7 +52,6 @@ public sealed class BatchedExecutor
public BatchedExecutor(LLamaWeights model, IContextParams contextParams)
{
Model = model;
Batch = new LLamaBatch();
Context = model.CreateContext(contextParams);
Epoch = 1;
}
Expand Down Expand Up @@ -110,17 +111,26 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));

var status = await Context.DecodeAsync(Batch, cancellation);

// Only clear the batch if the result was ok. leaving all this state in place means that "Infer" can
// be called again after a warning (e.g. NoKvSlot).
if (status == DecodeResult.Ok)

// Swap over batches. This means the next batch can be filled with
// tokens while inference is still running for the previous one.
var batch = _promptingBatch;
(_promptingBatch, _nextBatch) = (_nextBatch, _promptingBatch);

var status = await Context.DecodeAsync(batch, cancellation);

// If there was an error swap the previous batch back into place. This allows infer to be called again
// after the issue has been fixed (e.g. some KV cache space has been freed) to "retry" this operation.
if (status != DecodeResult.Ok)
{
Epoch++;
Batch.Clear();
(_promptingBatch, _nextBatch) = (_nextBatch, _promptingBatch);
return status;
}


// Everything was ok, advance the epoch and clear the batch we just ran inference for.
Epoch++;
batch.Clear();

return status;
}

Expand Down

0 comments on commit ad6f22c

Please sign in to comment.