From f860f88c367d00a734253048d24a0395b1e6e512 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Tue, 2 Jan 2024 03:20:21 +0000 Subject: [PATCH] Code cleanup driven by R# suggestions: - Made `NativeApi` into a `static class` (it's not intended to be instantiated) - Moved `LLamaTokenType` enum out into a separate file - Made `LLamaSeqId` and `LLamaPos` into `record struct`, convenient to have equality etc --- LLama/Abstractions/IModelParams.cs | 2 +- LLama/Common/ChatHistory.cs | 9 +-- LLama/Common/FixedSizeQueue.cs | 1 - LLama/Common/InferenceParams.cs | 2 + LLama/Extensions/DictionaryExtensions.cs | 1 + LLama/Grammars/Grammar.cs | 8 +- LLama/LLamaContext.cs | 11 +-- LLama/LLamaEmbedder.cs | 18 ++--- LLama/LLamaWeights.cs | 2 +- LLama/Native/LLamaKvCacheView.cs | 4 +- LLama/Native/LLamaPos.cs | 4 +- LLama/Native/LLamaSeqId.cs | 4 +- LLama/Native/LLamaTokenType.cs | 12 +++ LLama/Native/NativeApi.BeamSearch.cs | 2 +- LLama/Native/NativeApi.Grammar.cs | 4 +- LLama/Native/NativeApi.Load.cs | 64 ++++++---------- LLama/Native/NativeApi.Quantize.cs | 2 +- LLama/Native/NativeApi.Sampling.cs | 4 +- LLama/Native/NativeApi.cs | 96 ++++++++++++------------ LLama/Native/NativeLibraryConfig.cs | 11 ++- LLama/Native/SafeLLamaContextHandle.cs | 3 - LLama/Native/SafeLlamaModelHandle.cs | 2 - 22 files changed, 126 insertions(+), 140 deletions(-) create mode 100644 LLama/Native/LLamaTokenType.cs diff --git a/LLama/Abstractions/IModelParams.cs b/LLama/Abstractions/IModelParams.cs index 4b4236f73..c4f96c37f 100644 --- a/LLama/Abstractions/IModelParams.cs +++ b/LLama/Abstractions/IModelParams.cs @@ -214,7 +214,7 @@ public sealed record MetadataOverride /// /// Get the key being overriden by this override /// - public string Key { get; init; } + public string Key { get; } internal LLamaModelKvOverrideType Type { get; } diff --git a/LLama/Common/ChatHistory.cs b/LLama/Common/ChatHistory.cs index 3f038874f..dc7414490 100644 --- a/LLama/Common/ChatHistory.cs +++ b/LLama/Common/ChatHistory.cs @@ -1,5 +1,4 @@ using System.Collections.Generic; -using System.IO; using System.Text.Json; using System.Text.Json.Serialization; @@ -37,6 +36,7 @@ public enum AuthorRole /// public class ChatHistory { + private static readonly JsonSerializerOptions _jsonOptions = new() { WriteIndented = true }; /// /// Chat message representation @@ -96,12 +96,7 @@ public void AddMessage(AuthorRole authorRole, string content) /// public string ToJson() { - return JsonSerializer.Serialize( - this, - new JsonSerializerOptions() - { - WriteIndented = true - }); + return JsonSerializer.Serialize(this, _jsonOptions); } /// diff --git a/LLama/Common/FixedSizeQueue.cs b/LLama/Common/FixedSizeQueue.cs index 6d272f23f..8c14a1961 100644 --- a/LLama/Common/FixedSizeQueue.cs +++ b/LLama/Common/FixedSizeQueue.cs @@ -2,7 +2,6 @@ using System.Collections; using System.Collections.Generic; using System.Linq; -using LLama.Extensions; namespace LLama.Common { diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index c1f395505..0e6020ad4 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -18,11 +18,13 @@ public record InferenceParams /// number of tokens to keep from initial prompt /// public int TokensKeep { get; set; } = 0; + /// /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response /// until it complete. /// public int MaxTokens { get; set; } = -1; + /// /// logit bias for specific tokens /// diff --git a/LLama/Extensions/DictionaryExtensions.cs b/LLama/Extensions/DictionaryExtensions.cs index b3643fae1..6599f6316 100644 --- a/LLama/Extensions/DictionaryExtensions.cs +++ b/LLama/Extensions/DictionaryExtensions.cs @@ -15,6 +15,7 @@ public static TValue GetValueOrDefault(this IReadOnlyDictionary(IReadOnlyDictionary dictionary, TKey key, TValue defaultValue) { + // ReSharper disable once CanSimplifyDictionaryTryGetValueWithGetValueOrDefault (this is a shim for that method!) return dictionary.TryGetValue(key, out var value) ? value : defaultValue; } } diff --git a/LLama/Grammars/Grammar.cs b/LLama/Grammars/Grammar.cs index 5135e341e..abb65aa30 100644 --- a/LLama/Grammars/Grammar.cs +++ b/LLama/Grammars/Grammar.cs @@ -15,7 +15,7 @@ public sealed class Grammar /// /// Index of the initial rule to start from /// - public ulong StartRuleIndex { get; set; } + public ulong StartRuleIndex { get; } /// /// The rules which make up this grammar @@ -121,6 +121,12 @@ private void PrintRule(StringBuilder output, GrammarRule rule) case LLamaGrammarElementType.CHAR_ALT: case LLamaGrammarElementType.CHAR_RNG_UPPER: break; + + case LLamaGrammarElementType.END: + case LLamaGrammarElementType.ALT: + case LLamaGrammarElementType.RULE_REF: + case LLamaGrammarElementType.CHAR: + case LLamaGrammarElementType.CHAR_NOT: default: output.Append("] "); break; diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index db0ac179c..abd8f879f 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -43,7 +43,7 @@ public sealed class LLamaContext /// /// The context params set for this context /// - public IContextParams Params { get; set; } + public IContextParams Params { get; } /// /// The native handle, which is used to be passed to the native APIs @@ -56,15 +56,6 @@ public sealed class LLamaContext /// public Encoding Encoding { get; } - internal LLamaContext(SafeLLamaContextHandle nativeContext, IContextParams @params, ILogger? logger = null) - { - Params = @params; - - _logger = logger; - Encoding = @params.Encoding; - NativeHandle = nativeContext; - } - /// /// Create a new LLamaContext for the given LLamaWeights /// diff --git a/LLama/LLamaEmbedder.cs b/LLama/LLamaEmbedder.cs index ab56280c3..c551016c0 100644 --- a/LLama/LLamaEmbedder.cs +++ b/LLama/LLamaEmbedder.cs @@ -12,17 +12,15 @@ namespace LLama public sealed class LLamaEmbedder : IDisposable { - private readonly LLamaContext _ctx; - /// /// Dimension of embedding vectors /// - public int EmbeddingSize => _ctx.EmbeddingSize; + public int EmbeddingSize => Context.EmbeddingSize; /// /// LLama Context /// - public LLamaContext Context => this._ctx; + public LLamaContext Context { get; } /// /// Create a new embedder, using the given LLamaWeights @@ -33,7 +31,7 @@ public sealed class LLamaEmbedder public LLamaEmbedder(LLamaWeights weights, IContextParams @params, ILogger? logger = null) { @params.EmbeddingMode = true; - _ctx = weights.CreateContext(@params, logger); + Context = weights.CreateContext(@params, logger); } /// @@ -72,20 +70,20 @@ public float[] GetEmbeddings(string text) /// public float[] GetEmbeddings(string text, bool addBos) { - var embed_inp_array = _ctx.Tokenize(text, addBos); + var embed_inp_array = Context.Tokenize(text, addBos); // TODO(Rinne): deal with log of prompt if (embed_inp_array.Length > 0) - _ctx.Eval(embed_inp_array, 0); + Context.Eval(embed_inp_array, 0); unsafe { - var embeddings = NativeApi.llama_get_embeddings(_ctx.NativeHandle); + var embeddings = NativeApi.llama_get_embeddings(Context.NativeHandle); if (embeddings == null) return Array.Empty(); - return new Span(embeddings, EmbeddingSize).ToArray(); + return embeddings.ToArray(); } } @@ -94,7 +92,7 @@ public float[] GetEmbeddings(string text, bool addBos) /// public void Dispose() { - _ctx.Dispose(); + Context.Dispose(); } } diff --git a/LLama/LLamaWeights.cs b/LLama/LLamaWeights.cs index 847be5515..5cb482add 100644 --- a/LLama/LLamaWeights.cs +++ b/LLama/LLamaWeights.cs @@ -64,7 +64,7 @@ public sealed class LLamaWeights /// public IReadOnlyDictionary Metadata { get; set; } - internal LLamaWeights(SafeLlamaModelHandle weights) + private LLamaWeights(SafeLlamaModelHandle weights) { NativeHandle = weights; Metadata = weights.ReadMetadata(); diff --git a/LLama/Native/LLamaKvCacheView.cs b/LLama/Native/LLamaKvCacheView.cs index ea1c1172c..65fbccba3 100644 --- a/LLama/Native/LLamaKvCacheView.cs +++ b/LLama/Native/LLamaKvCacheView.cs @@ -14,7 +14,7 @@ public struct LLamaKvCacheViewCell /// May be negative if the cell is not populated. /// public LLamaPos pos; -}; +} /// /// An updateable view of the KV cache (llama_kv_cache_view) @@ -130,7 +130,7 @@ public ref LLamaKvCacheView GetView() } } -partial class NativeApi +public static partial class NativeApi { /// /// Create an empty KV cache view. (use only for debugging purposes) diff --git a/LLama/Native/LLamaPos.cs b/LLama/Native/LLamaPos.cs index 67ede7d52..52d67d505 100644 --- a/LLama/Native/LLamaPos.cs +++ b/LLama/Native/LLamaPos.cs @@ -6,7 +6,7 @@ namespace LLama.Native; /// Indicates position in a sequence /// [StructLayout(LayoutKind.Sequential)] -public struct LLamaPos +public record struct LLamaPos { /// /// The raw value @@ -17,7 +17,7 @@ public struct LLamaPos /// Create a new LLamaPos /// /// - public LLamaPos(int value) + private LLamaPos(int value) { Value = value; } diff --git a/LLama/Native/LLamaSeqId.cs b/LLama/Native/LLamaSeqId.cs index 191a6b5ec..bcee74f13 100644 --- a/LLama/Native/LLamaSeqId.cs +++ b/LLama/Native/LLamaSeqId.cs @@ -6,7 +6,7 @@ namespace LLama.Native; /// ID for a sequence in a batch /// [StructLayout(LayoutKind.Sequential)] -public struct LLamaSeqId +public record struct LLamaSeqId { /// /// The raw value @@ -17,7 +17,7 @@ public struct LLamaSeqId /// Create a new LLamaSeqId /// /// - public LLamaSeqId(int value) + private LLamaSeqId(int value) { Value = value; } diff --git a/LLama/Native/LLamaTokenType.cs b/LLama/Native/LLamaTokenType.cs new file mode 100644 index 000000000..171e782ae --- /dev/null +++ b/LLama/Native/LLamaTokenType.cs @@ -0,0 +1,12 @@ +namespace LLama.Native; + +public enum LLamaTokenType +{ + LLAMA_TOKEN_TYPE_UNDEFINED = 0, + LLAMA_TOKEN_TYPE_NORMAL = 1, + LLAMA_TOKEN_TYPE_UNKNOWN = 2, + LLAMA_TOKEN_TYPE_CONTROL = 3, + LLAMA_TOKEN_TYPE_USER_DEFINED = 4, + LLAMA_TOKEN_TYPE_UNUSED = 5, + LLAMA_TOKEN_TYPE_BYTE = 6, +} \ No newline at end of file diff --git a/LLama/Native/NativeApi.BeamSearch.cs b/LLama/Native/NativeApi.BeamSearch.cs index 1049dbe3a..142b997bb 100644 --- a/LLama/Native/NativeApi.BeamSearch.cs +++ b/LLama/Native/NativeApi.BeamSearch.cs @@ -3,7 +3,7 @@ namespace LLama.Native; -public partial class NativeApi +public static partial class NativeApi { /// /// Type of pointer to the beam_search_callback function. diff --git a/LLama/Native/NativeApi.Grammar.cs b/LLama/Native/NativeApi.Grammar.cs index 84e298c7d..4d47872b5 100644 --- a/LLama/Native/NativeApi.Grammar.cs +++ b/LLama/Native/NativeApi.Grammar.cs @@ -5,7 +5,7 @@ namespace LLama.Native { using llama_token = Int32; - public unsafe partial class NativeApi + public static partial class NativeApi { /// /// Create a new grammar from the given set of grammar rules @@ -15,7 +15,7 @@ public unsafe partial class NativeApi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index); + public static extern unsafe IntPtr llama_grammar_init(LLamaGrammarElement** rules, ulong n_rules, ulong start_rule_index); /// /// Free all memory from the given SafeLLamaGrammarHandle diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs index d8a887252..5ae02c1ac 100644 --- a/LLama/Native/NativeApi.Load.cs +++ b/LLama/Native/NativeApi.Load.cs @@ -4,13 +4,12 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; -using System.Linq; using System.Runtime.InteropServices; using System.Text.Json; namespace LLama.Native { - public partial class NativeApi + public static partial class NativeApi { static NativeApi() { @@ -97,22 +96,13 @@ private static int GetCudaMajorVersion() } if (string.IsNullOrEmpty(version)) - { return -1; - } - else - { - version = version.Split('.')[0]; - bool success = int.TryParse(version, out var majorVersion); - if (success) - { - return majorVersion; - } - else - { - return -1; - } - } + + version = version.Split('.')[0]; + if (int.TryParse(version, out var majorVersion)) + return majorVersion; + + return -1; } private static string GetCudaVersionFromPath(string cudaPath) @@ -129,7 +119,7 @@ private static string GetCudaVersionFromPath(string cudaPath) { return string.Empty; } - return versionNode.GetString(); + return versionNode.GetString() ?? ""; } } catch (Exception) @@ -169,18 +159,14 @@ private static List GetLibraryTryOrder(NativeLibraryConfig.Description c { platform = OSPlatform.OSX; suffix = ".dylib"; - if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported) - { - prefix = "runtimes/osx-arm64/native/"; - } - else - { - prefix = "runtimes/osx-x64/native/"; - } + + prefix = System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported + ? "runtimes/osx-arm64/native/" + : "runtimes/osx-x64/native/"; } else { - throw new RuntimeError($"Your system plarform is not supported, please open an issue in LLamaSharp."); + throw new RuntimeError("Your system plarform is not supported, please open an issue in LLamaSharp."); } Log($"Detected OS Platform: {platform}", LogLevel.Information); @@ -275,15 +261,15 @@ private static IntPtr TryLoadLibrary() var libraryTryLoadOrder = GetLibraryTryOrder(configuration); - string[] preferredPaths = configuration.SearchDirectories; - string[] possiblePathPrefix = new string[] { - System.AppDomain.CurrentDomain.BaseDirectory, + var preferredPaths = configuration.SearchDirectories; + var possiblePathPrefix = new[] { + AppDomain.CurrentDomain.BaseDirectory, Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? "" }; - var tryFindPath = (string filename) => + string TryFindPath(string filename) { - foreach(var path in preferredPaths) + foreach (var path in preferredPaths) { if (File.Exists(Path.Combine(path, filename))) { @@ -291,7 +277,7 @@ private static IntPtr TryLoadLibrary() } } - foreach(var path in possiblePathPrefix) + foreach (var path in possiblePathPrefix) { if (File.Exists(Path.Combine(path, filename))) { @@ -300,21 +286,19 @@ private static IntPtr TryLoadLibrary() } return filename; - }; + } foreach (var libraryPath in libraryTryLoadOrder) { - var fullPath = tryFindPath(libraryPath); + var fullPath = TryFindPath(libraryPath); var result = TryLoad(fullPath, true); if (result is not null && result != IntPtr.Zero) { Log($"{fullPath} is selected and loaded successfully.", LogLevel.Information); - return result ?? IntPtr.Zero; - } - else - { - Log($"Tried to load {fullPath} but failed.", LogLevel.Information); + return (IntPtr)result; } + + Log($"Tried to load {fullPath} but failed.", LogLevel.Information); } if (!configuration.AllowFallback) diff --git a/LLama/Native/NativeApi.Quantize.cs b/LLama/Native/NativeApi.Quantize.cs index d4ff5cf80..b849e38d5 100644 --- a/LLama/Native/NativeApi.Quantize.cs +++ b/LLama/Native/NativeApi.Quantize.cs @@ -2,7 +2,7 @@ namespace LLama.Native { - public partial class NativeApi + public static partial class NativeApi { /// /// Returns 0 on success diff --git a/LLama/Native/NativeApi.Sampling.cs b/LLama/Native/NativeApi.Sampling.cs index 9e7d375b6..53a6dd233 100644 --- a/LLama/Native/NativeApi.Sampling.cs +++ b/LLama/Native/NativeApi.Sampling.cs @@ -5,7 +5,7 @@ namespace LLama.Native { using llama_token = Int32; - public unsafe partial class NativeApi + public static partial class NativeApi { /// /// Repetition penalty described in CTRL academic paper https://arxiv.org/abs/1909.05858, with negative logit fix. @@ -19,7 +19,7 @@ public unsafe partial class NativeApi /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. /// Frequency and presence penalties described in OpenAI API https://platform.openai.com/docs/api-reference/parameter-details. [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, + public static extern unsafe void llama_sample_repetition_penalties(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates, llama_token* last_tokens, ulong last_tokens_size, float penalty_repeat, diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 24b9f571d..1c7715f66 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -9,17 +9,6 @@ namespace LLama.Native { using llama_token = Int32; - public enum LLamaTokenType - { - LLAMA_TOKEN_TYPE_UNDEFINED = 0, - LLAMA_TOKEN_TYPE_NORMAL = 1, - LLAMA_TOKEN_TYPE_UNKNOWN = 2, - LLAMA_TOKEN_TYPE_CONTROL = 3, - LLAMA_TOKEN_TYPE_USER_DEFINED = 4, - LLAMA_TOKEN_TYPE_UNUSED = 5, - LLAMA_TOKEN_TYPE_BYTE = 6, - } - /// /// Callback from llama.cpp with log messages /// @@ -30,7 +19,7 @@ public enum LLamaTokenType /// /// Direct translation of the llama.cpp API /// - public unsafe partial class NativeApi + public static partial class NativeApi { /// /// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded. @@ -165,7 +154,7 @@ public unsafe partial class NativeApi /// /// the number of bytes copied [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); + public static extern unsafe ulong llama_copy_state_data(SafeLLamaContextHandle ctx, byte* dest); /// /// Set the state reading from the specified address @@ -174,7 +163,7 @@ public unsafe partial class NativeApi /// /// the number of bytes read [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); + public static extern unsafe ulong llama_set_state_data(SafeLLamaContextHandle ctx, byte* src); /// /// Load session file @@ -186,7 +175,7 @@ public unsafe partial class NativeApi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out); + public static extern unsafe bool llama_load_session_file(SafeLLamaContextHandle ctx, string path_session, llama_token[] tokens_out, ulong n_token_capacity, ulong* n_token_count_out); /// /// Save session file @@ -211,7 +200,7 @@ public unsafe partial class NativeApi /// Returns 0 on success [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] [Obsolete("use llama_decode() instead")] - public static extern int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past); + public static extern unsafe int llama_eval(SafeLLamaContextHandle ctx, llama_token* tokens, int n_tokens, int n_past); /// /// Convert the provided text into tokens. @@ -228,34 +217,37 @@ public unsafe partial class NativeApi /// public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encoding encoding, llama_token[] tokens, int n_max_tokens, bool add_bos, bool special) { - // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) - var byteCount = encoding.GetByteCount(text); - var array = ArrayPool.Shared.Rent(byteCount + 1); - try + unsafe { - // Convert to bytes - fixed (char* textPtr = text) - fixed (byte* arrayPtr = array) + // Calculate number of bytes in text and borrow an array that large (+1 for nul byte) + var byteCount = encoding.GetByteCount(text); + var array = ArrayPool.Shared.Rent(byteCount + 1); + try { - encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length); + // Convert to bytes + fixed (char* textPtr = text) + fixed (byte* arrayPtr = array) + { + encoding.GetBytes(textPtr, text.Length, arrayPtr, array.Length); + } + + // Add a zero byte to the end to terminate the string + array[byteCount] = 0; + + // Do the actual tokenization + fixed (byte* arrayPtr = array) + fixed (llama_token* tokensPtr = tokens) + return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special); + } + finally + { + ArrayPool.Shared.Return(array); } - - // Add a zero byte to the end to terminate the string - array[byteCount] = 0; - - // Do the actual tokenization - fixed (byte* arrayPtr = array) - fixed (llama_token* tokensPtr = tokens) - return llama_tokenize(ctx.ModelHandle, arrayPtr, byteCount, tokensPtr, n_max_tokens, add_bos, special); - } - finally - { - ArrayPool.Shared.Return(array); } } [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token); + public static extern unsafe byte* llama_token_get_text(SafeLlamaModelHandle model, llama_token token); [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] public static extern float llama_token_get_score(SafeLlamaModelHandle model, llama_token token); @@ -281,7 +273,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern float* llama_get_logits(SafeLLamaContextHandle ctx); + public static extern unsafe float* llama_get_logits(SafeLLamaContextHandle ctx); /// /// Logits for the ith token. Equivalent to: llama_get_logits(ctx) + i*n_vocab @@ -290,16 +282,24 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); + public static extern unsafe float* llama_get_logits_ith(SafeLLamaContextHandle ctx, int i); /// /// Get the embeddings for the input - /// shape: [n_embd] (1-dimensional) /// /// /// - [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern float* llama_get_embeddings(SafeLLamaContextHandle ctx); + public static Span llama_get_embeddings(SafeLLamaContextHandle ctx) + { + unsafe + { + var ptr = llama_get_embeddings_native(ctx); + return new Span(ptr, ctx.EmbeddingSize); + } + + [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_get_embeddings")] + static extern unsafe float* llama_get_embeddings_native(SafeLLamaContextHandle ctx); + } /// /// Get the "Beginning of sentence" token @@ -426,7 +426,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// /// The length of the string on success, or -1 on failure [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); + public static extern unsafe int llama_model_meta_val_str(SafeLlamaModelHandle model, byte* key, byte* buf, long buf_size); /// /// Get the number of metadata key/value pairs @@ -445,7 +445,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// /// The length of the string on success, or -1 on failure [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + public static extern unsafe int llama_model_meta_key_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); /// /// Get metadata value as a string by index @@ -456,7 +456,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// /// The length of the string on success, or -1 on failure [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); + public static extern unsafe int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model, int index, byte* buf, long buf_size); /// /// Get a string describing the model type @@ -466,7 +466,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// /// The length of the string on success, or -1 on failure [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size); + public static extern unsafe int llama_model_desc(SafeLlamaModelHandle model, byte* buf, long buf_size); /// /// Get the size of the model in bytes @@ -493,7 +493,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// size of the buffer /// The length written, or if the buffer is too small a negative that indicates the length required [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); + public static extern unsafe int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length); /// /// Convert text into tokens @@ -509,7 +509,7 @@ public static int llama_tokenize(SafeLLamaContextHandle ctx, string text, Encodi /// Returns a negative number on failure - the number of tokens that would have been returned /// [DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)] - public static extern int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special); + public static extern unsafe int llama_tokenize(SafeLlamaModelHandle model, byte* text, int text_len, int* tokens, int n_max_tokens, bool add_bos, bool special); /// /// Register a callback to receive llama log messages diff --git a/LLama/Native/NativeLibraryConfig.cs b/LLama/Native/NativeLibraryConfig.cs index aaa328c91..ad52fc816 100644 --- a/LLama/Native/NativeLibraryConfig.cs +++ b/LLama/Native/NativeLibraryConfig.cs @@ -29,10 +29,11 @@ public sealed class NativeLibraryConfig private bool _allowFallback = true; private bool _skipCheck = false; private bool _logging = false; + /// /// search directory -> priority level, 0 is the lowest. /// - private List _searchDirectories = new List(); + private readonly List _searchDirectories = new List(); private static void ThrowIfLoaded() { @@ -159,9 +160,8 @@ public NativeLibraryConfig WithSearchDirectory(string directory) internal static Description CheckAndGatherDescription() { if (Instance._allowFallback && Instance._skipCheck) - { throw new ArgumentException("Cannot skip the check when fallback is allowed."); - } + return new Description( Instance._libraryPath, Instance._useCuda, @@ -169,7 +169,8 @@ internal static Description CheckAndGatherDescription() Instance._allowFallback, Instance._skipCheck, Instance._logging, - Instance._searchDirectories.Concat(new string[] { "./" }).ToArray()); + Instance._searchDirectories.Concat(new[] { "./" }).ToArray() + ); } internal static string AvxLevelToString(AvxLevel level) @@ -204,7 +205,9 @@ private static bool CheckAVX512() if (!System.Runtime.Intrinsics.X86.X86Base.IsSupported) return false; + // ReSharper disable UnusedVariable (ebx is used when < NET8) var (_, ebx, ecx, _) = System.Runtime.Intrinsics.X86.X86Base.CpuId(7, 0); + // ReSharper restore UnusedVariable var vnni = (ecx & 0b_1000_0000_0000) != 0; diff --git a/LLama/Native/SafeLLamaContextHandle.cs b/LLama/Native/SafeLLamaContextHandle.cs index df33076f9..98b510783 100644 --- a/LLama/Native/SafeLLamaContextHandle.cs +++ b/LLama/Native/SafeLLamaContextHandle.cs @@ -1,6 +1,5 @@ using System; using System.Buffers; -using System.Collections.Generic; using System.Text; using LLama.Exceptions; @@ -51,8 +50,6 @@ public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model) _model.DangerousAddRef(ref success); if (!success) throw new RuntimeError("Failed to increment model refcount"); - - } /// diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index d62e50417..2280250ec 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -214,7 +214,6 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) /// Get the metadata key for the given index /// /// The index to get - /// A temporary buffer to store key characters in. Must be large enough to contain the key. /// The key, null if there is no such key or if the buffer was too small public Memory? MetadataKeyByIndex(int index) { @@ -243,7 +242,6 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params) /// Get the metadata value for the given index /// /// The index to get - /// A temporary buffer to store value characters in. Must be large enough to contain the value. /// The value, null if there is no such value or if the buffer was too small public Memory? MetadataValueByIndex(int index) {