Skip to content

Commit

Permalink
BatchedExecutor Save/Load (#681)
Browse files Browse the repository at this point in the history
* Added the ability to save and load individual conversations in a batched executor.
 - New example
 - Added `BatchedExecutor.Load(filepath)` method
 - Added `Conversation.Save(filepath)` method
 - Added new (currently internal) `SaveState`/`LoadState` methods in LLamaContext which can stash some extra binary data in the header

* Added ability to save/load a `Conversation` to an in-memory state, instead of to file.

* Moved the new save/load methods out to an extension class specifically for the batched executor.

* Removed unnecessary spaces
  • Loading branch information
martindevans authored Apr 23, 2024
1 parent f01c13e commit ccc49eb
Show file tree
Hide file tree
Showing 6 changed files with 438 additions and 12 deletions.
1 change: 1 addition & 0 deletions LLama.Examples/ExampleRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ public class ExampleRunner
{ "Semantic Kernel: Prompt", SemanticKernelPrompt.Run },
{ "Semantic Kernel: Chat", SemanticKernelChat.Run },
{ "Semantic Kernel: Store", SemanticKernelMemory.Run },
{ "Batched Executor: Save/Load", BatchedExecutorSaveAndLoad.Run },
{ "Batched Executor: Fork", BatchedExecutorFork.Run },
{ "Batched Executor: Rewind", BatchedExecutorRewind.Run },
{ "Batched Executor: Guidance", BatchedExecutorGuidance.Run },
Expand Down
108 changes: 108 additions & 0 deletions LLama.Examples/Examples/BatchedExecutorSaveAndLoad.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
using LLama.Batched;
using LLama.Common;
using LLama.Native;
using LLama.Sampling;
using Spectre.Console;

namespace LLama.Examples.Examples;

/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
public class BatchedExecutorSaveAndLoad
{
private const int n_len = 18;

public static async Task Run()
{
string modelPath = UserSettings.GetModelPath();

var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);

var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that");

// Create an executor that can evaluate a batch of conversations together
using var executor = new BatchedExecutor(model, parameters);

// Print some info
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name");
Console.WriteLine($"Created executor with model: {name}");

// Create a conversation
var conversation = executor.Create();
conversation.Prompt(prompt);

// Run inference loop
var decoder = new StreamingTokenDecoder(executor.Context);
var sampler = new DefaultSamplingPipeline();
var lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);

// Can't save a conversation while RequiresInference is true
if (conversation.RequiresInference)
await executor.Infer();

// Save this conversation to a file and dispose it
conversation.Save("demo_conversation.state");
conversation.Dispose();
AnsiConsole.WriteLine($"Saved state: {new FileInfo("demo_conversation.state").Length} bytes");

// Now create a new conversation by loading that state
conversation = executor.Load("demo_conversation.state");
AnsiConsole.WriteLine("Loaded state");

// Prompt it again with the last token, so we can continue generating
conversation.Rewind(1);
conversation.Prompt(lastToken);

// Continue generating text
lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len);

// Can't save a conversation while RequiresInference is true
if (conversation.RequiresInference)
await executor.Infer();

// Save the conversation again, this time into system memory
using (var state = conversation.Save())
{
conversation.Dispose();
AnsiConsole.WriteLine($"Saved state to memory: {state.Size} bytes");

// Now create a new conversation by loading that state
conversation = executor.Load("demo_conversation.state");
AnsiConsole.WriteLine("Loaded state");
}

// Prompt it again with the last token, so we can continue generating
conversation.Rewind(1);
conversation.Prompt(lastToken);

// Continue generating text
await GenerateTokens(executor, conversation, sampler, decoder, n_len);

// Display final ouput
AnsiConsole.MarkupLine($"[red]{prompt}{decoder.Read()}[/]");
}

private static async Task<LLamaToken> GenerateTokens(BatchedExecutor executor, Conversation conversation, ISamplingPipeline sampler, StreamingTokenDecoder decoder, int count = 15)
{
var token = (LLamaToken)0;

for (var i = 0; i < count; i++)
{
// Run inference
await executor.Infer();

// Use sampling pipeline to pick a token
token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), ReadOnlySpan<LLamaToken>.Empty);

// Add it to the decoder, so it can be converted into text later
decoder.Add(token);

// Prompt the conversation with the token
conversation.Prompt(token);
}

return token;
}
}
33 changes: 33 additions & 0 deletions LLama/Batched/BatchedExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,39 @@ public Conversation Create()
return new Conversation(this, GetNextSequenceId());
}

/// <summary>
/// Load a conversation that was previously saved to a file. Once loaded the conversation will
/// need to be prompted.
/// </summary>
/// <param name="filepath"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
public Conversation Load(string filepath)
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));

var conversation = Create();
conversation.Load(filepath);
return conversation;
}

/// <summary>
/// Load a conversation that was previously saved into memory. Once loaded the conversation will need to be prompted.
/// </summary>
/// <param name="state"></param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
public Conversation Load(Conversation.State state)
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(BatchedExecutor));

var conversation = Create();
conversation.Load(state);
return conversation;
}

/// <summary>
/// Run inference for all conversations in the batch which have pending tokens.
///
Expand Down
173 changes: 169 additions & 4 deletions LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
using System.Buffers;
using System.Collections.Generic;
using System.Runtime.InteropServices;
using System.Text.Json;
using LLama.Native;

