Skip to content

Commit

Permalink
feat: support cuda feature detection.
Browse files Browse the repository at this point in the history
  • Loading branch information
AsakusaRinne committed Nov 11, 2023
1 parent 5fe721b commit d03e1db
Show file tree
Hide file tree
Showing 5 changed files with 351 additions and 35 deletions.
18 changes: 9 additions & 9 deletions LLama/LLamaSharp.Runtime.targets
Original file line number Diff line number Diff line change
Expand Up @@ -5,39 +5,39 @@
<ItemGroup Condition="'$(IncludeBuiltInRuntimes)' == 'true'">
<None Include="$(MSBuildThisFileDirectory)runtimes/libllama.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>libllama.dll</Link>
<Link>runtimes/win-x64/native/libllama.dll</Link>
</None>
<None Include="$(MSBuildThisFileDirectory)runtimes/libllama-cuda11.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>libllama-cuda11.dll</Link>
<Link>runtimes/win-x64/native/cuda11/libllama.dll</Link>
</None>
<None Include="$(MSBuildThisFileDirectory)runtimes/libllama-cuda12.dll">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>libllama-cuda12.dll</Link>
<Link>runtimes/win-x64/native/cuda12/libllama.dll</Link>
</None>
<None Include="$(MSBuildThisFileDirectory)runtimes/libllama.so">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>libllama.so</Link>
<Link>runtimes/linux-x64/native/libllama.so</Link>
</None>
<None Include="$(MSBuildThisFileDirectory)runtimes/libllama-cuda11.so">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>libllama-cuda11.so</Link>
<Link>runtimes/linux-x64/native/cuda11/libllama.so</Link>
</None>
<None Include="$(MSBuildThisFileDirectory)runtimes/libllama-cuda12.so">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>libllama-cuda12.so</Link>
<Link>runtimes/linux-x64/native/cuda12/libllama.so</Link>
</None>
<None Include="$(MSBuildThisFileDirectory)runtimes/macos-arm64/libllama.dylib">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>runtimes/macos-arm64/libllama.dylib</Link>
<Link>runtimes/osx-arm64/native/libllama.dylib</Link>
</None>
<None Include="$(MSBuildThisFileDirectory)runtimes/macos-arm64/ggml-metal.metal">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>runtimes/macos-arm64/ggml-metal.metal</Link>
<Link>runtimes/osx-arm64/native/ggml-metal.metal</Link>
</None>
<None Include="$(MSBuildThisFileDirectory)runtimes/macos-x86_64/libllama.dylib">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
<Link>runtimes/macos-x86_64/libllama.dylib</Link>
<Link>runtimes/osx-x64/native/libllama.dylib</Link>
</None>
</ItemGroup>
</Project>
2 changes: 1 addition & 1 deletion LLama/LLamaSharp.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
<Platforms>AnyCPU;x64;Arm64</Platforms>
<AllowUnsafeBlocks>True</AllowUnsafeBlocks>

<Version>0.5.0</Version>
<Version>0.7.1</Version>
<Authors>Yaohui Liu, Martin Evans, Haiping Chen</Authors>
<Company>SciSharp STACK</Company>
<GeneratePackageOnBuild>true</GeneratePackageOnBuild>
Expand Down
251 changes: 227 additions & 24 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.IO;
using System.Runtime.InteropServices;
using System.Text;
using System.Text.Json;
using LLama.Exceptions;

#pragma warning disable IDE1006 // Naming Styles
Expand Down Expand Up @@ -43,45 +46,244 @@ static NativeApi()
llama_backend_init(false);
}

