diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj
index a4ebb604e..7fea45623 100644
--- a/LLama.Examples/LLama.Examples.csproj
+++ b/LLama.Examples/LLama.Examples.csproj
@@ -8,6 +8,7 @@
AnyCPU;x64
true
+ true
diff --git a/LLama.Examples/NewVersion/BatchedDecoding.cs b/LLama.Examples/NewVersion/BatchedDecoding.cs
new file mode 100644
index 000000000..ad29703a3
--- /dev/null
+++ b/LLama.Examples/NewVersion/BatchedDecoding.cs
@@ -0,0 +1,177 @@
+using System.Diagnostics;
+using System.Security.Cryptography;
+using System.Text;
+using LLama.Common;
+using LLama.Native;
+
+namespace LLama.Examples.NewVersion;
+
+///
+/// This demonstrates generating multiple replies to the same prompt, with a shared cache
+///
+/// Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this!
+public class BatchedDecoding
+{
+ private const int n_parallel = 8;
+ private const int n_len = 32;
+
+ private const int top_k = 80;
+ private const float top_p = 0.8f;
+ private const float temp = 0.5f;
+
+ public static async Task Run()
+ {
+ Console.Write("Please input your model path: ");
+ var modelPath = Console.ReadLine();
+
+ Console.WriteLine("Prompt (leave blank to select automatically):");
+ var prompt = Console.ReadLine();
+ if (string.IsNullOrWhiteSpace(prompt))
+ prompt = "Not many people know that";
+
+ // Load model
+ var parameters = new ModelParams(modelPath);
+ using var model = LLamaWeights.LoadFromFile(parameters);
+
+ // Tokenize prompt
+ var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8);
+ var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel;
+
+ // Create a context
+ parameters.ContextSize = (uint)model.ContextSize;
+ parameters.BatchSize = (uint)Math.Max(n_len, n_parallel);
+ using var context = model.CreateContext(parameters);
+
+ var n_ctx = context.ContextSize;
+
+ // make sure the KV cache is big enough to hold all the prompt and generated tokens
+ if (n_kv_req > n_ctx)
+ {
+ await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n");
+ await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n");
+ return;
+ }
+
+ using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1);
+
+ // evaluate the initial prompt
+ for (var i = 0; i < prompt_tokens.Length; i++)
+ batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false);
+ Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length);
+
+ // llama_decode will output logits only for the last token of the prompt
+ unsafe
+ {
+ batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1;
+ }
+
+ if (context.NativeHandle.Decode(batch) != 0)
+ {
+ await Console.Error.WriteLineAsync("llama_decode failed");
+ return;
+ }
+
+ // assign the system KV cache to all parallel sequences
+ // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them
+ for (var i = 1; i < n_parallel; ++i)
+ {
+ NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens);
+ }
+
+ if (n_parallel > 1)
+ {
+ Console.WriteLine();
+ Console.WriteLine($"generating {n_parallel} sequences...");
+ }
+
+ // remember the batch index of the last token for each parallel sequence
+ // we need this to determine which logits to sample from
+ List i_batch = new();
+ for (var i = 0; i < n_parallel; i++)
+ i_batch.Add(batch.NativeBatch.n_tokens - 1);
+
+ var n_cur = batch.NativeBatch.n_tokens;
+ var n_decode = 0;
+
+ var streams = new List[n_parallel];
+ for (var i = 0; i < n_parallel; i++)
+ streams[i] = new();
+
+ var eos = model.EndOfSentenceToken;
+ var nl = model.NewlineToken;
+
+ var timer = new Stopwatch();
+ timer.Start();
+ while (n_cur <= n_len)
+ {
+ batch.LLamaBatchClear();
+
+ for (var i = 0; i < n_parallel; i++)
+ {
+ // Skip completed streams
+ if (i_batch[i] < 0)
+ continue;
+
+ var n_vocab = model.VocabCount;
+ LLamaTokenDataArray candidates;
+ unsafe
+ {
+ candidates = LLamaTokenDataArray.Create(new Span(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
+ }
+
+ candidates.TopK(context.NativeHandle, top_k);
+ candidates.TopP(context.NativeHandle, top_p);
+ candidates.Temperature(context.NativeHandle, temp);
+ var new_token_id = candidates.SampleToken(context.NativeHandle);
+
+ if (new_token_id == eos || new_token_id == nl)
+ {
+ i_batch[i] = -1;
+ Console.WriteLine($"Completed Stream {i} early");
+ continue;
+ }
+
+ streams[i].Add(new_token_id);
+
+ i_batch[i] = batch.NativeBatch.n_tokens;
+
+ // push this new token for next evaluation
+ batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true);
+
+ n_decode++;
+ }
+
+ // all streams are finished
+ if (batch.NativeBatch.n_tokens == 0)
+ {
+ break;
+ }
+
+ n_cur++;
+
+ // evaluate the current batch with the transformer model
+ if (context.NativeHandle.Decode(batch) != 0)
+ {
+ await Console.Error.WriteLineAsync("failed to eval");
+ return;
+ }
+ }
+
+ timer.Stop();
+ Console.ForegroundColor = ConsoleColor.Yellow;
+ Console.WriteLine();
+ Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms");
+ Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second");
+
+ var index = 0;
+ foreach (var stream in streams)
+ {
+ var text = context.DeTokenize(stream);
+
+ Console.ForegroundColor = ConsoleColor.Green;
+ Console.Write($"{index++}. {prompt}");
+ Console.ForegroundColor = ConsoleColor.Red;
+ Console.WriteLine(text);
+ }
+ }
+}
\ No newline at end of file
diff --git a/LLama.Examples/NewVersion/SemanticKernelChat.cs b/LLama.Examples/NewVersion/SemanticKernelChat.cs
index 9fd59058f..4f27bf659 100644
--- a/LLama.Examples/NewVersion/SemanticKernelChat.cs
+++ b/LLama.Examples/NewVersion/SemanticKernelChat.cs
@@ -14,10 +14,7 @@ public static async Task Run()
var modelPath = Console.ReadLine();
// Load weights into memory
- var parameters = new ModelParams(modelPath)
- {
- Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)),
- };
+ var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
using var context = model.CreateContext(parameters);
var ex = new InteractiveExecutor(context);
diff --git a/LLama.Examples/NewVersion/SemanticKernelPrompt.cs b/LLama.Examples/NewVersion/SemanticKernelPrompt.cs
index 52591f918..d744b7630 100644
--- a/LLama.Examples/NewVersion/SemanticKernelPrompt.cs
+++ b/LLama.Examples/NewVersion/SemanticKernelPrompt.cs
@@ -16,10 +16,7 @@ public static async Task Run()
var modelPath = Console.ReadLine();
// Load weights into memory
- var parameters = new ModelParams(modelPath)
- {
- Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
- };
+ var parameters = new ModelParams(modelPath);
using var model = LLamaWeights.LoadFromFile(parameters);
var ex = new StatelessExecutor(model, parameters);
diff --git a/LLama.Examples/NewVersion/TalkToYourself.cs b/LLama.Examples/NewVersion/TalkToYourself.cs
index 4c412c93f..80d129ca1 100644
--- a/LLama.Examples/NewVersion/TalkToYourself.cs
+++ b/LLama.Examples/NewVersion/TalkToYourself.cs
@@ -13,10 +13,7 @@ public static async Task Run()
var modelPath = Console.ReadLine();
// Load weights into memory
- var @params = new ModelParams(modelPath)
- {
- Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue))
- };
+ var @params = new ModelParams(modelPath);
using var weights = LLamaWeights.LoadFromFile(@params);
// Create 2 contexts sharing the same weights
diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs
index d88087bb4..2f698f803 100644
--- a/LLama.Examples/NewVersion/TestRunner.cs
+++ b/LLama.Examples/NewVersion/TestRunner.cs
@@ -22,6 +22,7 @@ public static async Task Run()
Console.WriteLine("12: Semantic Kernel Chat.");
Console.WriteLine("13: Semantic Kernel Memory.");
Console.WriteLine("14: Coding Assistant.");
+ Console.WriteLine("15: Batch Decoding.");
while (true)
{
@@ -88,6 +89,10 @@ public static async Task Run()
{
await CodingAssistant.Run();
}
+ else if (choice == 15)
+ {
+ await BatchedDecoding.Run();
+ }
else
{
Console.WriteLine("Cannot parse your choice. Please select again.");
diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs
index 19f618af2..195cc4a28 100644
--- a/LLama.Unittest/StatelessExecutorTest.cs
+++ b/LLama.Unittest/StatelessExecutorTest.cs
@@ -1,3 +1,4 @@
+using System.Diagnostics;
using LLama.Common;
using Xunit.Abstractions;
@@ -34,10 +35,17 @@ public async Task Stateless()
const string question = "Question. what is a cat?\nAnswer: ";
var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } };
+ var timer = new Stopwatch();
+ timer.Start();
+
var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync());
+ timer.Stop();
+ _testOutputHelper.WriteLine($"{timer.ElapsedMilliseconds}ms");
+
_testOutputHelper.WriteLine(result1);
+ _testOutputHelper.WriteLine(result2);
// Check that it produced the exact same result both times
Assert.Equal(result1, result2);
diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs
index c2420af37..0b38a04d8 100644
--- a/LLama.Web/Common/InferenceOptions.cs
+++ b/LLama.Web/Common/InferenceOptions.cs
@@ -23,7 +23,7 @@ public class InferenceOptions : IInferenceParams
///
/// Sequences where the model will stop generating further tokens.
///
- public IEnumerable AntiPrompts { get; set; } = Array.Empty();
+ public IReadOnlyList AntiPrompts { get; set; } = Array.Empty();
///
/// path to file for saving/loading model eval state
///
diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs
index 20a3e348a..6a63ccc31 100644
--- a/LLama.Web/Common/ModelOptions.cs
+++ b/LLama.Web/Common/ModelOptions.cs
@@ -111,12 +111,12 @@ public class ModelOptions
///
/// RoPE base frequency
///
- public float RopeFrequencyBase { get; set; } = 10000.0f;
+ public float? RopeFrequencyBase { get; set; }
///
/// RoPE frequency scaling factor
///
- public float RopeFrequencyScale { get; set; } = 1.0f;
+ public float? RopeFrequencyScale { get; set; }
///
/// Use experimental mul_mat_q kernels
diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs
index 201a9b9ad..8ff6d7ccf 100644
--- a/LLama/Abstractions/IContextParams.cs
+++ b/LLama/Abstractions/IContextParams.cs
@@ -39,14 +39,14 @@ public interface IContextParams
bool EmbeddingMode { get; set; }
///
- /// RoPE base frequency
+ /// RoPE base frequency (null to fetch from the model)
///
- float RopeFrequencyBase { get; set; }
+ float? RopeFrequencyBase { get; set; }
///
- /// RoPE frequency scaling factor
+ /// RoPE frequency scaling factor (null to fetch from the model)
///
- float RopeFrequencyScale { get; set; }
+ float? RopeFrequencyScale { get; set; }
///
/// Use experimental mul_mat_q kernels
diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs
index 93a9b52ba..a21c7306b 100644
--- a/LLama/Abstractions/IInferenceParams.cs
+++ b/LLama/Abstractions/IInferenceParams.cs
@@ -29,7 +29,7 @@ public interface IInferenceParams
///
/// Sequences where the model will stop generating further tokens.
///
- public IEnumerable AntiPrompts { get; set; }
+ public IReadOnlyList AntiPrompts { get; set; }
///
/// 0 or lower to use vocab size
diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs
index bef64631a..d0217e2f8 100644
--- a/LLama/Common/InferenceParams.cs
+++ b/LLama/Common/InferenceParams.cs
@@ -28,7 +28,7 @@ public record InferenceParams : IInferenceParams
///
/// Sequences where the model will stop generating further tokens.
///
- public IEnumerable AntiPrompts { get; set; } = Array.Empty();
+ public IReadOnlyList AntiPrompts { get; set; } = Array.Empty();
///
/// 0 or lower to use vocab size
diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs
index 78f51c6cb..ee5bd3e4c 100644
--- a/LLama/Common/ModelParams.cs
+++ b/LLama/Common/ModelParams.cs
@@ -91,12 +91,12 @@ public record ModelParams
///
/// RoPE base frequency
///
- public float RopeFrequencyBase { get; set; } = 10000.0f;
+ public float? RopeFrequencyBase { get; set; }
///
/// RoPE frequency scaling factor
///
- public float RopeFrequencyScale { get; set; } = 1.0f;
+ public float? RopeFrequencyScale { get; set; }
///
/// Use experimental mul_mat_q kernels
@@ -156,7 +156,7 @@ public ModelParams(string modelPath, uint contextSize = 512, int gpuLayerCount =
bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false,
string loraAdapter = "", string loraBase = "", int threads = -1, uint batchSize = 512,
bool embeddingMode = false,
- float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false,
+ float? ropeFrequencyBase = null, float? ropeFrequencyScale = null, bool mulMatQ = false,
string encoding = "UTF-8")
{
ContextSize = contextSize;
diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs
index 7ca508a2b..fcc9d372a 100644
--- a/LLama/Extensions/IContextParamsExtensions.cs
+++ b/LLama/Extensions/IContextParamsExtensions.cs
@@ -27,8 +27,8 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.f16_kv = @params.UseFp16Memory;
result.logits_all = @params.Perplexity;
result.embedding = @params.EmbeddingMode;
- result.rope_freq_base = @params.RopeFrequencyBase;
- result.rope_freq_scale = @params.RopeFrequencyScale;
+ result.rope_freq_base = @params.RopeFrequencyBase ?? 0;
+ result.rope_freq_scale = @params.RopeFrequencyScale ?? 0;
result.mul_mat_q = @params.MulMatQ;
result.n_threads = Threads(@params.Threads);
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 47240dc71..0ccae76b4 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -235,13 +235,13 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu
if (grammar != null)
{
- SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar);
+ candidates.ApplyGrammar(NativeHandle, grammar);
}
if (temperature <= 0)
{
// Greedy sampling
- id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates);
+ id = candidates.SampleTokenGreedy(NativeHandle);
}
else
{
@@ -250,32 +250,28 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu
if (mirostat == MirostatType.Mirostat)
{
const int mirostat_m = 100;
- SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
- id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu);
+ candidates.Temperature(NativeHandle, temperature);
+ id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu);
}
else if (mirostat == MirostatType.Mirostat2)
{
- SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
- id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu);
+ candidates.Temperature(NativeHandle, temperature);
+ id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu);
}
else
{
- // Temperature sampling
- SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1);
- SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1);
- SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1);
- SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1);
- SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature);
- id = SamplingApi.llama_sample_token(NativeHandle, candidates);
+ candidates.TopK(NativeHandle, topK);
+ candidates.TailFree(NativeHandle, tfsZ);
+ candidates.LocallyTypical(NativeHandle, typicalP);
+ candidates.TopP(NativeHandle, topP);
+ candidates.Temperature(NativeHandle, temperature);
+ id = candidates.SampleToken(NativeHandle);
}
}
mirostat_mu = mu;
}
- if (grammar != null)
- {
- NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id);
- }
+ grammar?.AcceptToken(NativeHandle, id);
return id;
}
@@ -305,7 +301,7 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dic
}
// Save the newline logit value
- var nl_token = NativeApi.llama_token_nl(NativeHandle);
+ var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle);
var nl_logit = logits[nl_token];
// Convert logits into token candidates
@@ -316,8 +312,7 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dic
var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray();
// Apply penalties to candidates
- SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty);
- SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence);
+ candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence);
// Restore newline token logit value if necessary
if (!penalizeNL)
@@ -369,7 +364,7 @@ public int Eval(List tokens, int pastTokensCount)
try
{
tokens.CopyTo(rented, 0);
- return Eval(rented, pastTokensCount);
+ return Eval(rented.AsSpan(0, tokens.Count), pastTokensCount);
}
finally
{
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 9e4292ea6..80c6f5420 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -163,7 +163,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
}
}
- if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
+ if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{
args.WaitForInput = true;
}
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index d3d4a9e39..5078648b1 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -30,7 +30,7 @@ public class InteractiveExecutor : StatefulExecutorBase
public InteractiveExecutor(LLamaContext context, ILogger? logger = null)
: base(context, logger)
{
- _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle);
+ _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle);
}
///
@@ -141,7 +141,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args)
return (true, Array.Empty());
}
- if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle))
+ if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{
return (true, new[] { " [end of text]\n" });
}
@@ -202,7 +202,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
_last_n_tokens.Enqueue(id);
- if (id == NativeApi.llama_token_eos(Context.NativeHandle))
+ if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle))
{
id = _llama_token_newline;
if (args.Antiprompts is not null && args.Antiprompts.Count > 0)
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index ab1e9bbc0..4bdeaa3f2 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -6,7 +6,6 @@
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
-using LLama.Extensions;
using LLama.Native;
using Microsoft.Extensions.Logging;
@@ -47,68 +46,66 @@ public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger?
}
///
- public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
+ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default)
{
- using var context = _weights.CreateContext(_params, _logger);
- Context = context;
-
+ // Ensure the context from last time is disposed (it always hould be)
if (!Context.NativeHandle.IsClosed)
Context.Dispose();
- Context = _weights.CreateContext(Context.Params, _logger);
-
- var decoder = new StreamingTokenDecoder(Context);
- var antiprocessor = new AntipromptProcessor(inferenceParams?.AntiPrompts ?? Array.Empty());
-
- if (inferenceParams != null)
- {
- if (inferenceParams.TokensKeep > Context.ContextSize)
- throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
- }
- cancellationToken.ThrowIfCancellationRequested();
+ // Create an inference context which will be disposed when this method exits
+ using var context = _weights.CreateContext(_params, _logger);
+ Context = context;
+ // Sanity check inference params
inferenceParams ??= new InferenceParams();
+ if (inferenceParams.TokensKeep > Context.ContextSize)
+ throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})");
+
+ // Create decoders for the token stream
+ var decoder = new StreamingTokenDecoder(Context);
+ var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts);
- var lastTokens = new List(inferenceParams.RepeatLastTokensCount);
- for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++)
+ // Keep track of the last N tokens emitted
+ var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
+ var lastTokens = new List(repeat_last_n);
+ for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add(0);
- var tokens = Context.Tokenize(text).ToList();
+ // Tokenize the prompt
+ var tokens = Context.Tokenize(prompt).ToList();
+ lastTokens.AddRange(tokens);
+ var n_past = 1 + tokens.Count;
+ // Evaluate the prompt
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);
- lastTokens.AddRange(tokens);
- var n_past = 1 + tokens.Count;
-
+ // Begin loop, evaluating one token at a time
var mu = (float?)null;
var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens;
- for(var i = 0; i < max_tokens; i++)
+ for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++)
{
- if (cancellationToken.IsCancellationRequested)
- break;
-
- var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount;
-
+ // Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
+ // Sample a single token
var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
- lastTokens.Add(id);
-
+ // Decode this token into text
decoder.Add(id);
var decoded = decoder.Read();
yield return decoded;
- tokens.Clear();
- tokens.Add(id);
-
// Check if any of the antiprompts have been generated
if (antiprocessor.Add(decoded))
break;
+ lastTokens.Add(id);
+ tokens.Clear();
+ tokens.Add(id);
+
// when run out of context
// based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497
if (n_past + tokens.Count >= Context.ContextSize)
diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs
index 64878e2ab..7ae104a5c 100644
--- a/LLama/LLamaWeights.cs
+++ b/LLama/LLamaWeights.cs
@@ -38,6 +38,21 @@ public sealed class LLamaWeights
///
public ulong ParameterCount => NativeHandle.ParameterCount;
+ ///
+ /// Get the newline token for this model
+ ///
+ public int NewlineToken => NativeApi.llama_token_nl(NativeHandle);
+
+ ///
+ /// Get the "end of sentence" token for this model
+ ///
+ public int EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle);
+
+ ///
+ /// Get the "beginning of sentence" token for this model
+ ///
+ public int BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle);
+
///
/// Dimension of embedding vectors
///
diff --git a/LLama/Native/LLamaBatchSafeHandle.cs b/LLama/Native/LLamaBatchSafeHandle.cs
index 1f7ef2c53..30e703946 100644
--- a/LLama/Native/LLamaBatchSafeHandle.cs
+++ b/LLama/Native/LLamaBatchSafeHandle.cs
@@ -15,7 +15,7 @@ public sealed class LLamaBatchSafeHandle
///
/// Get the native llama_batch struct
///
- public LLamaNativeBatch NativeBatch { get; private set; }
+ public LLamaNativeBatch NativeBatch;
///
/// the token ids of the input (used when embd is NULL)
@@ -113,10 +113,11 @@ public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd)
///
///
///
+ ///
///
- public static LLamaBatchSafeHandle Create(int n_tokens, int embd)
+ public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max)
{
- var batch = NativeApi.llama_batch_init(n_tokens, embd);
+ var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max);
return new LLamaBatchSafeHandle(batch, embd);
}
@@ -128,4 +129,32 @@ protected override bool ReleaseHandle()
SetHandle(IntPtr.Zero);
return true;
}
+
+ ///
+ /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2
+ ///
+ public void LLamaBatchAdd(int token, LLamaPos pos, ReadOnlySpan sequences, bool logits)
+ {
+ unsafe
+ {
+ NativeBatch.token[NativeBatch.n_tokens] = token;
+ NativeBatch.pos[NativeBatch.n_tokens] = pos;
+ NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length;
+
+ for (var i = 0; i < sequences.Length; i++)
+ NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i];
+
+ NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits);
+
+ NativeBatch.n_tokens++;
+ }
+ }
+
+ ///
+ /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825
+ ///
+ public void LLamaBatchClear()
+ {
+ NativeBatch.n_tokens = 0;
+ }
}
\ No newline at end of file
diff --git a/LLama/Native/LLamaBeamView.cs b/LLama/Native/LLamaBeamView.cs
index e6a6c39f5..e6bc504ee 100644
--- a/LLama/Native/LLamaBeamView.cs
+++ b/LLama/Native/LLamaBeamView.cs
@@ -11,13 +11,13 @@ namespace LLama.Native;
[StructLayout(LayoutKind.Sequential)]
public struct LLamaBeamView
{
- private readonly unsafe llama_token* tokens;
- private readonly nint n_tokens;
+ private unsafe llama_token* tokens;
+ private nint n_tokens;
///
/// Cumulative beam probability (renormalized relative to all beams)
///
- public readonly float CumulativeProbability;
+ public float CumulativeProbability;
///
/// Callback should set this to true when a beam is at end-of-beam.
diff --git a/LLama/Native/LLamaBeamsState.cs b/LLama/Native/LLamaBeamsState.cs
index 6f0a447d7..f78c45b97 100644
--- a/LLama/Native/LLamaBeamsState.cs
+++ b/LLama/Native/LLamaBeamsState.cs
@@ -9,27 +9,27 @@ namespace LLama.Native;
/// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks.
///
[StructLayout(LayoutKind.Sequential)]
-public readonly struct LLamaBeamsState
+public struct LLamaBeamsState
{
///
/// The state of each individual beam
///
- private readonly unsafe LLamaBeamView* beam_views;
+ private unsafe LLamaBeamView* beam_views;
///
/// Number of elements in beam_views
///
- private readonly nint n_beams;
+ private nint n_beams;
///
/// Current max length of prefix tokens shared by all beams.
///
- public readonly ulong CommonPrefixLength;
+ public ulong CommonPrefixLength;
///
/// True iff this is the last callback invocation.
///
- public readonly bool LastCall;
+ public bool LastCall;
///
/// The current state of each beam
diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs
index 5a7f27bb2..96313f239 100644
--- a/LLama/Native/LLamaGrammarElement.cs
+++ b/LLama/Native/LLamaGrammarElement.cs
@@ -52,18 +52,18 @@ public enum LLamaGrammarElementType
///
[StructLayout(LayoutKind.Sequential)]
[DebuggerDisplay("{Type} {Value}")]
- public readonly struct LLamaGrammarElement
+ public struct LLamaGrammarElement
: IEquatable
{
///
/// The type of this element
///
- public readonly LLamaGrammarElementType Type;
+ public LLamaGrammarElementType Type;
///
/// Unicode code point or rule ID
///
- public readonly uint Value;
+ public uint Value;
///
/// Construct a new LLamaGrammarElement
diff --git a/LLama/Native/LLamaModelQuantizeParams.cs b/LLama/Native/LLamaModelQuantizeParams.cs
index 4d0b9e7fa..8f10b2ff3 100644
--- a/LLama/Native/LLamaModelQuantizeParams.cs
+++ b/LLama/Native/LLamaModelQuantizeParams.cs
@@ -1,10 +1,12 @@
using System;
+using System.Runtime.InteropServices;
namespace LLama.Native
{
///
/// Quantizer parameters used in the native API
///
+ [StructLayout(LayoutKind.Sequential)]
public struct LLamaModelQuantizeParams
{
///
diff --git a/LLama/Native/LLamaNativeBatch.cs b/LLama/Native/LLamaNativeBatch.cs
index 576f8b279..d46f8b99e 100644
--- a/LLama/Native/LLamaNativeBatch.cs
+++ b/LLama/Native/LLamaNativeBatch.cs
@@ -11,35 +11,47 @@ namespace LLama.Native;
/// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens
///
[StructLayout(LayoutKind.Sequential)]
-public readonly unsafe struct LLamaNativeBatch
+public unsafe struct LLamaNativeBatch
{
///
/// The number of items pointed at by pos, seq_id and logits.
///
- public readonly int n_tokens;
+ public int n_tokens;
///
/// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created
///
- public readonly llama_token* token;
+ public llama_token* token;
///
/// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created
///
- public readonly float* embd;
+ public float* embd;
///
/// the positions of the respective token in the sequence
///
- public readonly LLamaPos* pos;
+ public LLamaPos* pos;
+
+ ///
+ /// https://github.com/ggerganov/llama.cpp/blob/master/llama.h#L139 ???
+ ///
+ public int* n_seq_id;
///
/// the sequence to which the respective token belongs
///
- public readonly LLamaSeqId* seq_id;
+ public LLamaSeqId** seq_id;
///
/// if zero, the logits for the respective token will not be output
///
- public readonly byte* logits;
+ public byte* logits;
+
+ // Note from llama.cpp:
+ // > helpers for smooth API transition - can be deprecated in the future
+ // > for future-proof code, use the above fields instead and ignore everything below
+ private LLamaPos _all_pos_0;
+ private LLamaPos _all_pos_1;
+ private LLamaSeqId _all_seq_id;
}
\ No newline at end of file
diff --git a/LLama/Native/LLamaPos.cs b/LLama/Native/LLamaPos.cs
index 4deae57b6..67ede7d52 100644
--- a/LLama/Native/LLamaPos.cs
+++ b/LLama/Native/LLamaPos.cs
@@ -1,14 +1,26 @@
-namespace LLama.Native;
+using System.Runtime.InteropServices;
+
+namespace LLama.Native;
///
/// Indicates position in a sequence
///
-public readonly record struct LLamaPos(int Value)
+[StructLayout(LayoutKind.Sequential)]
+public struct LLamaPos
{
///
/// The raw value
///
- public readonly int Value = Value;
+ public int Value;
+
+ ///
+ /// Create a new LLamaPos
+ ///
+ ///
+ public LLamaPos(int value)
+ {
+ Value = value;
+ }
///
/// Convert a LLamaPos into an integer (extract the raw value)
diff --git a/LLama/Native/LLamaSeqId.cs b/LLama/Native/LLamaSeqId.cs
index 4a665a2c3..191a6b5ec 100644
--- a/LLama/Native/LLamaSeqId.cs
+++ b/LLama/Native/LLamaSeqId.cs
@@ -1,15 +1,26 @@
-namespace LLama.Native;
+using System.Runtime.InteropServices;
+
+namespace LLama.Native;
///
/// ID for a sequence in a batch
///
-///
-public record struct LLamaSeqId(int Value)
+[StructLayout(LayoutKind.Sequential)]
+public struct LLamaSeqId
{
///
/// The raw value
///
- public int Value = Value;
+ public int Value;
+
+ ///
+ /// Create a new LLamaSeqId
+ ///
+ ///
+ public LLamaSeqId(int value)
+ {
+ Value = value;
+ }
///
/// Convert a LLamaSeqId into an integer (extract the raw value)
diff --git a/LLama/Native/LLamaTokenData.cs b/LLama/Native/LLamaTokenData.cs
index 1ea6820dd..45edd4542 100644
--- a/LLama/Native/LLamaTokenData.cs
+++ b/LLama/Native/LLamaTokenData.cs
@@ -5,24 +5,34 @@ namespace LLama.Native;
///
/// A single token along with probability of this token being selected
///
-///
-///
-///
[StructLayout(LayoutKind.Sequential)]
-public record struct LLamaTokenData(int id, float logit, float p)
+public struct LLamaTokenData
{
///
/// token id
///
- public int id = id;
+ public int id;
///
/// log-odds of the token
///
- public float logit = logit;
+ public float logit;
///
/// probability of the token
///
- public float p = p;
+ public float p;
+
+ ///
+ /// Create a new LLamaTokenData
+ ///
+ ///
+ ///
+ ///
+ public LLamaTokenData(int id, float logit, float p)
+ {
+ this.id = id;
+ this.logit = logit;
+ this.p = p;
+ }
}
\ No newline at end of file
diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs
index 7a2965ed8..8f20a73ac 100644
--- a/LLama/Native/LLamaTokenDataArray.cs
+++ b/LLama/Native/LLamaTokenDataArray.cs
@@ -45,6 +45,199 @@ public static LLamaTokenDataArray Create(ReadOnlySpan logits)
return new LLamaTokenDataArray(candidates);
}
+
+ #region sampling
+ ///
+ /// Apply grammar rules to candidate tokens
+ ///
+ ///
+ ///
+ public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_grammar(ctx, ref st, grammar);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+ ///
+ ///
+ /// Number of tokens to keep
+ /// Minimum number to keep
+ public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_top_k(context, ref st, k, minKeep);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751
+ ///
+ ///
+ ///
+ ///
+ public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_top_p(context, ref st, p, minKeep);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
+ ///
+ ///
+ ///
+ ///
+ public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_tail_free(context, ref st, z, min_keep);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666.
+ ///
+ ///
+ ///
+ ///
+ public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_typical(context, ref st, p, min_keep);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
+ /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
+ ///
+ ///
+ ///
+ ///
+ ///
+ ///
+ public void RepetitionPenalty(SafeLLamaContextHandle context, Memory last_tokens, float penalty_repeat, float penalty_freq, float penalty_present)
+ {
+ unsafe
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ using (var last_tokens_handle = last_tokens.Pin())
+ {
+ NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present);
+ sorted = st.sorted;
+ }
+ }
+ }
+
+ ///
+ /// Sample with temperature.
+ /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual
+ ///
+ ///
+ ///
+ public void Temperature(SafeLLamaContextHandle context, float temp)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_temperature(context, ref st, temp);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits.
+ ///
+ ///
+ public void Softmax(SafeLLamaContextHandle context)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_softmax(context, ref st);
+ sorted = st.sorted;
+ }
+ }
+
+ ///
+ /// Randomly selects a token from the candidates based on their probabilities.
+ ///
+ ///
+ ///
+ public int SampleToken(SafeLLamaContextHandle context)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ var token = NativeApi.llama_sample_token(context, ref st);
+ sorted = st.sorted;
+ return token;
+ }
+ }
+
+ ///
+ /// Selects the token with the highest probability.
+ ///
+ ///
+ ///
+ public int SampleTokenGreedy(SafeLLamaContextHandle context)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ var token = NativeApi.llama_sample_token_greedy(context, ref st);
+ sorted = st.sorted;
+ return token;
+ }
+ }
+
+ ///
+ /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
+ ///
+ ///
+ /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+ /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+ /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
+ /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
+ ///
+ public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu);
+ sorted = st.sorted;
+ return token;
+ }
+ }
+
+ ///
+ /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words.
+ ///
+ ///
+ /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text.
+ /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
+ /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
+ ///
+ public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu);
+ sorted = st.sorted;
+ return token;
+ }
+ }
+ #endregion
}
///
diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs
index 80e682cfe..e7ee32ba1 100644
--- a/LLama/Native/NativeApi.Sampling.cs
+++ b/LLama/Native/NativeApi.Sampling.cs
@@ -9,26 +9,22 @@ public unsafe partial class NativeApi
{
///
/// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
- ///
- ///
- /// Pointer to LLamaTokenDataArray
- ///
- ///
- ///
- [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty);
-
- ///
/// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
///
///
/// Pointer to LLamaTokenDataArray
///
///
- ///
- ///
+ /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
+ /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
+ /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence);
+ public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx,
+ ref LLamaTokenDataArrayNative candidates,
+ llama_token* last_tokens, ulong last_tokens_size,
+ float penalty_repeat,
+ float penalty_freq,
+ float penalty_present);
///
/// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806
diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs
index 41f9ee670..e3b182bd4 100644
--- a/LLama/Native/NativeApi.cs
+++ b/LLama/Native/NativeApi.cs
@@ -348,9 +348,10 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
/// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
///
///
+ ///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx);
+ public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i);
///
/// Get the embeddings for the input
@@ -366,21 +367,21 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx);
+ public static extern llama_token llama_token_bos(SafeLlamaModelHandle model);
///
/// Get the "End of sentence" token
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx);
+ public static extern llama_token llama_token_eos(SafeLlamaModelHandle model);
///
/// Get the "new line" token
///
///
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx);
+ public static extern llama_token llama_token_nl(SafeLlamaModelHandle model);
///
/// Print out timing information for this context
@@ -530,6 +531,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
///
/// Allocates a batch of tokens on the heap
+ /// Each token can be assigned up to n_seq_max sequence ids
/// The batch has to be freed with llama_batch_free()
/// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float)
/// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token
@@ -538,8 +540,9 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi
///
///
///
+ /// Each token can be assigned up to n_seq_max sequence ids
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
- public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd);
+ public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd, int n_seq_max);
///
/// Frees a batch of tokens allocated with llama_batch_init()
diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs
index 7fb5edf74..59a5bfd9e 100644
--- a/LLama/Native/SafeLLamaContextHandle.cs
+++ b/LLama/Native/SafeLLamaContextHandle.cs
@@ -114,6 +114,22 @@ public Span GetLogits()
}
}
+ ///
+ /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab
+ ///
+ ///
+ ///
+ public Span GetLogitsIth(int i)
+ {
+ var model = ThrowIfDisposed();
+
+ unsafe
+ {
+ var logits = NativeApi.llama_get_logits_ith(this, i);
+ return new Span(logits, model.VocabCount);
+ }
+ }
+
#region tokens
///
/// Convert the given text into tokens
diff --git a/LLama/Native/SafeLLamaGrammarHandle.cs b/LLama/Native/SafeLLamaGrammarHandle.cs
index ed1c15c83..f430b7c3f 100644
--- a/LLama/Native/SafeLLamaGrammarHandle.cs
+++ b/LLama/Native/SafeLLamaGrammarHandle.cs
@@ -102,5 +102,15 @@ public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules,
return new(grammar_ptr);
}
#endregion
+
+ ///
+ /// Accepts the sampled token into the grammar
+ ///
+ ///
+ ///
+ public void AcceptToken(SafeLLamaContextHandle ctx, int token)
+ {
+ NativeApi.llama_grammar_accept_token(ctx, this, token);
+ }
}
}
diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs
index e26bf971a..41709def3 100644
--- a/LLama/Native/SamplingApi.cs
+++ b/LLama/Native/SamplingApi.cs
@@ -9,7 +9,7 @@ namespace LLama.Native
///
/// Direct translation of the llama.cpp sampling API
///
- public unsafe class SamplingApi
+ public class SamplingApi
{
///
/// Apply grammar rules to candidate tokens
@@ -17,70 +17,10 @@ public unsafe class SamplingApi
///
///
///
+ [Obsolete("use LLamaTokenDataArray ApplyGrammar method")]
public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_grammar(ctx, ref st, grammar);
- }
-
- ///
- /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
- ///
- ///
- /// Pointer to LLamaTokenDataArray
- ///
- ///
- ///
- [Obsolete("last_tokens_size parameter is no longer needed")]
- public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float penalty)
- {
- llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty);
- }
-
- ///
- /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix.
- ///
- ///
- /// Pointer to LLamaTokenDataArray
- ///
- ///
- public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float penalty)
- {
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- using var last_tokens_handle = last_tokens.Pin();
-
- NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty);
- }
-
- ///
- /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
- ///
- ///
- /// Pointer to LLamaTokenDataArray
- ///
- ///
- ///
- ///
- [Obsolete("last_tokens_size parameter is no longer needed")]
- public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence)
- {
- llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence);
- }
-
- ///
- /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details.
- ///
- ///
- /// Pointer to LLamaTokenDataArray
- ///
- ///
- ///
- public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float alpha_frequency, float alpha_presence)
- {
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- using var last_tokens_handle = last_tokens.Pin();
-
- NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence);
+ candidates.ApplyGrammar(ctx, grammar);
}
///
@@ -88,10 +28,10 @@ public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContex
///
///
/// Pointer to LLamaTokenDataArray
+ [Obsolete("use LLamaTokenDataArray Softmax method")]
public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_softmax(ctx, ref st);
+ candidates.Softmax(ctx);
}
///
@@ -101,10 +41,10 @@ public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDa
/// Pointer to LLamaTokenDataArray
///
///
+ [Obsolete("use LLamaTokenDataArray TopK method")]
public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep);
+ candidates.TopK(ctx, k, min_keep);
}
///
@@ -114,10 +54,10 @@ public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenData
/// Pointer to LLamaTokenDataArray
///
///
+ [Obsolete("use LLamaTokenDataArray TopP method")]
public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep);
+ candidates.TopP(ctx, p, min_keep);
}
///
@@ -127,10 +67,10 @@ public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenData
/// Pointer to LLamaTokenDataArray
///
///
+ [Obsolete("use LLamaTokenDataArray TailFree method")]
public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep);
+ candidates.TailFree(ctx, z, min_keep);
}
///
@@ -140,10 +80,10 @@ public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaToken
/// Pointer to LLamaTokenDataArray
///
///
+ [Obsolete("use LLamaTokenDataArray LocallyTypical method")]
public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_typical(ctx, ref st, p, min_keep);
+ candidates.LocallyTypical(ctx, p, min_keep);
}
///
@@ -153,10 +93,10 @@ public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDa
///
///
///
+ [Obsolete("use LLamaTokenDataArray Temperature() method")]
public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- NativeApi.llama_sample_temperature(ctx, ref st, temp);
+ candidates.Temperature(ctx, temp);
}
///
@@ -169,10 +109,10 @@ public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTok
/// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm.
/// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
///
+ [Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")]
public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu);
+ return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu);
}
///
@@ -184,10 +124,10 @@ public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx
/// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates.
/// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal.
///
+ [Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")]
public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu);
+ return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu);
}
///
@@ -196,10 +136,10 @@ public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle
///
/// Pointer to LLamaTokenDataArray
///
+ [Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")]
public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- return NativeApi.llama_sample_token_greedy(ctx, ref st);
+ return candidates.SampleTokenGreedy(ctx);
}
///
@@ -208,10 +148,10 @@ public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx,
///
/// Pointer to LLamaTokenDataArray
///
+ [Obsolete("use LLamaTokenDataArray SampleToken() method")]
public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates)
{
- using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st);
- return NativeApi.llama_sample_token(ctx, ref st);
+ return candidates.SampleToken(ctx);
}
}
}
diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal
index 99b9fd7a7..f4b460564 100644
--- a/LLama/runtimes/ggml-metal.metal
+++ b/LLama/runtimes/ggml-metal.metal
@@ -18,6 +18,21 @@ typedef struct {
uint8_t qs[QK4_1 / 2]; // nibbles / quants
} block_q4_1;
+#define QK5_0 32
+typedef struct {
+ half d; // delta
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_0 / 2]; // nibbles / quants
+} block_q5_0;
+
+#define QK5_1 32
+typedef struct {
+ half d; // delta
+ half m; // min
+ uint8_t qh[4]; // 5-th bit of quants
+ uint8_t qs[QK5_1 / 2]; // nibbles / quants
+} block_q5_1;
+
#define QK8_0 32
typedef struct {
half d; // delta
@@ -110,9 +125,17 @@ kernel void kernel_mul_row(
}
kernel void kernel_scale(
+ device const float * src0,
+ device float * dst,
+ constant float & scale,
+ uint tpig[[thread_position_in_grid]]) {
+ dst[tpig] = src0[tpig] * scale;
+}
+
+kernel void kernel_scale_4(
device const float4 * src0,
device float4 * dst,
- constant float & scale,
+ constant float & scale,
uint tpig[[thread_position_in_grid]]) {
dst[tpig] = src0[tpig] * scale;
}
@@ -399,8 +422,11 @@ kernel void kernel_rms_norm(
// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
+
float2 acc = 0.f;
+
device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2);
+
for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -417,8 +443,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre
inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) {
float d = qb_curr->d;
float m = qb_curr->m;
- device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+
float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2);
+
for (int i = 0; i < 8; i+=2) {
acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F)
+ yl[i + 1] * (qs[i / 2] & 0x0F00);
@@ -428,6 +457,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre
return d * (acc[0] + acc[1]) + sumy * m;
}
+// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_0/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2);
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+ return d * (sumy * -16.f + acc[0] + acc[1]);
+}
+
+// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i])
+// il indicates where the q5 quants begin (0 or QK5_1/4)
+// we assume that the yl's have been multiplied with the appropriate scale factor
+// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096)
+inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) {
+ float d = qb_curr->d;
+ float m = qb_curr->m;
+
+ float2 acc = 0.f;
+
+ device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2);
+ const uint32_t qh = *((device const uint32_t *)qb_curr->qh);
+
+ for (int i = 0; i < 8; i+=2) {
+ acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010))
+ + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000));
+ acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100))
+ + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000));
+ }
+ return d * (acc[0] + acc[1]) + sumy * m;
+}
+
// putting them in the kernel cause a significant performance penalty
#define N_DST 4 // each SIMD group works on 4 rows
#define N_SIMDGROUP 2 // number of SIMD groups in a thread group
@@ -525,6 +597,43 @@ kernel void kernel_mul_mv_q4_1_f32(
mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
}
+kernel void kernel_mul_mv_q5_0_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+}
+
+kernel void kernel_mul_mv_q5_1_f32(
+ device const void * src0,
+ device const float * src1,
+ device float * dst,
+ constant int64_t & ne00,
+ constant int64_t & ne01[[buffer(4)]],
+ constant int64_t & ne02[[buffer(5)]],
+ constant int64_t & ne10[[buffer(9)]],
+ constant int64_t & ne12[[buffer(11)]],
+ constant int64_t & ne0[[buffer(15)]],
+ constant int64_t & ne1[[buffer(16)]],
+ constant uint & gqa[[buffer(17)]],
+ uint3 tgpig[[threadgroup_position_in_grid]],
+ uint tiisg[[thread_index_in_simdgroup]],
+ uint sgitg[[simdgroup_index_in_threadgroup]]) {
+ mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg);
+}
+
+
#define NB_Q8_0 8
kernel void kernel_mul_mv_q8_0_f32(
@@ -2149,6 +2258,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg
}
}
+template
+void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 3);
+ const float d = xb->d;
+ const float md = -16.h * xb->d;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[i/2][2*(i%2)+0] = d * x0 + md;
+ reg[i/2][2*(i%2)+1] = d * x1 + md;
+ }
+}
+
+template
+void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) {
+ device const uint16_t * qs = ((device const uint16_t *)xb + 4);
+ const float d = xb->d;
+ const float m = xb->m;
+ const ushort mask = il ? 0x00F0 : 0x000F;
+
+ const uint32_t qh = *((device const uint32_t *)xb->qh);
+
+ const int x_mv = il ? 4 : 0;
+
+ const int gh_mv = il ? 12 : 0;
+ const int gh_bk = il ? 0 : 4;
+
+ for (int i = 0; i < 8; i++) {
+ // extract the 5-th bits for x0 and x1
+ const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10;
+ const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10;
+
+ // combine the 4-bits from qs with the 5th bit
+ const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0);
+ const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1);
+
+ reg[i/2][2*(i%2)+0] = d * x0 + m;
+ reg[i/2][2*(i%2)+1] = d * x1 + m;
+ }
+}
+
template
void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) {
device const int8_t * qs = ((device const int8_t *)xb->qs);
@@ -2490,6 +2655,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows
template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows;
+template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows;
template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows;
@@ -2518,6 +2685,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm;
+template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm;
template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm;
diff --git a/LLama/runtimes/libllama-cuda11.dll b/LLama/runtimes/libllama-cuda11.dll
index e5fc7dad5..70bb5b07a 100644
Binary files a/LLama/runtimes/libllama-cuda11.dll and b/LLama/runtimes/libllama-cuda11.dll differ
diff --git a/LLama/runtimes/libllama-cuda11.so b/LLama/runtimes/libllama-cuda11.so
index 3532fe998..45bac80be 100644
Binary files a/LLama/runtimes/libllama-cuda11.so and b/LLama/runtimes/libllama-cuda11.so differ
diff --git a/LLama/runtimes/libllama-cuda12.dll b/LLama/runtimes/libllama-cuda12.dll
index 89f27e243..7f64e0e38 100644
Binary files a/LLama/runtimes/libllama-cuda12.dll and b/LLama/runtimes/libllama-cuda12.dll differ
diff --git a/LLama/runtimes/libllama-cuda12.so b/LLama/runtimes/libllama-cuda12.so
index 81b4aa991..4a1e4380b 100644
Binary files a/LLama/runtimes/libllama-cuda12.so and b/LLama/runtimes/libllama-cuda12.so differ
diff --git a/LLama/runtimes/libllama.dll b/LLama/runtimes/libllama.dll
index 62d071ec8..00b93ba0f 100644
Binary files a/LLama/runtimes/libllama.dll and b/LLama/runtimes/libllama.dll differ
diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib
index c2ca7ec8e..3f36bb359 100755
Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ
diff --git a/LLama/runtimes/libllama.so b/LLama/runtimes/libllama.so
index b9ef4c1da..5240d696d 100644
Binary files a/LLama/runtimes/libllama.so and b/LLama/runtimes/libllama.so differ