namespace LLama.Batched;
Expand All @@ -14,7 +15,7 @@ public sealed class Conversation
{
private ulong _requiredEpoch;
private LLamaPos _end;
private int _batchIndex;
private int _batchSampleIndex;
private bool _disposed;
private bool _forked;

Expand Down Expand Up @@ -107,7 +108,7 @@ public Conversation Fork()
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
// they both copy the logits before the next sampling run, to fix this issue.
_requiredEpoch = _requiredEpoch,
_batchIndex = _batchIndex,
_batchSampleIndex = _batchSampleIndex,
_forked = true,

_end = _end,
Expand Down Expand Up @@ -140,7 +141,7 @@ public Span<float> Sample()
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();

var span = Executor.Context.NativeHandle.GetLogitsIth(_batchIndex);
var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex);

// If necessary copy the span, to protect it from modification. This is only done when
// this conversation has been forked in this epoch.
Expand Down Expand Up @@ -220,7 +221,7 @@ public void Prompt(ReadOnlySpan<LLamaToken> tokens)

// Add the prompt to the batch
for (var i = 0; i < tokens.Length; i++)
_batchIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
_batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);

// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
Expand Down Expand Up @@ -350,4 +351,168 @@ public void Divide(LLamaPos start, LLamaPos end, int divisor)
/// <returns>The new end token position</returns>
public delegate LLamaPos ModifyKvCache(LLamaPos end, KvAccessor kv);
#endregion

#region save/load
private void AssertCanLoad()
{
AssertNotDisposed();
if (_end.Value > 0)
throw new InvalidOperationException("Cannot load into a non-empty conversation");
}

private void AssertCanSave()
{
AssertNotDisposed();
if (RequiresInference)
throw new CannotSaveWhileRequiresInferenceException();
}


/// <summary>
/// Save the complete state of this conversation to a file. if the file already exists it will be overwritten.
/// </summary>
/// <param name="filepath"></param>
/// <exception cref="CannotSaveWhileRequiresInferenceException"></exception>
public void Save(string filepath)
{
AssertCanSave();

// Prepare extra state to put into file header
var state = GetState();
var bytes = JsonSerializer.SerializeToUtf8Bytes(state);

// Save extra state along with the KV cache
Executor.Context.SaveState(filepath, ConversationId, bytes);
}

/// <summary>
/// Save the complete state of this conversation in system memory.
/// </summary>
/// <returns></returns>
public State Save()
{
AssertCanSave();

return new PrivateState(
Executor.Context.GetState(ConversationId),
GetState()
);
}


/// <summary>
/// Load state from a file
/// This should only ever be called by the BatchedExecutor, on a newly created conversation object!
/// </summary>
/// <param name="filepath"></param>
/// <exception cref="InvalidOperationException"></exception>
internal void Load(string filepath)
{
AssertCanLoad();

// Load the state from file into the KV cache
Executor.Context.LoadState(filepath, ConversationId, out var header);

// deserialize the extra state in the file header
var state = JsonSerializer.Deserialize<SerializableConversationState>(header);
if (state == null)
{
Dispose();
throw new InvalidOperationException("Failed to deserialize - deserialized header state was null");
}

Load(state);
}

/// <summary>
/// Load state from a previously saved state.
/// This should only ever be called by the BatchedExecutor, on a newly created conversation object!
/// </summary>
/// <param name="state"></param>
internal void Load(State state)
{
AssertCanLoad();

// There is only one class that extends State and it is PrivateState, so this cast is safe.
var priv = (PrivateState)state;

// Load the state from file into the KV cache
Executor.Context.LoadState(priv.SequenceState, ConversationId);

Load(priv.ConversationState);
}


private void Load(SerializableConversationState state)
{
if (state.Version != 1)
throw new InvalidOperationException("Failed to deserialize - mismatched version number");

// Load extra conversation state
_end = state.TokenCount;
}

private SerializableConversationState GetState()
{
return new SerializableConversationState(
Version: 1,
TokenCount: TokenCount
);
}


private record SerializableConversationState(int Version, int TokenCount);

private sealed class PrivateState
: State
{
public readonly LLamaContext.SequenceState SequenceState;
public readonly SerializableConversationState ConversationState;

public override ulong Size => SequenceState.Size;

public PrivateState(LLamaContext.SequenceState sequenceState, SerializableConversationState conversationState)
{
SequenceState = sequenceState;
ConversationState = conversationState;
}

/// <inheritdoc />
public override void Dispose()
{
if (IsDisposed)
throw new ObjectDisposedException(nameof(State));
IsDisposed = true;

SequenceState.Dispose();
}
}

/// <summary>
/// In memory saved state of a <see cref="Conversation"/>
/// </summary>
public abstract class State
: IDisposable
{
/// <summary>
/// Indicates if this state has been disposed
/// </summary>
public bool IsDisposed { get; protected set; }

/// <summary>
/// Get the size in bytes of this state object
/// </summary>
public abstract ulong Size { get; }

/// <inheritdoc />
public abstract void Dispose();

/// <summary>
/// Internal constructor prevent anyone outside of LLamaSharp extending this class
/// </summary>
internal State()
{
}
}
#endregion
}
Loading

0 comments on commit ccc49eb

Please sign in to comment.