private static int GetCudaMajorVersion()
{
string? cudaPath;
string version = "";
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
cudaPath = Environment.GetEnvironmentVariable("CUDA_PATH");
if(cudaPath is null)
{
return -1;
}
version = GetCudaVersionFromPath(cudaPath);
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
// Try the default first
cudaPath = "/usr/local/bin/cuda";
version = GetCudaVersionFromPath(cudaPath);
if (string.IsNullOrEmpty(version))
{
cudaPath = Environment.GetEnvironmentVariable("LD_LIBRARY_PATH");
if(cudaPath is null)
{
return -1;
}
foreach(var path in cudaPath.Split(':'))
{
version = GetCudaVersionFromPath(Path.Combine(path, ".."));
if (string.IsNullOrEmpty(version))
{
break;
}
}
}
}

if (string.IsNullOrEmpty(version))
{
return -1;
}
else
{
version = version.Split('.')[0];
bool success = int.TryParse(version, out var majorVersion);
if (success)
{
return majorVersion;
}
else
{
return -1;
}
}
}

private static string GetCudaVersionFromPath(string cudaPath)
{
try
{
string json = File.ReadAllText(Path.Combine(cudaPath, cudaVersionFile));
using (JsonDocument document = JsonDocument.Parse(json))
{
JsonElement root = document.RootElement;
JsonElement cublasNode = root.GetProperty("libcublas");
JsonElement versionNode = cublasNode.GetProperty("version");
if (versionNode.ValueKind == JsonValueKind.Undefined)
{
return string.Empty;
}
return versionNode.GetString();
}
}
catch (Exception)
{
return string.Empty;
}
}

#if NET6_0_OR_GREATER
private static string GetAvxLibraryPath(NativeLibraryConfig.AvxLevel avxLevel, string prefix, string suffix)
{
var avxStr = NativeLibraryConfig.AvxLevelToString(avxLevel);
if (!string.IsNullOrEmpty(avxStr))
{
avxStr += "/";
}
return $"{prefix}{avxStr}{libraryName}{suffix}";
}

private static List<string> GetLibraryTryOrder(NativeLibraryConfig.Description configuration)
{
OSPlatform platform;
string prefix, suffix;
if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
{
platform = OSPlatform.Windows;
prefix = "runtimes/win-x64/native/";
suffix = ".dll";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
{
platform = OSPlatform.Linux;
prefix = "runtimes/linux-x64/native/";
suffix = ".so";
}
else if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
{
platform = OSPlatform.OSX;
suffix = ".dylib";
if (System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported)
{
prefix = "runtimes/osx-arm64/native/";
}
else
{
prefix = "runtimes/osx-x64/native/";
}
}
else
{
throw new RuntimeError($"Your system plarform is not supported, please open an issue in LLamaSharp.");
}

List<string> result = new();
if (configuration.UseCuda && (platform == OSPlatform.Windows || platform == OSPlatform.Linux)) // no cuda on macos
{
int cudaVersion = GetCudaMajorVersion();

// TODO: load cuda library with avx
if (cudaVersion == -1 && !configuration.AllowFallback)
{
// if check skipped, we just try to load cuda libraries one by one.
if (configuration.SkipCheck)
{
result.Add($"{prefix}cuda12/{libraryName}{suffix}");
result.Add($"{prefix}cuda11/{libraryName}{suffix}");
}
else
{
throw new RuntimeError("Configured to load a cuda library but no cuda detected on your device.");
}
}
else if (cudaVersion == 11)
{
result.Add($"{prefix}cuda11/{libraryName}{suffix}");
}
else if (cudaVersion == 12)
{
result.Add($"{prefix}cuda12/{libraryName}{suffix}");
}
else if (cudaVersion > 0)
{
throw new RuntimeError($"Cuda version {cudaVersion} hasn't been supported by LLamaSharp, please open an issue for it.");
}
// otherwise no cuda detected but allow fallback
}

// use cpu (or mac possibly with metal)
if (!configuration.AllowFallback && platform != OSPlatform.OSX)
{
result.Add(GetAvxLibraryPath(configuration.AvxLevel, prefix, suffix));
}
else if(platform != OSPlatform.OSX) // in macos there's absolutely no avx
{
#if NET8_0_OR_GREATER
if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx512)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx512, prefix, suffix)));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix)));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix)));
}
else
#endif
if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx2)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx2, prefix, suffix));
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix));
}
else if (configuration.AvxLevel == NativeLibraryConfig.AvxLevel.Avx)
{
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.Avx, prefix, suffix));
}
result.Add(GetAvxLibraryPath(NativeLibraryConfig.AvxLevel.None, prefix, suffix));
}

