Skip to content

Commit

Permalink
Added support for embeddings prompting to BatchedExecutor, allowing…
Browse files Browse the repository at this point in the history
… support for Llava!

 - Switched batch queue in `BatchedExecutor` to have 2 possible types - token batches and embeddings batches.
 - Switched inference lock to a integer using Interlocked, this interacts better with async.
 - Added `Conversation.Prompt` method for `Span<float>` (raw embeddings) and `SafeLlavaImageEmbedHandle`.
 - Added `LLamaBatchEmbeddings`, equivalent to `LLamaBatch` for embeddings instead of for tokens.
  • Loading branch information
martindevans committed May 31, 2024
1 parent be2c4fe commit 7a3a1cd
Show file tree
Hide file tree
Showing 8 changed files with 580 additions and 17 deletions.
3 changes: 2 additions & 1 deletion LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using Spectre.Console;
using Spectre.Console;
using LLama.Examples.Examples;

public class ExampleRunner
Expand Down Expand Up @@ -31,6 +31,7 @@ public class ExampleRunner
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
{ "Batched Executor: LLava", BatchedExecutorLLava.Run },
{ "Speech Chat: Integration with Whisper.net", SpeechChat.Run },
{ "Exit", () => { Environment.Exit(0); return Task.CompletedTask; } }
};
Expand Down
91 changes: 91 additions & 0 deletions LLama.Examples/Examples/BatchedExecutorLLava.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
using System.Text;
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;

namespace LLama.Examples.Examples;

/// <summary>
/// Demonstrates using LLava (image embeddings) with the batched executor.
/// </summary>
public class BatchedExecutorLLava
{
/// <summary>
/// How many tokens of response to generate
/// </summary>
public const int TokenCount = 64;

public static async Task Run()
{
// Load model weights
var parameters = new ModelParams(UserSettings.GetModelPath());
using var model = await LLamaWeights.LoadFromFileAsync(parameters);
using var llava = await LLavaWeights.LoadFromFileAsync(UserSettings.GetMMProjPath());

// Decide on the prompt
var prompt = model.Tokenize(AnsiConsole.Ask("Prompt (or ENTER for default):", "\nUSER: Provide a full description of the image.\nASSISTANT: "), true, false, Encoding.UTF8);

// Get image and show it
var image = UserSettings.GetImagePath();
AnsiConsole.Write(new CanvasImage(image));

// Create an executor with one conversation
using var executor = new BatchedExecutor(model, parameters);
using var conversation = executor.Create();

// Embed the image
SafeLlavaImageEmbedHandle embedding = null!;
await AnsiConsole
.Status()
.StartAsync("[yellow]Embedding image with CLIP[/]", async _ =>
{
// ReSharper disable once AccessToDisposedClosure
embedding = llava.CreateImageEmbeddings(await File.ReadAllBytesAsync(image));
});

// Pass in the image and run inference until the entire image has been processed
await AnsiConsole
.Status()
.StartAsync("[yellow]Processing image embedding with language model[/]", async _ =>
{
conversation.Prompt(embedding);
while (executor.BatchedTokenCount > 0)
await executor.Infer();
});

// Prompt with the text prompt
conversation.Prompt(prompt);

// Run inference loop
var decoder = new StreamingTokenDecoder(executor.Context);
var sampler = new DefaultSamplingPipeline();
await AnsiConsole
.Progress()
.StartAsync(async ctx =>
{
var task = ctx.AddTask("Generating Response");
task.MaxValue = TokenCount;

// Run a normal inference loop
for (var i = 0; i < TokenCount; i++)
{
task.Increment(1);

await executor.Infer();

var token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), Array.Empty<LLamaToken>());
if (executor.Context.NativeHandle.ModelHandle.Tokens.IsEndOfGeneration(token))
break;

decoder.Add(token);
conversation.Prompt(token);
}
});

