From b05c3154f49225900f3fc4405fff78c3c500ce03 Mon Sep 17 00:00:00 2001 From: Rinne Date: Tue, 28 Nov 2023 20:58:32 +0800 Subject: [PATCH 1/2] feat: allow customized search path for native library loading. --- LLama/ChatSession.cs | 24 +++++--- LLama/Native/NativeApi.Load.cs | 20 ++++--- LLama/Native/NativeLibraryConfig.cs | 88 ++++++++++++++++++++++++++++- 3 files changed, 115 insertions(+), 17 deletions(-) diff --git a/LLama/ChatSession.cs b/LLama/ChatSession.cs index 7ee995906..5c535a6bd 100644 --- a/LLama/ChatSession.cs +++ b/LLama/ChatSession.cs @@ -152,14 +152,24 @@ public async IAsyncEnumerable ChatAsync(string prompt, IInferenceParams? foreach (var inputTransform in InputTransformPipeline) prompt = inputTransform.Transform(prompt); - History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); - - if (_executor is InteractiveExecutor executor) + // TODO: need to be refactored. + if (_executor is InteractiveExecutor executor && ((InteractiveExecutorState)executor.GetStateData()).IsPromptRun) { - InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData(); - prompt = state.IsPromptRun - ? HistoryTransform.HistoryToText(History) - : prompt; + History.Messages.Add(new ChatHistory.Message(AuthorRole.System, prompt)); + var converted_prompt = HistoryTransform.HistoryToText(History); + // Avoid missing anti-prompt. + if (!prompt.EndsWith("\n") && !prompt.EndsWith("\r\n")) + { + prompt = converted_prompt.Trim(); + } + else + { + prompt = converted_prompt; + } + } + else + { + History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt)); } StringBuilder sb = new(); diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs index 148f1735e..36ae55bf4 100644 --- a/LLama/Native/NativeApi.Load.cs +++ b/LLama/Native/NativeApi.Load.cs @@ -4,6 +4,7 @@ using System.Collections.Generic; using System.Diagnostics; using System.IO; +using System.Linq; using System.Runtime.InteropServices; using System.Text.Json; @@ -258,6 +259,7 @@ private static IntPtr TryLoadLibrary() enableLogging = configuration.Logging; // We move the flag to avoid loading library when the variable is called else where. NativeLibraryConfig.LibraryHasLoaded = true; + Log(configuration.ToString(), LogLevel.Information); if (!string.IsNullOrEmpty(configuration.Path)) { @@ -273,6 +275,7 @@ private static IntPtr TryLoadLibrary() var libraryTryLoadOrder = GetLibraryTryOrder(configuration); + string[] preferredPaths = configuration.SearchDirectories.OrderByDescending(kv => kv.Value).Select(kv => kv.Key).ToArray(); string[] possiblePathPrefix = new string[] { System.AppDomain.CurrentDomain.BaseDirectory, Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? "" @@ -280,19 +283,22 @@ private static IntPtr TryLoadLibrary() var tryFindPath = (string filename) => { - int i = 0; - while (!File.Exists(filename)) + foreach(var path in preferredPaths) { - if (i < possiblePathPrefix.Length) + if (File.Exists(Path.Combine(path, filename))) { - filename = Path.Combine(possiblePathPrefix[i], filename); - i++; + return Path.Combine(path, filename); } - else + } + + foreach(var path in possiblePathPrefix) + { + if (File.Exists(Path.Combine(path, filename))) { - break; + return Path.Combine(path, filename); } } + return filename; }; diff --git a/LLama/Native/NativeLibraryConfig.cs b/LLama/Native/NativeLibraryConfig.cs index 76f893576..fcef9c34f 100644 --- a/LLama/Native/NativeLibraryConfig.cs +++ b/LLama/Native/NativeLibraryConfig.cs @@ -1,4 +1,6 @@ using System; +using System.Collections.Generic; +using System.Linq; namespace LLama.Native { @@ -27,6 +29,13 @@ public sealed class NativeLibraryConfig private bool _allowFallback = true; private bool _skipCheck = false; private bool _logging = false; + /// + /// search directory -> priority level, 0 is the lowest. + /// + private Dictionary _searchDirectories = new Dictionary() + { + { "./", 0 } + }; private static void ThrowIfLoaded() { @@ -120,13 +129,62 @@ public NativeLibraryConfig WithLogs(bool enable = true) return this; } + /// + /// Add self-defined search directories. Note that the file stucture of the added + /// directories must be the same as the default directory. Besides, the directory + /// won't be used recursively. + /// + /// The directories and corresponding priorities, in which 0 is the lowest. The default path has priority 0. + /// + public NativeLibraryConfig WithSearchDirectories(IDictionary directoriesAndPriorities) + { + ThrowIfLoaded(); + + foreach(var (directory, priority) in directoriesAndPriorities) + { + if(priority < 0) + { + throw new ArgumentException("Priority must be a positive number."); + } + _searchDirectories[directory] = priority; + } + return this; + } + + /// + /// Add self-defined search directories. Note that the file stucture of the added + /// directories must be the same as the default directory. Besides, the directory + /// won't be used recursively. + /// + /// + /// The priority of your added search path. 0 is the lowest. The default path has priority 0. + /// + public NativeLibraryConfig WithSearchDirectory(string directory, int priority) + { + ThrowIfLoaded(); + + if (priority < 0) + { + throw new ArgumentException("Priority must be a positive number."); + } + _searchDirectories[directory] = priority; + return this; + } + internal static Description CheckAndGatherDescription() { if (Instance._allowFallback && Instance._skipCheck) { throw new ArgumentException("Cannot skip the check when fallback is allowed."); } - return new Description(Instance._libraryPath, Instance._useCuda, Instance._avxLevel, Instance._allowFallback, Instance._skipCheck, Instance._logging); + return new Description( + Instance._libraryPath, + Instance._useCuda, + Instance._avxLevel, + Instance._allowFallback, + Instance._skipCheck, + Instance._logging, + Instance._searchDirectories); } internal static string AvxLevelToString(AvxLevel level) @@ -183,7 +241,31 @@ public enum AvxLevel Avx512, } - internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging); + internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging, Dictionary SearchDirectories) + { + public override string ToString() + { + string avxLevelString = AvxLevel switch + { + AvxLevel.None => "NoAVX", + AvxLevel.Avx => "AVX", + AvxLevel.Avx2 => "AVX2", + AvxLevel.Avx512 => "AVX512", + _ => "Unknown" + }; + + string searchDirectoriesString = string.Join(", ", SearchDirectories.Select(kv => $"[{kv.Key}: {kv.Value}]")); + + return $"NativeLibraryConfig Description:\n" + + $"- Path: {Path}\n" + + $"- PreferCuda: {UseCuda}\n" + + $"- PreferredAvxLevel: {avxLevelString}\n" + + $"- AllowFallback: {AllowFallback}\n" + + $"- SkipCheck: {SkipCheck}\n" + + $"- Logging: {Logging}\n" + + $"- SearchDirectories and Priorities: {searchDirectoriesString}"; + } + } } #endif - } +} From ffc347a3f302f0d0817bb56acd80a03d532f60ae Mon Sep 17 00:00:00 2001 From: Rinne Date: Wed, 29 Nov 2023 00:16:00 +0800 Subject: [PATCH 2/2] resolve comments. --- LLama/Native/NativeApi.Load.cs | 2 +- LLama/Native/NativeLibraryConfig.cs | 33 ++++++++--------------------- 2 files changed, 10 insertions(+), 25 deletions(-) diff --git a/LLama/Native/NativeApi.Load.cs b/LLama/Native/NativeApi.Load.cs index 36ae55bf4..d8a887252 100644 --- a/LLama/Native/NativeApi.Load.cs +++ b/LLama/Native/NativeApi.Load.cs @@ -275,7 +275,7 @@ private static IntPtr TryLoadLibrary() var libraryTryLoadOrder = GetLibraryTryOrder(configuration); - string[] preferredPaths = configuration.SearchDirectories.OrderByDescending(kv => kv.Value).Select(kv => kv.Key).ToArray(); + string[] preferredPaths = configuration.SearchDirectories; string[] possiblePathPrefix = new string[] { System.AppDomain.CurrentDomain.BaseDirectory, Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? "" diff --git a/LLama/Native/NativeLibraryConfig.cs b/LLama/Native/NativeLibraryConfig.cs index fcef9c34f..e51359707 100644 --- a/LLama/Native/NativeLibraryConfig.cs +++ b/LLama/Native/NativeLibraryConfig.cs @@ -32,10 +32,7 @@ public sealed class NativeLibraryConfig /// /// search directory -> priority level, 0 is the lowest. /// - private Dictionary _searchDirectories = new Dictionary() - { - { "./", 0 } - }; + private List _searchDirectories = new List(); private static void ThrowIfLoaded() { @@ -134,20 +131,13 @@ public NativeLibraryConfig WithLogs(bool enable = true) /// directories must be the same as the default directory. Besides, the directory /// won't be used recursively. /// - /// The directories and corresponding priorities, in which 0 is the lowest. The default path has priority 0. + /// /// - public NativeLibraryConfig WithSearchDirectories(IDictionary directoriesAndPriorities) + public NativeLibraryConfig WithSearchDirectories(IEnumerable directories) { ThrowIfLoaded(); - foreach(var (directory, priority) in directoriesAndPriorities) - { - if(priority < 0) - { - throw new ArgumentException("Priority must be a positive number."); - } - _searchDirectories[directory] = priority; - } + _searchDirectories.AddRange(directories); return this; } @@ -157,17 +147,12 @@ public NativeLibraryConfig WithSearchDirectories(IDictionary direct /// won't be used recursively. /// /// - /// The priority of your added search path. 0 is the lowest. The default path has priority 0. /// - public NativeLibraryConfig WithSearchDirectory(string directory, int priority) + public NativeLibraryConfig WithSearchDirectory(string directory) { ThrowIfLoaded(); - if (priority < 0) - { - throw new ArgumentException("Priority must be a positive number."); - } - _searchDirectories[directory] = priority; + _searchDirectories.Add(directory); return this; } @@ -184,7 +169,7 @@ internal static Description CheckAndGatherDescription() Instance._allowFallback, Instance._skipCheck, Instance._logging, - Instance._searchDirectories); + Instance._searchDirectories.Concat(new string[] { "./" }).ToArray()); } internal static string AvxLevelToString(AvxLevel level) @@ -241,7 +226,7 @@ public enum AvxLevel Avx512, } - internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging, Dictionary SearchDirectories) + internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging, string[] SearchDirectories) { public override string ToString() { @@ -254,7 +239,7 @@ public override string ToString() _ => "Unknown" }; - string searchDirectoriesString = string.Join(", ", SearchDirectories.Select(kv => $"[{kv.Key}: {kv.Value}]")); + string searchDirectoriesString = "{ " + string.Join(", ", SearchDirectories) + " }"; return $"NativeLibraryConfig Description:\n" + $"- Path: {Path}\n" +