if(platform == OSPlatform.OSX)
{
result.Add($"{prefix}{libraryName}{suffix}");
}

return result;
}
#endif

/// <summary>
/// Try to load libllama, using CPU feature detection to try and load a more specialised DLL if possible
/// </summary>
/// <returns>The library handle to unload later, or IntPtr.Zero if no library was loaded</returns>
private static IntPtr TryLoadLibrary()
{
#if NET6_0_OR_GREATER
var configuration = NativeLibraryConfig.GetInstance().Desc;

if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows))
if (!string.IsNullOrEmpty(configuration.Path))
{
// All of the Windows libraries, in order of preference
return TryLoad("cu12.1.0/libllama.dll")
?? TryLoad("cu11.7.1/libllama.dll")
#if NET8_0_OR_GREATER
?? TryLoad("avx512/libllama.dll", System.Runtime.Intrinsics.X86.Avx512.IsSupported)
#endif
?? TryLoad("avx2/libllama.dll", System.Runtime.Intrinsics.X86.Avx2.IsSupported)
?? TryLoad("avx/libllama.dll", System.Runtime.Intrinsics.X86.Avx.IsSupported)
?? IntPtr.Zero;
// When loading the user specified library, there's no fallback.
var result = TryLoad(configuration.Path, true);
if (result is null || result == IntPtr.Zero)
{
throw new RuntimeError($"Failed to load the native library [{configuration.Path}] you specified.");
}
return result ?? IntPtr.Zero;
}

if (RuntimeInformation.IsOSPlatform(OSPlatform.Linux))
var libraryTryLoadOrder = GetLibraryTryOrder(configuration);

foreach(var libraryPath in libraryTryLoadOrder)
{
// All of the Linux libraries, in order of preference
return TryLoad("cu12.1.0/libllama.so")
?? TryLoad("cu11.7.1/libllama.so")
#if NET8_0_OR_GREATER
?? TryLoad("avx512/libllama.so", System.Runtime.Intrinsics.X86.Avx512.IsSupported)
#endif
?? TryLoad("avx2/libllama.so", System.Runtime.Intrinsics.X86.Avx2.IsSupported)
?? TryLoad("avx/libllama.so", System.Runtime.Intrinsics.X86.Avx.IsSupported)
?? IntPtr.Zero;
var result = TryLoad(libraryPath, true);
if(result is not null && result != IntPtr.Zero)
{
Console.ForegroundColor = ConsoleColor.Red;
Console.WriteLine($"[Native Library] {libraryPath} is loaded.");
Console.ResetColor();
return result ?? IntPtr.Zero;
}
else
{
Console.WriteLine($"Tried to load {libraryPath}");
}
}

if (RuntimeInformation.IsOSPlatform(OSPlatform.OSX))
if (!configuration.AllowFallback)
{
return TryLoad("runtimes/macos-arm64/libllama.dylib", System.Runtime.Intrinsics.Arm.ArmBase.Arm64.IsSupported)
?? TryLoad("runtimes/macos-x86_64/libllama.dylib")
?? IntPtr.Zero;
throw new RuntimeError("Failed to load the library that match your rule, please" +
" 1) check your rule." +
" 2) try to allow fallback." +
" 3) or open an issue if it's expected to be successful.");
}
#endif

Expand All @@ -103,6 +305,7 @@ private static IntPtr TryLoadLibrary()
}

private const string libraryName = "libllama";
private const string cudaVersionFile = "version.json";

/// <summary>
/// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded.
Expand Down
Loading

0 comments on commit d03e1db

Please sign in to comment.