// Print final result
var str = decoder.Read();
AnsiConsole.MarkupInterpolated($"[green]{str}[/]");
}
}
111 changes: 95 additions & 16 deletions LLama/Batched/BatchedExecutor.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Linq;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -15,12 +16,12 @@ public sealed class BatchedExecutor
: IDisposable
{
private int _nextSequenceId;
private readonly List<LLamaBatch> _batchQueue = [ ];
private readonly List<IBatch> _batchQueue = [ ];

/// <summary>
/// Held while inference is running
/// Set to 1 using interlocked exchange while inference is running
/// </summary>
private readonly object _inferenceLock = new();
private int _inferenceLock = 0;

/// <summary>
/// Epoch is incremented twice every time Infer is called. Conversations can use this to keep track of
Expand All @@ -41,7 +42,12 @@ public sealed class BatchedExecutor
/// <summary>
/// Get the number of tokens in the batch, waiting for <see cref="Infer"/> to be called
/// </summary>
public int BatchedTokenCount => _batchQueue.Sum(a => a.TokenCount);
public int BatchedTokenCount => _batchQueue.Sum(a => a.ItemCount);

/// <summary>
/// Number of batches in the queue, waiting for <see cref="Infer"/> to be called
/// </summary>
public int BatchQueueCount => _batchQueue.Count;

/// <summary>
/// Check if this executor has been disposed.
Expand Down Expand Up @@ -120,9 +126,11 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
var next = GetNextBatch();
if (next == null)
return DecodeResult.Ok;

// Take the inference lock, if this fails it's because inference is already running.
if (!Monitor.TryEnter(_inferenceLock))

// This acts as a "lock" on inference, ensuring two inferences cannot run at once. First set the "_inferenceLock" field
// to the "key" value iff it is currently 0. If it is not currently 0 this will throw an exception.
var key = (int)(DateTime.UtcNow.Ticks & 0xFFFF_FFFF);
if (Interlocked.CompareExchange(ref _inferenceLock, key, 0) != 0)
throw new InvalidOperationException("Cannot start inference while it is already running");
try
{
Expand All @@ -133,7 +141,7 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
Epoch++;

// Run the actual inference. This is the slow bit!
var status = await Context.DecodeAsync(next, cancellation);
var status = await next.DecodeAsync(Context, cancellation);

// If there was an error then early exit without incrementing the epoch. This allows infer to be called
// again after the issue has been fixed (e.g. some KV cache space has been freed) to retry this operation.
Expand All @@ -143,18 +151,20 @@ public async Task<DecodeResult> Infer(CancellationToken cancellation = default)
return status;
}

// Everything was ok, advance the epoch and clear the batch we just ran inference for.
// Everything was ok, advance the epoch
Epoch++;
next.Clear();

return status;
}
finally
{
Monitor.Exit(_inferenceLock);
// Set "_inferenceLock" field back to zero iff it is currently the "key" value we set earlier. It should be
// impossible for this to ever fail!
var old = Interlocked.CompareExchange(ref _inferenceLock, 0, key);
Debug.Assert(old == key);
}

LLamaBatch? GetNextBatch()
IBatch? GetNextBatch()
{
if (_batchQueue.Count == 0)
return null;
Expand Down Expand Up @@ -194,17 +204,86 @@ internal LLamaSeqId GetNextSequenceId()
// Find a batch with space for at least minCapacity tokens
for (var i = 0; i < _batchQueue.Count; i++)
{
var capacity = Context.BatchSize - _batchQueue[i].TokenCount;
var item = _batchQueue[i];
if (item is not TokenBatch { Batch: var batch })
continue;

var capacity = Context.BatchSize - batch.TokenCount;
if (capacity < minCapacity)
continue;

if (_batchQueue[i].TokenCount < Context.BatchSize)
return (_batchQueue[i], Epoch + (uint)(i + 1) * 2);
if (batch.TokenCount < Context.BatchSize)
return (batch, Epoch + (uint)(i + 1) * 2);
}

// Add a new batch to the end of the queue
var end = new LLamaBatch();
_batchQueue.Add(end);
_batchQueue.Add(new TokenBatch(end));
return (end, Epoch + (uint)_batchQueue.Count * 2);
}

/// <summary>
/// Get a reference to a batch that embeddings can be added to.
/// </summary>
/// <param name="minCapacity"></param>
/// <returns></returns>
/// <exception cref="ArgumentOutOfRangeException"></exception>
internal (LLamaBatchEmbeddings batch, ulong epoch) GetEmbeddingBatch(int minCapacity = 1)
{
if (minCapacity > Context.BatchSize)
throw new ArgumentOutOfRangeException(nameof(minCapacity), $"Request batch capacity must be less than or equal to BatchSize ({Context.BatchSize})");

// Find a batch with space for at least minCapacity embeddings
for (var i = 0; i < _batchQueue.Count; i++)
{
var item = _batchQueue[i];
if (item is not EmbeddingBatch { Batch: var batch })
continue;

var capacity = Context.BatchSize - batch.EmbeddingsCount;
if (capacity < minCapacity)
continue;

if (batch.EmbeddingsCount < Context.BatchSize)
return (batch, Epoch + (uint)(i + 1) * 2);
}

// Add a new batch to the end of the queue
var end = new LLamaBatchEmbeddings(Context.EmbeddingSize);
_batchQueue.Add(new EmbeddingBatch(end));
return (end, Epoch + (uint)_batchQueue.Count * 2);
}

