Skip to content

Commit

Permalink
Merge pull request #223 from martindevans/batch_decoding
Browse files Browse the repository at this point in the history
New Binaries, Improved Sampling API, Batch Decoding Prototype
  • Loading branch information
martindevans authored Oct 31, 2023
2 parents f8b2c5d + db8f398 commit 5a9e13c
Show file tree
Hide file tree
Showing 42 changed files with 814 additions and 222 deletions.
1 change: 1 addition & 0 deletions LLama.Examples/LLama.Examples.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
<Platforms>AnyCPU;x64</Platforms>
<!-- Set IncludeBuiltInRuntimes to false to include your own runtime libraries and not link the defaults -->
<IncludeBuiltInRuntimes>true</IncludeBuiltInRuntimes>
<AllowUnsafeBlocks>true</AllowUnsafeBlocks>
</PropertyGroup>

<PropertyGroup Condition="'$(Configuration)|$(Platform)'=='Debug|AnyCPU'">
Expand Down
177 changes: 177 additions & 0 deletions LLama.Examples/NewVersion/BatchedDecoding.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,177 @@
using System.Diagnostics;
using System.Security.Cryptography;
using System.Text;
using LLama.Common;
using LLama.Native;

namespace LLama.Examples.NewVersion;

/// <summary>
/// This demonstrates generating multiple replies to the same prompt, with a shared cache
/// </summary>
/// <remarks>Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!</remarks>
public class BatchedDecoding
{
private const int n_parallel = 8;
private const int n_len = 32;

private const int top_k = 80;
private const float top_p = 0.8f;
private const float temp = 0.5f;

public static async Task Run()
{
Console.Write("Please input your model path: ");
var modelPath = Console.ReadLine();

Console.WriteLine("Prompt (leave blank to select automatically):");
var prompt = Console.ReadLine();
if (string.IsNullOrWhiteSpace(prompt))
prompt = "Not many people know that";

// Load model
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);

// Tokenize prompt
var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8);
var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;

// Create a context
parameters.ContextSize = (uint)model.ContextSize;
parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
using var context = model.CreateContext(parameters);

var n_ctx = context.ContextSize;

// make sure the KV cache is big enough to hold all the prompt and generated tokens
if (n_kv_req > n_ctx)
{
await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
return;
}

using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1);

// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false);
Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length);

// llama_decode will output logits only for the last token of the prompt
unsafe
{
batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
}

if (context.NativeHandle.Decode(batch) != 0)
{
await Console.Error.WriteLineAsync("llama_decode failed");
return;
}

// assign the system KV cache to all parallel sequences
// this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
for (var i = 1; i < n_parallel; ++i)
{
NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens);
}

if (n_parallel > 1)
{
Console.WriteLine();
Console.WriteLine($"generating {n_parallel} sequences...");
}

// remember the batch index of the last token for each parallel sequence
// we need this to determine which logits to sample from
List<int> i_batch = new();
for (var i = 0; i < n_parallel; i++)
i_batch.Add(batch.NativeBatch.n_tokens - 1);

var n_cur = batch.NativeBatch.n_tokens;
var n_decode = 0;

var streams = new List<int>[n_parallel];
for (var i = 0; i < n_parallel; i++)
streams[i] = new();

var eos = model.EndOfSentenceToken;
var nl = model.NewlineToken;

var timer = new Stopwatch();
timer.Start();
while (n_cur <= n_len)
{
batch.LLamaBatchClear();

for (var i = 0; i < n_parallel; i++)
{
// Skip completed streams
if (i_batch[i] < 0)
continue;

var n_vocab = model.VocabCount;
LLamaTokenDataArray candidates;
unsafe
{
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
}

candidates.TopK(context.NativeHandle, top_k);
candidates.TopP(context.NativeHandle, top_p);
candidates.Temperature(context.NativeHandle, temp);
var new_token_id = candidates.SampleToken(context.NativeHandle);

if (new_token_id == eos || new_token_id == nl)
{
i_batch[i] = -1;
Console.WriteLine($"Completed Stream {i} early");
continue;
}

streams[i].Add(new_token_id);

i_batch[i] = batch.NativeBatch.n_tokens;

// push this new token for next evaluation
batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);

n_decode++;
}

// all streams are finished
if (batch.NativeBatch.n_tokens == 0)
{
break;
}

n_cur++;

// evaluate the current batch with the transformer model
if (context.NativeHandle.Decode(batch) != 0)
{
await Console.Error.WriteLineAsync("failed to eval");
return;
}
}

timer.Stop();
Console.ForegroundColor = ConsoleColor.Yellow;
Console.WriteLine();
Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");

var index = 0;
foreach (var stream in streams)
{
var text = context.DeTokenize(stream);

Console.ForegroundColor = ConsoleColor.Green;
Console.Write($"{index++}. {prompt}");
Console.ForegroundColor = ConsoleColor.Red;
Console.WriteLine(text);
}
}
}
5 changes: 1 addition & 4 deletions LLama.Examples/NewVersion/SemanticKernelChat.cs
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,7 @@ public static async Task Run()
var modelPath = Console.ReadLine();

// Load weights into memory
var parameters = new ModelParams(modelPath)
{
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)),
};
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
Expand Down
5 changes: 1 addition & 4 deletions LLama.Examples/NewVersion/SemanticKernelPrompt.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,7 @@ public static async Task Run()
var modelPath = Console.ReadLine();

