-
Notifications
You must be signed in to change notification settings - Fork 381
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Added the ability to save and load individual conversations in a batched executor. - New example - Added `BatchedExecutor.Load(filepath)` method - Added `Conversation.Save(filepath)` method - Added new (currently internal) `SaveState`/`LoadState` methods in LLamaContext which can stash some extra binary data in the header * Added ability to save/load a `Conversation` to an in-memory state, instead of to file. * Moved the new save/load methods out to an extension class specifically for the batched executor. * Removed unnecessary spaces
- Loading branch information
1 parent
f01c13e
commit ccc49eb
Showing
6 changed files
with
438 additions
and
12 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,108 @@ | ||
using LLama.Batched; | ||
using LLama.Common; | ||
using LLama.Native; | ||
using LLama.Sampling; | ||
using Spectre.Console; | ||
|
||
namespace LLama.Examples.Examples; | ||
|
||
/// <summary> | ||
/// This demonstrates generating multiple replies to the same prompt, with a shared cache | ||
/// </summary> | ||
public class BatchedExecutorSaveAndLoad | ||
{ | ||
private const int n_len = 18; | ||
|
||
public static async Task Run() | ||
{ | ||
string modelPath = UserSettings.GetModelPath(); | ||
|
||
var parameters = new ModelParams(modelPath); | ||
using var model = LLamaWeights.LoadFromFile(parameters); | ||
|
||
var prompt = AnsiConsole.Ask("Prompt (or ENTER for default):", "Not many people know that"); | ||
|
||
// Create an executor that can evaluate a batch of conversations together | ||
using var executor = new BatchedExecutor(model, parameters); | ||
|
||
// Print some info | ||
var name = executor.Model.Metadata.GetValueOrDefault("general.name", "unknown model name"); | ||
Console.WriteLine($"Created executor with model: {name}"); | ||
|
||
// Create a conversation | ||
var conversation = executor.Create(); | ||
conversation.Prompt(prompt); | ||
|
||
// Run inference loop | ||
var decoder = new StreamingTokenDecoder(executor.Context); | ||
var sampler = new DefaultSamplingPipeline(); | ||
var lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len); | ||
|
||
// Can't save a conversation while RequiresInference is true | ||
if (conversation.RequiresInference) | ||
await executor.Infer(); | ||
|
||
// Save this conversation to a file and dispose it | ||
conversation.Save("demo_conversation.state"); | ||
conversation.Dispose(); | ||
AnsiConsole.WriteLine($"Saved state: {new FileInfo("demo_conversation.state").Length} bytes"); | ||
|
||
// Now create a new conversation by loading that state | ||
conversation = executor.Load("demo_conversation.state"); | ||
AnsiConsole.WriteLine("Loaded state"); | ||
|
||
// Prompt it again with the last token, so we can continue generating | ||
conversation.Rewind(1); | ||
conversation.Prompt(lastToken); | ||
|
||
// Continue generating text | ||
lastToken = await GenerateTokens(executor, conversation, sampler, decoder, n_len); | ||
|
||
// Can't save a conversation while RequiresInference is true | ||
if (conversation.RequiresInference) | ||
await executor.Infer(); | ||
|
||
// Save the conversation again, this time into system memory | ||
using (var state = conversation.Save()) | ||
{ | ||
conversation.Dispose(); | ||
AnsiConsole.WriteLine($"Saved state to memory: {state.Size} bytes"); | ||
|
||
// Now create a new conversation by loading that state | ||
conversation = executor.Load("demo_conversation.state"); | ||
AnsiConsole.WriteLine("Loaded state"); | ||
} | ||
|
||
// Prompt it again with the last token, so we can continue generating | ||
conversation.Rewind(1); | ||
conversation.Prompt(lastToken); | ||
|
||
// Continue generating text | ||
await GenerateTokens(executor, conversation, sampler, decoder, n_len); | ||
|
||
// Display final ouput | ||
AnsiConsole.MarkupLine($"[red]{prompt}{decoder.Read()}[/]"); | ||
} | ||
|
||
private static async Task<LLamaToken> GenerateTokens(BatchedExecutor executor, Conversation conversation, ISamplingPipeline sampler, StreamingTokenDecoder decoder, int count = 15) | ||
{ | ||
var token = (LLamaToken)0; | ||
|
||
for (var i = 0; i < count; i++) | ||
{ | ||
// Run inference | ||
await executor.Infer(); | ||
|
||
// Use sampling pipeline to pick a token | ||
token = sampler.Sample(executor.Context.NativeHandle, conversation.Sample(), ReadOnlySpan<LLamaToken>.Empty); | ||
|
||
// Add it to the decoder, so it can be converted into text later | ||
decoder.Add(token); | ||
|
||
// Prompt the conversation with the token | ||
conversation.Prompt(token); | ||
} | ||
|
||
return token; | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.