#region batches
private interface IBatch
{
int ItemCount { get; }

Task<DecodeResult> DecodeAsync(LLamaContext ctx, CancellationToken token);
}

private class TokenBatch(LLamaBatch batch)
: IBatch
{
public readonly LLamaBatch Batch = batch;
public int ItemCount => Batch.TokenCount;

public Task<DecodeResult> DecodeAsync(LLamaContext ctx, CancellationToken token)
{
return ctx.DecodeAsync(Batch, token);
}
}

private class EmbeddingBatch(LLamaBatchEmbeddings batch)
: IBatch
{
public readonly LLamaBatchEmbeddings Batch = batch;
public int ItemCount => Batch.EmbeddingsCount;

public Task<DecodeResult> DecodeAsync(LLamaContext ctx, CancellationToken token)
{
return ctx.DecodeAsync(Batch, token);
}
}
#endregion
}
67 changes: 67 additions & 0 deletions LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,73 @@ public void Prompt(LLamaToken token)
Span<LLamaToken> span = [ token ];
Prompt(span);
}

/// <summary>
/// Prompt this conversation with an image embedding
/// </summary>
/// <param name="embedding"></param>
public void Prompt(SafeLlavaImageEmbedHandle embedding)
{
AssertCanBePrompted();

if (embedding.Model.EmbeddingDimensions != Executor.Model.EmbeddingSize)
throw new ArgumentException($"Embedding dimension mismatch between image embedding ({embedding.Model.EmbeddingDimensions}) and model ({Executor.Model.EmbeddingSize})");

// Get a temporary array large enough to hold one embedding item
var tempArr = ArrayPool<float>.Shared.Rent(embedding.Model.EmbeddingDimensions);
var tempSpan = tempArr.AsSpan(0, embedding.Model.EmbeddingDimensions);
try
{
for (var i = 0; i < embedding.Model.PatchCount; i++)
{
// Get a batch with space
(var batch, _requiredEpoch) = Executor.GetEmbeddingBatch();

batch.Add(
(i, embedding),
static (Span<float> dest, (int index, SafeLlavaImageEmbedHandle embedding) tup) => tup.embedding.GetEmbedding(dest, tup.index),
_end++,
ConversationId,
i == embedding.Model.PatchCount - 1
);
}
}
finally
{
ArrayPool<float>.Shared.Return(tempArr);
}
}

/// <summary>
/// Prompt this conversation with embeddings
/// </summary>
/// <param name="embeddings">The raw values of the embeddings. This span must divide equally by the embedding size of this model.</param>
public void Prompt(ReadOnlySpan<float> embeddings)
{
AssertCanBePrompted();

var dim = Executor.Model.EmbeddingSize;
var count = embeddings.Length / dim;
if (count * dim != embeddings.Length)
throw new ArgumentException($"Incorrect embeddings span size, length ({embeddings.Length}) must be divisible by embedding dimensions ({Executor.Model.EmbeddingSize})");

while (embeddings.Length > 0)
{
// Get a batch with space
(var batch, _requiredEpoch) = Executor.GetEmbeddingBatch();

// Add 1 embedding to the batch
batch.Add(
embeddings.Slice(0, dim),
_end++,
ConversationId,
embeddings.Length == dim
);

// Advance to next embedding
embeddings = embeddings.Slice(dim);
}
}
#endregion

#region modify
Expand Down
22 changes: 22 additions & 0 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,28 @@ public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancel
{
return Task.Run(() => Decode(batch), cancellationToken);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
public DecodeResult Decode(LLamaBatchEmbeddings batch)
{
if (batch.EmbeddingsCount == 0)
return 0;
if (batch.EmbeddingsCount > Params.BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));

return (DecodeResult)NativeHandle.Decode(batch);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <param name="cancellationToken"></param>
public Task<DecodeResult> DecodeAsync(LLamaBatchEmbeddings batch, CancellationToken cancellationToken = default)
{
return Task.Run(() => Decode(batch), cancellationToken);
}
#endregion

/// <inheritdoc />
Expand Down
Loading

0 comments on commit 7a3a1cd

Please sign in to comment.