From e11df78467bc57ab989aeeb73a86320e6fe3d5f6 Mon Sep 17 00:00:00 2001 From: Yaohui Liu Date: Sun, 12 Nov 2023 01:48:18 +0800 Subject: [PATCH] fix: cannot load library under some conditions. --- LLama/Native/NativeApi.cs | 51 +++++++++++++++++++++++++++++---------- 1 file changed, 38 insertions(+), 13 deletions(-) diff --git a/LLama/Native/NativeApi.cs b/LLama/Native/NativeApi.cs index 3819980fe..9ff1c4ede 100644 --- a/LLama/Native/NativeApi.cs +++ b/LLama/Native/NativeApi.cs @@ -3,6 +3,7 @@ using System.Collections.Generic; using System.IO; using System.Runtime.InteropServices; +using System.Security.Cryptography.X509Certificates; using System.Text; using System.Text.Json; using LLama.Exceptions; @@ -132,7 +133,7 @@ private static string GetAvxLibraryPath(NativeLibraryConfig.AvxLevel avxLevel, s { avxStr += "/"; } - return $"{prefix}{avxStr}{libraryName}{suffix}"; + return $"{prefix}{avxStr}"; } private static List GetLibraryTryOrder(NativeLibraryConfig.Description configuration) @@ -180,8 +181,8 @@ private static List GetLibraryTryOrder(NativeLibraryConfig.Description c // if check skipped, we just try to load cuda libraries one by one. if (configuration.SkipCheck) { - result.Add($"{prefix}cuda12/{libraryName}{suffix}"); - result.Add($"{prefix}cuda11/{libraryName}{suffix}"); + result.Add($"{prefix}cuda12/"); + result.Add($"{prefix}cuda11/"); } else { @@ -190,11 +191,11 @@ private static List GetLibraryTryOrder(NativeLibraryConfig.Description c } else if (cudaVersion == 11) { - result.Add($"{prefix}cuda11/{libraryName}{suffix}"); + result.Add($"{prefix}cuda11/"); } else if (cudaVersion == 12) { - result.Add($"{prefix}cuda12/{libraryName}{suffix}"); + result.Add($"{prefix}cuda12/"); } else if (cudaVersion > 0) { @@ -233,7 +234,7 @@ private static List GetLibraryTryOrder(NativeLibraryConfig.Description c if(platform == OSPlatform.OSX) { - result.Add($"{prefix}{libraryName}{suffix}"); + result.Add($"{prefix}"); } return result; @@ -252,29 +253,53 @@ private static IntPtr TryLoadLibrary() if (!string.IsNullOrEmpty(configuration.Path)) { // When loading the user specified library, there's no fallback. - var result = TryLoad(configuration.Path, true); - if (result is null || result == IntPtr.Zero) + var result = NativeLibrary.Load(configuration.Path); + if (result == IntPtr.Zero) { throw new RuntimeError($"Failed to load the native library [{configuration.Path}] you specified."); } - return result ?? IntPtr.Zero; + return result; } var libraryTryLoadOrder = GetLibraryTryOrder(configuration); + string[] possiblePathPrefix = new string[] { + System.AppDomain.CurrentDomain.BaseDirectory, + Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? "" + }; + + var tryFindPath = (string filename) => + { + int i = 0; + while (!File.Exists(filename)) + { + if(i < possiblePathPrefix.Length) + { + filename = Path.Combine(possiblePathPrefix[i], filename); + i++; + } + else + { + break; + } + } + return filename; + }; + foreach(var libraryPath in libraryTryLoadOrder) { - var result = TryLoad(libraryPath, true); + var fullPath = tryFindPath(libraryPath); + var result = TryLoad(fullPath, true); if(result is not null && result != IntPtr.Zero) { Console.ForegroundColor = ConsoleColor.Red; - Console.WriteLine($"[Native Library] {libraryPath} is loaded."); + Console.WriteLine($"[Native Library] {fullPath} is loaded."); Console.ResetColor(); return result ?? IntPtr.Zero; } else { - Console.WriteLine($"Tried to load {libraryPath}"); + Console.WriteLine($"Tried to load {fullPath}"); } } @@ -296,7 +321,7 @@ private static IntPtr TryLoadLibrary() if (!supported) return null; - if (NativeLibrary.TryLoad(path, out var handle)) + if (NativeLibrary.TryLoad(libraryName, System.Reflection.Assembly.GetExecutingAssembly(), )) return handle; return null;