// Load weights into memory
var parameters = new ModelParams(modelPath)
{
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
};
var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters);

Expand Down
5 changes: 1 addition & 4 deletions LLama.Examples/NewVersion/TalkToYourself.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,10 +13,7 @@ public static async Task Run()
var modelPath = Console.ReadLine();

// Load weights into memory
var @params = new ModelParams(modelPath)
{
Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
};
var @params = new ModelParams(modelPath);
using var weights = LLamaWeights.LoadFromFile(@params);

// Create 2 contexts sharing the same weights
Expand Down
5 changes: 5 additions & 0 deletions LLama.Examples/NewVersion/TestRunner.cs
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ public static async Task Run()
Console.WriteLine("12: Semantic Kernel Chat.");
Console.WriteLine("13: Semantic Kernel Memory.");
Console.WriteLine("14: Coding Assistant.");
Console.WriteLine("15: Batch Decoding.");

while (true)
{
Expand Down Expand Up @@ -88,6 +89,10 @@ public static async Task Run()
{
await CodingAssistant.Run();
}
else if (choice == 15)
{
await BatchedDecoding.Run();
}
else
{
Console.WriteLine("Cannot parse your choice. Please select again.");
Expand Down
8 changes: 8 additions & 0 deletions LLama.Unittest/StatelessExecutorTest.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Diagnostics;
using LLama.Common;
using Xunit.Abstractions;

Expand Down Expand Up @@ -34,10 +35,17 @@ public async Task Stateless()
const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };

var timer = new Stopwatch();
timer.Start();

var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());

timer.Stop();
_testOutputHelper.WriteLine($"{timer.ElapsedMilliseconds}ms");

_testOutputHelper.WriteLine(result1);
_testOutputHelper.WriteLine(result2);

// Check that it produced the exact same result both times
Assert.Equal(result1, result2);
Expand Down
2 changes: 1 addition & 1 deletion LLama.Web/Common/InferenceOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ public class InferenceOptions : IInferenceParams
/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
public IEnumerable<string> AntiPrompts { get; set; } = Array.Empty<string>();
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
/// <summary>
/// path to file for saving/loading model eval state
/// </summary>
Expand Down
4 changes: 2 additions & 2 deletions LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -111,12 +111,12 @@ public class ModelOptions
/// <summary>
/// RoPE base frequency
/// </summary>
public float RopeFrequencyBase { get; set; } = 10000.0f;
public float? RopeFrequencyBase { get; set; }

/// <summary>
/// RoPE frequency scaling factor
/// </summary>
public float RopeFrequencyScale { get; set; } = 1.0f;
public float? RopeFrequencyScale { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels
Expand Down
8 changes: 4 additions & 4 deletions LLama/Abstractions/IContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -39,14 +39,14 @@ public interface IContextParams
bool EmbeddingMode { get; set; }

/// <summary>
/// RoPE base frequency
/// RoPE base frequency (null to fetch from the model)
/// </summary>
float RopeFrequencyBase { get; set; }
float? RopeFrequencyBase { get; set; }

/// <summary>
/// RoPE frequency scaling factor
/// RoPE frequency scaling factor (null to fetch from the model)
/// </summary>
float RopeFrequencyScale { get; set; }
float? RopeFrequencyScale { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels
Expand Down
2 changes: 1 addition & 1 deletion LLama/Abstractions/IInferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ public interface IInferenceParams
/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
public IEnumerable<string> AntiPrompts { get; set; }
public IReadOnlyList<string> AntiPrompts { get; set; }

/// <summary>
/// 0 or lower to use vocab size
Expand Down
2 changes: 1 addition & 1 deletion LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public record InferenceParams : IInferenceParams
/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
public IEnumerable<string> AntiPrompts { get; set; } = Array.Empty<string>();
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();

/// <summary>
/// 0 or lower to use vocab size
Expand Down
6 changes: 3 additions & 3 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,12 +91,12 @@ public record ModelParams
/// <summary>
/// RoPE base frequency
/// </summary>
public float RopeFrequencyBase { get; set; } = 10000.0f;
public float? RopeFrequencyBase { get; set; }

/// <summary>
/// RoPE frequency scaling factor
/// </summary>
public float RopeFrequencyScale { get; set; } = 1.0f;
public float? RopeFrequencyScale { get; set; }

/// <summary>
/// Use experimental mul_mat_q kernels
Expand Down Expand Up @@ -156,7 +156,7 @@ public ModelParams(string modelPath, uint contextSize = 512, int gpuLayerCount =
bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false,
string loraAdapter = "", string loraBase = "", int threads = -1, uint batchSize = 512,
bool embeddingMode = false,
float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false,
float? ropeFrequencyBase = null, float? ropeFrequencyScale = null, bool mulMatQ = false,
string encoding = "UTF-8")
{
ContextSize = contextSize;
Expand Down
4 changes: 2 additions & 2 deletions LLama/Extensions/IContextParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -27,8 +27,8 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.f16_kv = @params.UseFp16Memory;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
result.rope_freq_base = @params.RopeFrequencyBase;
result.rope_freq_scale = @params.RopeFrequencyScale;
result.rope_freq_base = @params.RopeFrequencyBase ?? 0;
result.rope_freq_scale = @params.RopeFrequencyScale ?? 0;
result.mul_mat_q = @params.MulMatQ;

result.n_threads = Threads(@params.Threads);
Expand Down
Loading

0 comments on commit 5a9e13c

Please sign in to comment.