diff --git a/LLama.Examples/LLama.Examples.csproj b/LLama.Examples/LLama.Examples.csproj index a4ebb604e..7fea45623 100644 --- a/LLama.Examples/LLama.Examples.csproj +++ b/LLama.Examples/LLama.Examples.csproj @@ -8,6 +8,7 @@ AnyCPU;x64 true + true diff --git a/LLama.Examples/NewVersion/BatchedDecoding.cs b/LLama.Examples/NewVersion/BatchedDecoding.cs new file mode 100644 index 000000000..ad29703a3 --- /dev/null +++ b/LLama.Examples/NewVersion/BatchedDecoding.cs @@ -0,0 +1,177 @@ +using System.Diagnostics; +using System.Security.Cryptography; +using System.Text; +using LLama.Common; +using LLama.Native; + +namespace LLama.Examples.NewVersion; + +/// +/// This demonstrates generating multiple replies to the same prompt, with a shared cache +/// +/// Note that this is currently using the low level API directly, future work will provide a safer C# wrapper over this! +public class BatchedDecoding +{ + private const int n_parallel = 8; + private const int n_len = 32; + + private const int top_k = 80; + private const float top_p = 0.8f; + private const float temp = 0.5f; + + public static async Task Run() + { + Console.Write("Please input your model path: "); + var modelPath = Console.ReadLine(); + + Console.WriteLine("Prompt (leave blank to select automatically):"); + var prompt = Console.ReadLine(); + if (string.IsNullOrWhiteSpace(prompt)) + prompt = "Not many people know that"; + + // Load model + var parameters = new ModelParams(modelPath); + using var model = LLamaWeights.LoadFromFile(parameters); + + // Tokenize prompt + var prompt_tokens = model.NativeHandle.Tokenize(prompt, true, false, Encoding.UTF8); + var n_kv_req = prompt_tokens.Length + (n_len - prompt_tokens.Length) * n_parallel; + + // Create a context + parameters.ContextSize = (uint)model.ContextSize; + parameters.BatchSize = (uint)Math.Max(n_len, n_parallel); + using var context = model.CreateContext(parameters); + + var n_ctx = context.ContextSize; + + // make sure the KV cache is big enough to hold all the prompt and generated tokens + if (n_kv_req > n_ctx) + { + await Console.Error.WriteLineAsync($"error: n_kv_req ({n_kv_req}) > n_ctx, the required KV cache size is not big enough\n"); + await Console.Error.WriteLineAsync(" either reduce n_parallel or increase n_ctx\n"); + return; + } + + using var batch = LLamaBatchSafeHandle.Create(Math.Max(prompt_tokens.Length, n_parallel), 0, 1); + + // evaluate the initial prompt + for (var i = 0; i < prompt_tokens.Length; i++) + batch.LLamaBatchAdd(prompt_tokens[i], i, new[] { (LLamaSeqId)0 }, false); + Debug.Assert(batch.NativeBatch.n_tokens == prompt_tokens.Length); + + // llama_decode will output logits only for the last token of the prompt + unsafe + { + batch.NativeBatch.logits[batch.NativeBatch.n_tokens - 1] = 1; + } + + if (context.NativeHandle.Decode(batch) != 0) + { + await Console.Error.WriteLineAsync("llama_decode failed"); + return; + } + + // assign the system KV cache to all parallel sequences + // this way, the parallel sequences will "reuse" the prompt tokens without having to copy them + for (var i = 1; i < n_parallel; ++i) + { + NativeApi.llama_kv_cache_seq_cp(context.NativeHandle, (LLamaSeqId)0, (LLamaSeqId)i, 0, batch.NativeBatch.n_tokens); + } + + if (n_parallel > 1) + { + Console.WriteLine(); + Console.WriteLine($"generating {n_parallel} sequences..."); + } + + // remember the batch index of the last token for each parallel sequence + // we need this to determine which logits to sample from + List i_batch = new(); + for (var i = 0; i < n_parallel; i++) + i_batch.Add(batch.NativeBatch.n_tokens - 1); + + var n_cur = batch.NativeBatch.n_tokens; + var n_decode = 0; + + var streams = new List[n_parallel]; + for (var i = 0; i < n_parallel; i++) + streams[i] = new(); + + var eos = model.EndOfSentenceToken; + var nl = model.NewlineToken; + + var timer = new Stopwatch(); + timer.Start(); + while (n_cur <= n_len) + { + batch.LLamaBatchClear(); + + for (var i = 0; i < n_parallel; i++) + { + // Skip completed streams + if (i_batch[i] < 0) + continue; + + var n_vocab = model.VocabCount; + LLamaTokenDataArray candidates; + unsafe + { + candidates = LLamaTokenDataArray.Create(new Span(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab)); + } + + candidates.TopK(context.NativeHandle, top_k); + candidates.TopP(context.NativeHandle, top_p); + candidates.Temperature(context.NativeHandle, temp); + var new_token_id = candidates.SampleToken(context.NativeHandle); + + if (new_token_id == eos || new_token_id == nl) + { + i_batch[i] = -1; + Console.WriteLine($"Completed Stream {i} early"); + continue; + } + + streams[i].Add(new_token_id); + + i_batch[i] = batch.NativeBatch.n_tokens; + + // push this new token for next evaluation + batch.LLamaBatchAdd(new_token_id, n_cur, new[] { (LLamaSeqId)i }, true); + + n_decode++; + } + + // all streams are finished + if (batch.NativeBatch.n_tokens == 0) + { + break; + } + + n_cur++; + + // evaluate the current batch with the transformer model + if (context.NativeHandle.Decode(batch) != 0) + { + await Console.Error.WriteLineAsync("failed to eval"); + return; + } + } + + timer.Stop(); + Console.ForegroundColor = ConsoleColor.Yellow; + Console.WriteLine(); + Console.WriteLine($"Decoded {n_decode} tokens in {timer.ElapsedMilliseconds}ms"); + Console.WriteLine($"Rate: {n_decode / timer.Elapsed.TotalSeconds:##.000} tokens/second"); + + var index = 0; + foreach (var stream in streams) + { + var text = context.DeTokenize(stream); + + Console.ForegroundColor = ConsoleColor.Green; + Console.Write($"{index++}. {prompt}"); + Console.ForegroundColor = ConsoleColor.Red; + Console.WriteLine(text); + } + } +} \ No newline at end of file diff --git a/LLama.Examples/NewVersion/SemanticKernelChat.cs b/LLama.Examples/NewVersion/SemanticKernelChat.cs index 9fd59058f..4f27bf659 100644 --- a/LLama.Examples/NewVersion/SemanticKernelChat.cs +++ b/LLama.Examples/NewVersion/SemanticKernelChat.cs @@ -14,10 +14,7 @@ public static async Task Run() var modelPath = Console.ReadLine(); // Load weights into memory - var parameters = new ModelParams(modelPath) - { - Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)), - }; + var parameters = new ModelParams(modelPath); using var model = LLamaWeights.LoadFromFile(parameters); using var context = model.CreateContext(parameters); var ex = new InteractiveExecutor(context); diff --git a/LLama.Examples/NewVersion/SemanticKernelPrompt.cs b/LLama.Examples/NewVersion/SemanticKernelPrompt.cs index 52591f918..d744b7630 100644 --- a/LLama.Examples/NewVersion/SemanticKernelPrompt.cs +++ b/LLama.Examples/NewVersion/SemanticKernelPrompt.cs @@ -16,10 +16,7 @@ public static async Task Run() var modelPath = Console.ReadLine(); // Load weights into memory - var parameters = new ModelParams(modelPath) - { - Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)) - }; + var parameters = new ModelParams(modelPath); using var model = LLamaWeights.LoadFromFile(parameters); var ex = new StatelessExecutor(model, parameters); diff --git a/LLama.Examples/NewVersion/TalkToYourself.cs b/LLama.Examples/NewVersion/TalkToYourself.cs index 4c412c93f..80d129ca1 100644 --- a/LLama.Examples/NewVersion/TalkToYourself.cs +++ b/LLama.Examples/NewVersion/TalkToYourself.cs @@ -13,10 +13,7 @@ public static async Task Run() var modelPath = Console.ReadLine(); // Load weights into memory - var @params = new ModelParams(modelPath) - { - Seed = unchecked((uint)RandomNumberGenerator.GetInt32(int.MaxValue)) - }; + var @params = new ModelParams(modelPath); using var weights = LLamaWeights.LoadFromFile(@params); // Create 2 contexts sharing the same weights diff --git a/LLama.Examples/NewVersion/TestRunner.cs b/LLama.Examples/NewVersion/TestRunner.cs index d88087bb4..2f698f803 100644 --- a/LLama.Examples/NewVersion/TestRunner.cs +++ b/LLama.Examples/NewVersion/TestRunner.cs @@ -22,6 +22,7 @@ public static async Task Run() Console.WriteLine("12: Semantic Kernel Chat."); Console.WriteLine("13: Semantic Kernel Memory."); Console.WriteLine("14: Coding Assistant."); + Console.WriteLine("15: Batch Decoding."); while (true) { @@ -88,6 +89,10 @@ public static async Task Run() { await CodingAssistant.Run(); } + else if (choice == 15) + { + await BatchedDecoding.Run(); + } else { Console.WriteLine("Cannot parse your choice. Please select again."); diff --git a/LLama.Unittest/StatelessExecutorTest.cs b/LLama.Unittest/StatelessExecutorTest.cs index 19f618af2..195cc4a28 100644 --- a/LLama.Unittest/StatelessExecutorTest.cs +++ b/LLama.Unittest/StatelessExecutorTest.cs @@ -1,3 +1,4 @@ +using System.Diagnostics; using LLama.Common; using Xunit.Abstractions; @@ -34,10 +35,17 @@ public async Task Stateless() const string question = "Question. what is a cat?\nAnswer: "; var @params = new InferenceParams { MaxTokens = 32, AntiPrompts = new[] { "." } }; + var timer = new Stopwatch(); + timer.Start(); + var result1 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); var result2 = string.Join("", await executor.InferAsync(question, @params).ToListAsync()); + timer.Stop(); + _testOutputHelper.WriteLine($"{timer.ElapsedMilliseconds}ms"); + _testOutputHelper.WriteLine(result1); + _testOutputHelper.WriteLine(result2); // Check that it produced the exact same result both times Assert.Equal(result1, result2); diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index c2420af37..0b38a04d8 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -23,7 +23,7 @@ public class InferenceOptions : IInferenceParams /// /// Sequences where the model will stop generating further tokens. /// - public IEnumerable AntiPrompts { get; set; } = Array.Empty(); + public IReadOnlyList AntiPrompts { get; set; } = Array.Empty(); /// /// path to file for saving/loading model eval state /// diff --git a/LLama.Web/Common/ModelOptions.cs b/LLama.Web/Common/ModelOptions.cs index 20a3e348a..6a63ccc31 100644 --- a/LLama.Web/Common/ModelOptions.cs +++ b/LLama.Web/Common/ModelOptions.cs @@ -111,12 +111,12 @@ public class ModelOptions /// /// RoPE base frequency /// - public float RopeFrequencyBase { get; set; } = 10000.0f; + public float? RopeFrequencyBase { get; set; } /// /// RoPE frequency scaling factor /// - public float RopeFrequencyScale { get; set; } = 1.0f; + public float? RopeFrequencyScale { get; set; } /// /// Use experimental mul_mat_q kernels diff --git a/LLama/Abstractions/IContextParams.cs b/LLama/Abstractions/IContextParams.cs index 201a9b9ad..8ff6d7ccf 100644 --- a/LLama/Abstractions/IContextParams.cs +++ b/LLama/Abstractions/IContextParams.cs @@ -39,14 +39,14 @@ public interface IContextParams bool EmbeddingMode { get; set; } /// - /// RoPE base frequency + /// RoPE base frequency (null to fetch from the model) /// - float RopeFrequencyBase { get; set; } + float? RopeFrequencyBase { get; set; } /// - /// RoPE frequency scaling factor + /// RoPE frequency scaling factor (null to fetch from the model) /// - float RopeFrequencyScale { get; set; } + float? RopeFrequencyScale { get; set; } /// /// Use experimental mul_mat_q kernels diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index 93a9b52ba..a21c7306b 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -29,7 +29,7 @@ public interface IInferenceParams /// /// Sequences where the model will stop generating further tokens. /// - public IEnumerable AntiPrompts { get; set; } + public IReadOnlyList AntiPrompts { get; set; } /// /// 0 or lower to use vocab size diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index bef64631a..d0217e2f8 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -28,7 +28,7 @@ public record InferenceParams : IInferenceParams /// /// Sequences where the model will stop generating further tokens. /// - public IEnumerable AntiPrompts { get; set; } = Array.Empty(); + public IReadOnlyList AntiPrompts { get; set; } = Array.Empty(); /// /// 0 or lower to use vocab size diff --git a/LLama/Common/ModelParams.cs b/LLama/Common/ModelParams.cs index 78f51c6cb..ee5bd3e4c 100644 --- a/LLama/Common/ModelParams.cs +++ b/LLama/Common/ModelParams.cs @@ -91,12 +91,12 @@ public record ModelParams /// /// RoPE base frequency /// - public float RopeFrequencyBase { get; set; } = 10000.0f; + public float? RopeFrequencyBase { get; set; } /// /// RoPE frequency scaling factor /// - public float RopeFrequencyScale { get; set; } = 1.0f; + public float? RopeFrequencyScale { get; set; } /// /// Use experimental mul_mat_q kernels @@ -156,7 +156,7 @@ public ModelParams(string modelPath, uint contextSize = 512, int gpuLayerCount = bool useMemorymap = true, bool useMemoryLock = false, bool perplexity = false, string loraAdapter = "", string loraBase = "", int threads = -1, uint batchSize = 512, bool embeddingMode = false, - float ropeFrequencyBase = 10000.0f, float ropeFrequencyScale = 1f, bool mulMatQ = false, + float? ropeFrequencyBase = null, float? ropeFrequencyScale = null, bool mulMatQ = false, string encoding = "UTF-8") { ContextSize = contextSize; diff --git a/LLama/Extensions/IContextParamsExtensions.cs b/LLama/Extensions/IContextParamsExtensions.cs index 7ca508a2b..fcc9d372a 100644 --- a/LLama/Extensions/IContextParamsExtensions.cs +++ b/LLama/Extensions/IContextParamsExtensions.cs @@ -27,8 +27,8 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo result.f16_kv = @params.UseFp16Memory; result.logits_all = @params.Perplexity; result.embedding = @params.EmbeddingMode; - result.rope_freq_base = @params.RopeFrequencyBase; - result.rope_freq_scale = @params.RopeFrequencyScale; + result.rope_freq_base = @params.RopeFrequencyBase ?? 0; + result.rope_freq_scale = @params.RopeFrequencyScale ?? 0; result.mul_mat_q = @params.MulMatQ; result.n_threads = Threads(@params.Threads); diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 47240dc71..0ccae76b4 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -235,13 +235,13 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu if (grammar != null) { - SamplingApi.llama_sample_grammar(NativeHandle, candidates, grammar); + candidates.ApplyGrammar(NativeHandle, grammar); } if (temperature <= 0) { // Greedy sampling - id = SamplingApi.llama_sample_token_greedy(NativeHandle, candidates); + id = candidates.SampleTokenGreedy(NativeHandle); } else { @@ -250,32 +250,28 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu if (mirostat == MirostatType.Mirostat) { const int mirostat_m = 100; - SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat(NativeHandle, candidates, mirostatTau, mirostatEta, mirostat_m, ref mu); + candidates.Temperature(NativeHandle, temperature); + id = candidates.SampleTokenMirostat(NativeHandle, mirostatTau, mirostatEta, mirostat_m, ref mu); } else if (mirostat == MirostatType.Mirostat2) { - SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); - id = SamplingApi.llama_sample_token_mirostat_v2(NativeHandle, candidates, mirostatTau, mirostatEta, ref mu); + candidates.Temperature(NativeHandle, temperature); + id = candidates.SampleTokenMirostat2(NativeHandle, mirostatTau, mirostatEta, ref mu); } else { - // Temperature sampling - SamplingApi.llama_sample_top_k(NativeHandle, candidates, topK, 1); - SamplingApi.llama_sample_tail_free(NativeHandle, candidates, tfsZ, 1); - SamplingApi.llama_sample_typical(NativeHandle, candidates, typicalP, 1); - SamplingApi.llama_sample_top_p(NativeHandle, candidates, topP, 1); - SamplingApi.llama_sample_temperature(NativeHandle, candidates, temperature); - id = SamplingApi.llama_sample_token(NativeHandle, candidates); + candidates.TopK(NativeHandle, topK); + candidates.TailFree(NativeHandle, tfsZ); + candidates.LocallyTypical(NativeHandle, typicalP); + candidates.TopP(NativeHandle, topP); + candidates.Temperature(NativeHandle, temperature); + id = candidates.SampleToken(NativeHandle); } } mirostat_mu = mu; } - if (grammar != null) - { - NativeApi.llama_grammar_accept_token(NativeHandle, grammar, id); - } + grammar?.AcceptToken(NativeHandle, id); return id; } @@ -305,7 +301,7 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dic } // Save the newline logit value - var nl_token = NativeApi.llama_token_nl(NativeHandle); + var nl_token = NativeApi.llama_token_nl(NativeHandle.ModelHandle); var nl_logit = logits[nl_token]; // Convert logits into token candidates @@ -316,8 +312,7 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable lastTokens, Dic var last_n_array = lastTokens.TakeLast(last_n_repeat).ToArray(); // Apply penalties to candidates - SamplingApi.llama_sample_repetition_penalty(NativeHandle, candidates_p, last_n_array, repeatPenalty); - SamplingApi.llama_sample_frequency_and_presence_penalties(NativeHandle, candidates_p, last_n_array, alphaFrequency, alphaPresence); + candidates_p.RepetitionPenalty(NativeHandle, last_n_array, repeatPenalty, alphaFrequency, alphaPresence); // Restore newline token logit value if necessary if (!penalizeNL) @@ -369,7 +364,7 @@ public int Eval(List tokens, int pastTokensCount) try { tokens.CopyTo(rented, 0); - return Eval(rented, pastTokensCount); + return Eval(rented.AsSpan(0, tokens.Count), pastTokensCount); } finally { diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 9e4292ea6..80c6f5420 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -163,7 +163,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) } } - if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) + if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) { args.WaitForInput = true; } diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index d3d4a9e39..5078648b1 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -30,7 +30,7 @@ public class InteractiveExecutor : StatefulExecutorBase public InteractiveExecutor(LLamaContext context, ILogger? logger = null) : base(context, logger) { - _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle); + _llama_token_newline = NativeApi.llama_token_nl(Context.NativeHandle.ModelHandle); } /// @@ -141,7 +141,7 @@ protected override Task PreprocessInputs(string text, InferStateArgs args) return (true, Array.Empty()); } - if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle)) + if (_embeds.Count > 0 && _embeds.Last() == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) { return (true, new[] { " [end of text]\n" }); } @@ -202,7 +202,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In _last_n_tokens.Enqueue(id); - if (id == NativeApi.llama_token_eos(Context.NativeHandle)) + if (id == NativeApi.llama_token_eos(Context.NativeHandle.ModelHandle)) { id = _llama_token_newline; if (args.Antiprompts is not null && args.Antiprompts.Count > 0) diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index ab1e9bbc0..4bdeaa3f2 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -6,7 +6,6 @@ using System.Runtime.CompilerServices; using System.Threading; using System.Threading.Tasks; -using LLama.Extensions; using LLama.Native; using Microsoft.Extensions.Logging; @@ -47,68 +46,66 @@ public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger? } /// - public async IAsyncEnumerable InferAsync(string text, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) + public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams? inferenceParams = null, [EnumeratorCancellation] CancellationToken cancellationToken = default) { - using var context = _weights.CreateContext(_params, _logger); - Context = context; - + // Ensure the context from last time is disposed (it always hould be) if (!Context.NativeHandle.IsClosed) Context.Dispose(); - Context = _weights.CreateContext(Context.Params, _logger); - - var decoder = new StreamingTokenDecoder(Context); - var antiprocessor = new AntipromptProcessor(inferenceParams?.AntiPrompts ?? Array.Empty()); - - if (inferenceParams != null) - { - if (inferenceParams.TokensKeep > Context.ContextSize) - throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); - } - cancellationToken.ThrowIfCancellationRequested(); + // Create an inference context which will be disposed when this method exits + using var context = _weights.CreateContext(_params, _logger); + Context = context; + // Sanity check inference params inferenceParams ??= new InferenceParams(); + if (inferenceParams.TokensKeep > Context.ContextSize) + throw new ArgumentOutOfRangeException(nameof(inferenceParams), $"TokensKeep ({inferenceParams.TokensKeep}) cannot be larger than ContextSize ({Context.ContextSize})"); + + // Create decoders for the token stream + var decoder = new StreamingTokenDecoder(Context); + var antiprocessor = new AntipromptProcessor(inferenceParams.AntiPrompts); - var lastTokens = new List(inferenceParams.RepeatLastTokensCount); - for (var i = 0; i < inferenceParams.RepeatLastTokensCount; i++) + // Keep track of the last N tokens emitted + var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount); + var lastTokens = new List(repeat_last_n); + for (var i = 0; i < repeat_last_n; i++) lastTokens.Add(0); - var tokens = Context.Tokenize(text).ToList(); + // Tokenize the prompt + var tokens = Context.Tokenize(prompt).ToList(); + lastTokens.AddRange(tokens); + var n_past = 1 + tokens.Count; + // Evaluate the prompt await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken) .ConfigureAwait(false); - lastTokens.AddRange(tokens); - var n_past = 1 + tokens.Count; - + // Begin loop, evaluating one token at a time var mu = (float?)null; var max_tokens = inferenceParams.MaxTokens < 0 ? int.MaxValue : inferenceParams.MaxTokens; - for(var i = 0; i < max_tokens; i++) + for(var i = 0; i < max_tokens && !cancellationToken.IsCancellationRequested; i++) { - if (cancellationToken.IsCancellationRequested) - break; - - var repeat_last_n = inferenceParams.RepeatLastTokensCount < 0 ? Context.ContextSize : inferenceParams.RepeatLastTokensCount; - + // Penalize the generated tokens by various penalties var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n, inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); + // Sample a single token var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); - lastTokens.Add(id); - + // Decode this token into text decoder.Add(id); var decoded = decoder.Read(); yield return decoded; - tokens.Clear(); - tokens.Add(id); - // Check if any of the antiprompts have been generated if (antiprocessor.Add(decoded)) break; + lastTokens.Add(id); + tokens.Clear(); + tokens.Add(id); + // when run out of context // based on this logic: https://github.com/ggerganov/llama.cpp/blob/master/examples/main/main.cpp#L497 if (n_past + tokens.Count >= Context.ContextSize) diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 64878e2ab..7ae104a5c 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -38,6 +38,21 @@ public sealed class LLamaWeights /// public ulong ParameterCount => NativeHandle.ParameterCount; + /// + /// Get the newline token for this model + /// + public int NewlineToken => NativeApi.llama_token_nl(NativeHandle); + + /// + /// Get the "end of sentence" token for this model + /// + public int EndOfSentenceToken => NativeApi.llama_token_eos(NativeHandle); + + /// + /// Get the "beginning of sentence" token for this model + /// + public int BeginningOfSentenceToken => NativeApi.llama_token_bos(NativeHandle); + /// /// Dimension of embedding vectors /// diff --git a/LLama/Native/LLamaBatchSafeHandle.cs b/LLama/Native/LLamaBatchSafeHandle.cs index 1f7ef2c53..30e703946 100644 --- a/LLama/Native/LLamaBatchSafeHandle.cs +++ b/LLama/Native/LLamaBatchSafeHandle.cs @@ -15,7 +15,7 @@ public sealed class LLamaBatchSafeHandle /// /// Get the native llama_batch struct /// - public LLamaNativeBatch NativeBatch { get; private set; } + public LLamaNativeBatch NativeBatch; /// /// the token ids of the input (used when embd is NULL) @@ -113,10 +113,11 @@ public LLamaBatchSafeHandle(LLamaNativeBatch batch, int embd) /// /// /// + /// /// - public static LLamaBatchSafeHandle Create(int n_tokens, int embd) + public static LLamaBatchSafeHandle Create(int n_tokens, int embd, int n_seq_max) { - var batch = NativeApi.llama_batch_init(n_tokens, embd); + var batch = NativeApi.llama_batch_init(n_tokens, embd, n_seq_max); return new LLamaBatchSafeHandle(batch, embd); } @@ -128,4 +129,32 @@ protected override bool ReleaseHandle() SetHandle(IntPtr.Zero); return true; } + + /// + /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2 + /// + public void LLamaBatchAdd(int token, LLamaPos pos, ReadOnlySpan sequences, bool logits) + { + unsafe + { + NativeBatch.token[NativeBatch.n_tokens] = token; + NativeBatch.pos[NativeBatch.n_tokens] = pos; + NativeBatch.n_seq_id[NativeBatch.n_tokens] = sequences.Length; + + for (var i = 0; i < sequences.Length; i++) + NativeBatch.seq_id[NativeBatch.n_tokens][i] = sequences[i]; + + NativeBatch.logits[NativeBatch.n_tokens] = Convert.ToByte(logits); + + NativeBatch.n_tokens++; + } + } + + /// + /// https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L825 + /// + public void LLamaBatchClear() + { + NativeBatch.n_tokens = 0; + } } \ No newline at end of file diff --git a/LLama/Native/LLamaBeamView.cs b/LLama/Native/LLamaBeamView.cs index e6a6c39f5..e6bc504ee 100644 --- a/LLama/Native/LLamaBeamView.cs +++ b/LLama/Native/LLamaBeamView.cs @@ -11,13 +11,13 @@ namespace LLama.Native; [StructLayout(LayoutKind.Sequential)] public struct LLamaBeamView { - private readonly unsafe llama_token* tokens; - private readonly nint n_tokens; + private unsafe llama_token* tokens; + private nint n_tokens; /// /// Cumulative beam probability (renormalized relative to all beams) /// - public readonly float CumulativeProbability; + public float CumulativeProbability; /// /// Callback should set this to true when a beam is at end-of-beam. diff --git a/LLama/Native/LLamaBeamsState.cs b/LLama/Native/LLamaBeamsState.cs index 6f0a447d7..f78c45b97 100644 --- a/LLama/Native/LLamaBeamsState.cs +++ b/LLama/Native/LLamaBeamsState.cs @@ -9,27 +9,27 @@ namespace LLama.Native; /// (e.g. beams[0]) as they will be removed (shifted) from all beams in all subsequent callbacks. /// [StructLayout(LayoutKind.Sequential)] -public readonly struct LLamaBeamsState +public struct LLamaBeamsState { /// /// The state of each individual beam /// - private readonly unsafe LLamaBeamView* beam_views; + private unsafe LLamaBeamView* beam_views; /// /// Number of elements in beam_views /// - private readonly nint n_beams; + private nint n_beams; /// /// Current max length of prefix tokens shared by all beams. /// - public readonly ulong CommonPrefixLength; + public ulong CommonPrefixLength; /// /// True iff this is the last callback invocation. /// - public readonly bool LastCall; + public bool LastCall; /// /// The current state of each beam diff --git a/LLama/Native/LLamaGrammarElement.cs b/LLama/Native/LLamaGrammarElement.cs index 5a7f27bb2..96313f239 100644 --- a/LLama/Native/LLamaGrammarElement.cs +++ b/LLama/Native/LLamaGrammarElement.cs @@ -52,18 +52,18 @@ public enum LLamaGrammarElementType /// [StructLayout(LayoutKind.Sequential)] [DebuggerDisplay("{Type} {Value}")] - public readonly struct LLamaGrammarElement + public struct LLamaGrammarElement : IEquatable { /// /// The type of this element /// - public readonly LLamaGrammarElementType Type; + public LLamaGrammarElementType Type; /// /// Unicode code point or rule ID /// - public readonly uint Value; + public uint Value; /// /// Construct a new LLamaGrammarElement diff --git a/LLama/Native/LLamaModelQuantizeParams.cs b/LLama/Native/LLamaModelQuantizeParams.cs index 4d0b9e7fa..8f10b2ff3 100644 --- a/LLama/Native/LLamaModelQuantizeParams.cs +++ b/LLama/Native/LLamaModelQuantizeParams.cs @@ -1,10 +1,12 @@ using System; +using System.Runtime.InteropServices; namespace LLama.Native { /// /// Quantizer parameters used in the native API /// + [StructLayout(LayoutKind.Sequential)] public struct LLamaModelQuantizeParams { /// diff --git a/LLama/Native/LLamaNativeBatch.cs b/LLama/Native/LLamaNativeBatch.cs index 576f8b279..d46f8b99e 100644 --- a/LLama/Native/LLamaNativeBatch.cs +++ b/LLama/Native/LLamaNativeBatch.cs @@ -11,35 +11,47 @@ namespace LLama.Native; /// The provided arrays (i.e. token, embd, pos, etc.) must have size of n_tokens /// [StructLayout(LayoutKind.Sequential)] -public readonly unsafe struct LLamaNativeBatch +public unsafe struct LLamaNativeBatch { /// /// The number of items pointed at by pos, seq_id and logits. /// - public readonly int n_tokens; + public int n_tokens; /// /// Either `n_tokens` of `llama_token`, or `NULL`, depending on how this batch was created /// - public readonly llama_token* token; + public llama_token* token; /// /// Either `n_tokens * embd * sizeof(float)` or `NULL`, depending on how this batch was created /// - public readonly float* embd; + public float* embd; /// /// the positions of the respective token in the sequence /// - public readonly LLamaPos* pos; + public LLamaPos* pos; + + /// + /// https://github.com/ggerganov/llama.cpp/blob/master/llama.h#L139 ??? + /// + public int* n_seq_id; /// /// the sequence to which the respective token belongs /// - public readonly LLamaSeqId* seq_id; + public LLamaSeqId** seq_id; /// /// if zero, the logits for the respective token will not be output /// - public readonly byte* logits; + public byte* logits; + + // Note from llama.cpp: + // > helpers for smooth API transition - can be deprecated in the future + // > for future-proof code, use the above fields instead and ignore everything below + private LLamaPos _all_pos_0; + private LLamaPos _all_pos_1; + private LLamaSeqId _all_seq_id; } \ No newline at end of file diff --git a/LLama/Native/LLamaPos.cs b/LLama/Native/LLamaPos.cs index 4deae57b6..67ede7d52 100644 --- a/LLama/Native/LLamaPos.cs +++ b/LLama/Native/LLamaPos.cs @@ -1,14 +1,26 @@ -namespace LLama.Native; +using System.Runtime.InteropServices; + +namespace LLama.Native; /// /// Indicates position in a sequence /// -public readonly record struct LLamaPos(int Value) +[StructLayout(LayoutKind.Sequential)] +public struct LLamaPos { /// /// The raw value /// - public readonly int Value = Value; + public int Value; + + /// + /// Create a new LLamaPos + /// + /// + public LLamaPos(int value) + { + Value = value; + } /// /// Convert a LLamaPos into an integer (extract the raw value) diff --git a/LLama/Native/LLamaSeqId.cs b/LLama/Native/LLamaSeqId.cs index 4a665a2c3..191a6b5ec 100644 --- a/LLama/Native/LLamaSeqId.cs +++ b/LLama/Native/LLamaSeqId.cs @@ -1,15 +1,26 @@ -namespace LLama.Native; +using System.Runtime.InteropServices; + +namespace LLama.Native; /// /// ID for a sequence in a batch /// -/// -public record struct LLamaSeqId(int Value) +[StructLayout(LayoutKind.Sequential)] +public struct LLamaSeqId { /// /// The raw value /// - public int Value = Value; + public int Value; + + /// + /// Create a new LLamaSeqId + /// + /// + public LLamaSeqId(int value) + { + Value = value; + } /// /// Convert a LLamaSeqId into an integer (extract the raw value) diff --git a/LLama/Native/LLamaTokenData.cs b/LLama/Native/LLamaTokenData.cs index 1ea6820dd..45edd4542 100644 --- a/LLama/Native/LLamaTokenData.cs +++ b/LLama/Native/LLamaTokenData.cs @@ -5,24 +5,34 @@ namespace LLama.Native; /// /// A single token along with probability of this token being selected /// -/// -/// -/// [StructLayout(LayoutKind.Sequential)] -public record struct LLamaTokenData(int id, float logit, float p) +public struct LLamaTokenData { /// /// token id /// - public int id = id; + public int id; /// /// log-odds of the token /// - public float logit = logit; + public float logit; /// /// probability of the token /// - public float p = p; + public float p; + + /// + /// Create a new LLamaTokenData + /// + /// + /// + /// + public LLamaTokenData(int id, float logit, float p) + { + this.id = id; + this.logit = logit; + this.p = p; + } } \ No newline at end of file diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 7a2965ed8..8f20a73ac 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -45,6 +45,199 @@ public static LLamaTokenDataArray Create(ReadOnlySpan logits) return new LLamaTokenDataArray(candidates); } + + #region sampling + /// + /// Apply grammar rules to candidate tokens + /// + /// + /// + public void ApplyGrammar(SafeLLamaContextHandle ctx, SafeLLamaGrammarHandle grammar) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_grammar(ctx, ref st, grammar); + sorted = st.sorted; + } + } + + /// + /// Top-K sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + /// + /// + /// Number of tokens to keep + /// Minimum number to keep + public void TopK(SafeLLamaContextHandle context, int k, ulong minKeep = 1) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_top_k(context, ref st, k, minKeep); + sorted = st.sorted; + } + } + + /// + /// Nucleus sampling described in academic paper "The Curious Case of Neural Text Degeneration" https://arxiv.org/abs/1904.09751 + /// + /// + /// + /// + public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_top_p(context, ref st, p, minKeep); + sorted = st.sorted; + } + } + + /// + /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. + /// + /// + /// + /// + public void TailFree(SafeLLamaContextHandle context, float z, ulong min_keep = 1) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_tail_free(context, ref st, z, min_keep); + sorted = st.sorted; + } + } + + /// + /// Locally Typical Sampling implementation described in the paper https://arxiv.org/abs/2202.00666. + /// + /// + /// + /// + public void LocallyTypical(SafeLLamaContextHandle context, float p, ulong min_keep = 1) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_typical(context, ref st, p, min_keep); + sorted = st.sorted; + } + } + + /// + /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + /// + /// + /// + /// + /// + /// + public void RepetitionPenalty(SafeLLamaContextHandle context, Memory last_tokens, float penalty_repeat, float penalty_freq, float penalty_present) + { + unsafe + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + using (var last_tokens_handle = last_tokens.Pin()) + { + NativeApi.llama_sample_repetition_penalties(context, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty_repeat, penalty_freq, penalty_present); + sorted = st.sorted; + } + } + } + + /// + /// Sample with temperature. + /// As temperature increases, the prediction becomes more diverse but also vulnerable to hallucinations -- generating tokens that are sensible but not factual + /// + /// + /// + public void Temperature(SafeLLamaContextHandle context, float temp) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_temperature(context, ref st, temp); + sorted = st.sorted; + } + } + + /// + /// Sorts candidate tokens by their logits in descending order and calculate probabilities based on logits. + /// + /// + public void Softmax(SafeLLamaContextHandle context) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_softmax(context, ref st); + sorted = st.sorted; + } + } + + /// + /// Randomly selects a token from the candidates based on their probabilities. + /// + /// + /// + public int SampleToken(SafeLLamaContextHandle context) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + var token = NativeApi.llama_sample_token(context, ref st); + sorted = st.sorted; + return token; + } + } + + /// + /// Selects the token with the highest probability. + /// + /// + /// + public int SampleTokenGreedy(SafeLLamaContextHandle context) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + var token = NativeApi.llama_sample_token_greedy(context, ref st); + sorted = st.sorted; + return token; + } + } + + /// + /// Mirostat 1.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. + /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + /// + public int SampleTokenMirostat(SafeLLamaContextHandle context, float tau, float eta, int m, ref float mu) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + var token = NativeApi.llama_sample_token_mirostat(context, ref st, tau, eta, m, ref mu); + sorted = st.sorted; + return token; + } + } + + /// + /// Mirostat 2.0 algorithm described in the paper https://arxiv.org/abs/2007.14966. Uses tokens instead of words. + /// + /// + /// The target cross-entropy (or surprise) value you want to achieve for the generated text. A higher value corresponds to more surprising or less predictable text, while a lower value corresponds to less surprising or more predictable text. + /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. + /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. + /// + public int SampleTokenMirostat2(SafeLLamaContextHandle context, float tau, float eta, ref float mu) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + var token = NativeApi.llama_sample_token_mirostat_v2(context, ref st, tau, eta, ref mu); + sorted = st.sorted; + return token; + } + } + #endregion } /// diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs index 80e682cfe..e7ee32ba1 100644 --- a/LLama/Native/NativeApi.Sampling.cs +++ b/LLama/Native/NativeApi.Sampling.cs @@ -9,26 +9,22 @@ public unsafe partial class NativeApi { /// /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - /// - /// - /// Pointer to LLamaTokenDataArray - /// - /// - /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty); - - /// /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. /// /// /// Pointer to LLamaTokenDataArray /// /// - /// - /// + /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. + /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. + /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence); + public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, + ref LLamaTokenDataArrayNative candidates, + llama_token* last_tokens, ulong last_tokens_size, + float penalty_repeat, + float penalty_freq, + float penalty_present); /// /// Apply classifier-free guidance to the logits as described in academic paper "Stay on topic with Classifier-Free Guidance" https://arxiv.org/abs/2306.17806 diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 41f9ee670..e3b182bd4 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -348,9 +348,10 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab /// /// + /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx); + public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); /// /// Get the embeddings for the input @@ -366,21 +367,21 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_bos(SafeLLamaContextHandle ctx); + public static extern llama_token llama_token_bos(SafeLlamaModelHandle model); /// /// Get the "End of sentence" token /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_eos(SafeLLamaContextHandle ctx); + public static extern llama_token llama_token_eos(SafeLlamaModelHandle model); /// /// Get the "new line" token /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern llama_token llama_token_nl(SafeLLamaContextHandle ctx); + public static extern llama_token llama_token_nl(SafeLlamaModelHandle model); /// /// Print out timing information for this context @@ -530,6 +531,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// /// Allocates a batch of tokens on the heap + /// Each token can be assigned up to n_seq_max sequence ids /// The batch has to be freed with llama_batch_free() /// If embd != 0, llama_batch.embd will be allocated with size of n_tokens * embd * sizeof(float) /// Otherwise, llama_batch.token will be allocated to store n_tokens llama_token @@ -538,8 +540,9 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// /// /// + /// Each token can be assigned up to n_seq_max sequence ids [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd); + public static extern LLamaNativeBatch llama_batch_init(int n_tokens, int embd, int n_seq_max); /// /// Frees a batch of tokens allocated with llama_batch_init() diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index 7fb5edf74..59a5bfd9e 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -114,6 +114,22 @@ public Span GetLogits() } } + /// + /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab + /// + /// + /// + public Span GetLogitsIth(int i) + { + var model = ThrowIfDisposed(); + + unsafe + { + var logits = NativeApi.llama_get_logits_ith(this, i); + return new Span(logits, model.VocabCount); + } + } + #region tokens /// /// Convert the given text into tokens diff --git a/LLama/Native/SafeLLamaGrammarHandle.cs b/LLama/Native/SafeLLamaGrammarHandle.cs index ed1c15c83..f430b7c3f 100644 --- a/LLama/Native/SafeLLamaGrammarHandle.cs +++ b/LLama/Native/SafeLLamaGrammarHandle.cs @@ -102,5 +102,15 @@ public static unsafe SafeLLamaGrammarHandle Create(LLamaGrammarElement** rules, return new(grammar_ptr); } #endregion + + /// + /// Accepts the sampled token into the grammar + /// + /// + /// + public void AcceptToken(SafeLLamaContextHandle ctx, int token) + { + NativeApi.llama_grammar_accept_token(ctx, this, token); + } } } diff --git a/LLama/Native/SamplingApi.cs b/LLama/Native/SamplingApi.cs index e26bf971a..41709def3 100644 --- a/LLama/Native/SamplingApi.cs +++ b/LLama/Native/SamplingApi.cs @@ -9,7 +9,7 @@ namespace LLama.Native /// /// Direct translation of the llama.cpp sampling API /// - public unsafe class SamplingApi + public class SamplingApi { /// /// Apply grammar rules to candidate tokens @@ -17,70 +17,10 @@ public unsafe class SamplingApi /// /// /// + [Obsolete("use LLamaTokenDataArray ApplyGrammar method")] public static void llama_sample_grammar(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, SafeLLamaGrammarHandle grammar) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_grammar(ctx, ref st, grammar); - } - - /// - /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - /// - /// - /// Pointer to LLamaTokenDataArray - /// - /// - /// - [Obsolete("last_tokens_size parameter is no longer needed")] - public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float penalty) - { - llama_sample_repetition_penalty(ctx, candidates, last_tokens, penalty); - } - - /// - /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. - /// - /// - /// Pointer to LLamaTokenDataArray - /// - /// - public static void llama_sample_repetition_penalty(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float penalty) - { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - using var last_tokens_handle = last_tokens.Pin(); - - NativeApi.llama_sample_repetition_penalty(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, penalty); - } - - /// - /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - /// - /// - /// Pointer to LLamaTokenDataArray - /// - /// - /// - /// - [Obsolete("last_tokens_size parameter is no longer needed")] - public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, ulong last_tokens_size, float alpha_frequency, float alpha_presence) - { - llama_sample_frequency_and_presence_penalties(ctx, candidates, last_tokens, alpha_frequency, alpha_presence); - } - - /// - /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. - /// - /// - /// Pointer to LLamaTokenDataArray - /// - /// - /// - public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, Memory last_tokens, float alpha_frequency, float alpha_presence) - { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - using var last_tokens_handle = last_tokens.Pin(); - - NativeApi.llama_sample_frequency_and_presence_penalties(ctx, ref st, (int*)last_tokens_handle.Pointer, (ulong)last_tokens.Length, alpha_frequency, alpha_presence); + candidates.ApplyGrammar(ctx, grammar); } /// @@ -88,10 +28,10 @@ public static void llama_sample_frequency_and_presence_penalties(SafeLLamaContex /// /// /// Pointer to LLamaTokenDataArray + [Obsolete("use LLamaTokenDataArray Softmax method")] public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_softmax(ctx, ref st); + candidates.Softmax(ctx); } /// @@ -101,10 +41,10 @@ public static void llama_sample_softmax(SafeLLamaContextHandle ctx, LLamaTokenDa /// Pointer to LLamaTokenDataArray /// /// + [Obsolete("use LLamaTokenDataArray TopK method")] public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, int k, ulong min_keep) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_top_k(ctx, ref st, k, min_keep); + candidates.TopK(ctx, k, min_keep); } /// @@ -114,10 +54,10 @@ public static void llama_sample_top_k(SafeLLamaContextHandle ctx, LLamaTokenData /// Pointer to LLamaTokenDataArray /// /// + [Obsolete("use LLamaTokenDataArray TopP method")] public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_top_p(ctx, ref st, p, min_keep); + candidates.TopP(ctx, p, min_keep); } /// @@ -127,10 +67,10 @@ public static void llama_sample_top_p(SafeLLamaContextHandle ctx, LLamaTokenData /// Pointer to LLamaTokenDataArray /// /// + [Obsolete("use LLamaTokenDataArray TailFree method")] public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float z, ulong min_keep) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_tail_free(ctx, ref st, z, min_keep); + candidates.TailFree(ctx, z, min_keep); } /// @@ -140,10 +80,10 @@ public static void llama_sample_tail_free(SafeLLamaContextHandle ctx, LLamaToken /// Pointer to LLamaTokenDataArray /// /// + [Obsolete("use LLamaTokenDataArray LocallyTypical method")] public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float p, ulong min_keep) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_typical(ctx, ref st, p, min_keep); + candidates.LocallyTypical(ctx, p, min_keep); } /// @@ -153,10 +93,10 @@ public static void llama_sample_typical(SafeLLamaContextHandle ctx, LLamaTokenDa /// /// /// + [Obsolete("use LLamaTokenDataArray Temperature() method")] public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float temp) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - NativeApi.llama_sample_temperature(ctx, ref st, temp); + candidates.Temperature(ctx, temp); } /// @@ -169,10 +109,10 @@ public static void llama_sample_temperature(SafeLLamaContextHandle ctx, LLamaTok /// The number of tokens considered in the estimation of `s_hat`. This is an arbitrary value that is used to calculate `s_hat`, which in turn helps to calculate the value of `k`. In the paper, they use `m = 100`, but you can experiment with different values to see how it affects the performance of the algorithm. /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// + [Obsolete("use LLamaTokenDataArray SampleTokenMirostat() method")] public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, int m, ref float mu) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - return NativeApi.llama_sample_token_mirostat(ctx, ref st, tau, eta, m, ref mu); + return candidates.SampleTokenMirostat(ctx, tau, eta, m, ref mu); } /// @@ -184,10 +124,10 @@ public static llama_token llama_sample_token_mirostat(SafeLLamaContextHandle ctx /// The learning rate used to update `mu` based on the error between the target and observed surprisal of the sampled word. A larger learning rate will cause `mu` to be updated more quickly, while a smaller learning rate will result in slower updates. /// Maximum cross-entropy. This value is initialized to be twice the target cross-entropy (`2 * tau`) and is updated in the algorithm based on the error between the target and observed surprisal. /// + [Obsolete("use LLamaTokenDataArray SampleTokenMirostat2() method")] public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates, float tau, float eta, ref float mu) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - return NativeApi.llama_sample_token_mirostat_v2(ctx, ref st, tau, eta, ref mu); + return candidates.SampleTokenMirostat2(ctx, tau, eta, ref mu); } /// @@ -196,10 +136,10 @@ public static llama_token llama_sample_token_mirostat_v2(SafeLLamaContextHandle /// /// Pointer to LLamaTokenDataArray /// + [Obsolete("Use LLamaTokenDataArray SampleTokenGreedy() method")] public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - return NativeApi.llama_sample_token_greedy(ctx, ref st); + return candidates.SampleTokenGreedy(ctx); } /// @@ -208,10 +148,10 @@ public static llama_token llama_sample_token_greedy(SafeLLamaContextHandle ctx, /// /// Pointer to LLamaTokenDataArray /// + [Obsolete("use LLamaTokenDataArray SampleToken() method")] public static llama_token llama_sample_token(SafeLLamaContextHandle ctx, LLamaTokenDataArray candidates) { - using var handle = LLamaTokenDataArrayNative.Create(candidates, out var st); - return NativeApi.llama_sample_token(ctx, ref st); + return candidates.SampleToken(ctx); } } } diff --git a/LLama/runtimes/ggml-metal.metal b/LLama/runtimes/ggml-metal.metal index 99b9fd7a7..f4b460564 100644 --- a/LLama/runtimes/ggml-metal.metal +++ b/LLama/runtimes/ggml-metal.metal @@ -18,6 +18,21 @@ typedef struct { uint8_t qs[QK4_1 / 2]; // nibbles / quants } block_q4_1; +#define QK5_0 32 +typedef struct { + half d; // delta + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_0 / 2]; // nibbles / quants +} block_q5_0; + +#define QK5_1 32 +typedef struct { + half d; // delta + half m; // min + uint8_t qh[4]; // 5-th bit of quants + uint8_t qs[QK5_1 / 2]; // nibbles / quants +} block_q5_1; + #define QK8_0 32 typedef struct { half d; // delta @@ -110,9 +125,17 @@ kernel void kernel_mul_row( } kernel void kernel_scale( + device const float * src0, + device float * dst, + constant float & scale, + uint tpig[[thread_position_in_grid]]) { + dst[tpig] = src0[tpig] * scale; +} + +kernel void kernel_scale_4( device const float4 * src0, device float4 * dst, - constant float & scale, + constant float & scale, uint tpig[[thread_position_in_grid]]) { dst[tpig] = src0[tpig] * scale; } @@ -399,8 +422,11 @@ kernel void kernel_rms_norm( // that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; + float2 acc = 0.f; + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 1 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -417,8 +443,11 @@ inline float block_q_n_dot_y(device const block_q4_0 * qb_curr, float sumy, thre inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thread float * yl, int il) { float d = qb_curr->d; float m = qb_curr->m; - device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 2 + il/2); + for (int i = 0; i < 8; i+=2) { acc[0] += yl[i + 0] * (qs[i / 2] & 0x000F) + yl[i + 1] * (qs[i / 2] & 0x0F00); @@ -428,6 +457,49 @@ inline float block_q_n_dot_y(device const block_q4_1 * qb_curr, float sumy, thre return d * (acc[0] + acc[1]) + sumy * m; } +// function for calculate inner product between half a q5_0 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_0/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_0 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 3 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (sumy * -16.f + acc[0] + acc[1]); +} + +// function for calculate inner product between half a q5_1 block and 16 floats (yl), sumy is SUM(yl[i]) +// il indicates where the q5 quants begin (0 or QK5_1/4) +// we assume that the yl's have been multiplied with the appropriate scale factor +// that corresponds to the missing bit shifts (1, 1/16, 1/256, 1/4096) +inline float block_q_n_dot_y(device const block_q5_1 * qb_curr, float sumy, thread float * yl, int il) { + float d = qb_curr->d; + float m = qb_curr->m; + + float2 acc = 0.f; + + device const uint16_t * qs = ((device const uint16_t *)qb_curr + 4 + il/2); + const uint32_t qh = *((device const uint32_t *)qb_curr->qh); + + for (int i = 0; i < 8; i+=2) { + acc[0] += yl[i + 0] * ((qs[i / 2] & 0x000F) | ((qh >> (i+0+il ) << 4 ) & 0x00010)) + + yl[i + 1] * ((qs[i / 2] & 0x0F00) | ((qh >> (i+1+il ) << 12) & 0x01000)); + acc[1] += yl[i + 8] * ((qs[i / 2] & 0x00F0) | ((qh >> (i+0+il+QK5_0/2) << 8 ) & 0x00100)) + + yl[i + 9] * ((qs[i / 2] & 0xF000) | ((qh >> (i+1+il+QK5_0/2) << 16) & 0x10000)); + } + return d * (acc[0] + acc[1]) + sumy * m; +} + // putting them in the kernel cause a significant performance penalty #define N_DST 4 // each SIMD group works on 4 rows #define N_SIMDGROUP 2 // number of SIMD groups in a thread group @@ -525,6 +597,43 @@ kernel void kernel_mul_mv_q4_1_f32( mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); } +kernel void kernel_mul_mv_q5_0_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + +kernel void kernel_mul_mv_q5_1_f32( + device const void * src0, + device const float * src1, + device float * dst, + constant int64_t & ne00, + constant int64_t & ne01[[buffer(4)]], + constant int64_t & ne02[[buffer(5)]], + constant int64_t & ne10[[buffer(9)]], + constant int64_t & ne12[[buffer(11)]], + constant int64_t & ne0[[buffer(15)]], + constant int64_t & ne1[[buffer(16)]], + constant uint & gqa[[buffer(17)]], + uint3 tgpig[[threadgroup_position_in_grid]], + uint tiisg[[thread_index_in_simdgroup]], + uint sgitg[[simdgroup_index_in_threadgroup]]) { + mul_vec_q_n_f32(src0,src1,dst,ne00,ne01,ne02,ne10,ne12,ne0,ne1,gqa,tgpig,tiisg,sgitg); +} + + #define NB_Q8_0 8 kernel void kernel_mul_mv_q8_0_f32( @@ -2149,6 +2258,62 @@ void dequantize_q4_1(device const block_q4_1 *xb, short il, thread type4x4 & reg } } +template +void dequantize_q5_0(device const block_q5_0 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 3); + const float d = xb->d; + const float md = -16.h * xb->d; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + md; + reg[i/2][2*(i%2)+1] = d * x1 + md; + } +} + +template +void dequantize_q5_1(device const block_q5_1 *xb, short il, thread type4x4 & reg) { + device const uint16_t * qs = ((device const uint16_t *)xb + 4); + const float d = xb->d; + const float m = xb->m; + const ushort mask = il ? 0x00F0 : 0x000F; + + const uint32_t qh = *((device const uint32_t *)xb->qh); + + const int x_mv = il ? 4 : 0; + + const int gh_mv = il ? 12 : 0; + const int gh_bk = il ? 0 : 4; + + for (int i = 0; i < 8; i++) { + // extract the 5-th bits for x0 and x1 + const uint8_t xh_0 = ((qh >> (gh_mv + 2*i )) << gh_bk) & 0x10; + const uint8_t xh_1 = ((qh >> (gh_mv + 2*i+1)) << gh_bk) & 0x10; + + // combine the 4-bits from qs with the 5th bit + const int32_t x0 = ((((qs[i] ) & mask) >> x_mv) | xh_0); + const int32_t x1 = ((((qs[i] >> 8) & mask) >> x_mv) | xh_1); + + reg[i/2][2*(i%2)+0] = d * x0 + m; + reg[i/2][2*(i%2)+1] = d * x1 + m; + } +} + template void dequantize_q8_0(device const block_q8_0 *xb, short il, thread type4x4 & reg) { device const int8_t * qs = ((device const int8_t *)xb->qs); @@ -2490,6 +2655,8 @@ template [[host_name("kernel_get_rows_f32")]] kernel get_rows_t kernel_get_rows template [[host_name("kernel_get_rows_f16")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q4_1")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_0")]] kernel get_rows_t kernel_get_rows; +template [[host_name("kernel_get_rows_q5_1")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q8_0")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q2_K")]] kernel get_rows_t kernel_get_rows; template [[host_name("kernel_get_rows_q3_K")]] kernel get_rows_t kernel_get_rows; @@ -2518,6 +2685,8 @@ template [[host_name("kernel_mul_mm_f32_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q4_1_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_0_f32")]] kernel mat_mm_t kernel_mul_mm; +template [[host_name("kernel_mul_mm_q5_1_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q8_0_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q2_K_f32")]] kernel mat_mm_t kernel_mul_mm; template [[host_name("kernel_mul_mm_q3_K_f32")]] kernel mat_mm_t kernel_mul_mm; diff --git a/LLama/runtimes/libllama-cuda11.dll b/LLama/runtimes/libllama-cuda11.dll index e5fc7dad5..70bb5b07a 100644 Binary files a/LLama/runtimes/libllama-cuda11.dll and b/LLama/runtimes/libllama-cuda11.dll differ diff --git a/LLama/runtimes/libllama-cuda11.so b/LLama/runtimes/libllama-cuda11.so index 3532fe998..45bac80be 100644 Binary files a/LLama/runtimes/libllama-cuda11.so and b/LLama/runtimes/libllama-cuda11.so differ diff --git a/LLama/runtimes/libllama-cuda12.dll b/LLama/runtimes/libllama-cuda12.dll index 89f27e243..7f64e0e38 100644 Binary files a/LLama/runtimes/libllama-cuda12.dll and b/LLama/runtimes/libllama-cuda12.dll differ diff --git a/LLama/runtimes/libllama-cuda12.so b/LLama/runtimes/libllama-cuda12.so index 81b4aa991..4a1e4380b 100644 Binary files a/LLama/runtimes/libllama-cuda12.so and b/LLama/runtimes/libllama-cuda12.so differ diff --git a/LLama/runtimes/libllama.dll b/LLama/runtimes/libllama.dll index 62d071ec8..00b93ba0f 100644 Binary files a/LLama/runtimes/libllama.dll and b/LLama/runtimes/libllama.dll differ diff --git a/LLama/runtimes/libllama.dylib b/LLama/runtimes/libllama.dylib index c2ca7ec8e..3f36bb359 100755 Binary files a/LLama/runtimes/libllama.dylib and b/LLama/runtimes/libllama.dylib differ diff --git a/LLama/runtimes/libllama.so b/LLama/runtimes/libllama.so index b9ef4c1da..5240d696d 100644 Binary files a/LLama/runtimes/libllama.so and b/LLama/runtimes/libllama.so differ