Skip to content

Commit

Permalink
Merge pull request #333 from AsakusaRinne/master
Browse files Browse the repository at this point in the history
feat: allow customized search path for native library loading.
  • Loading branch information
AsakusaRinne authored Nov 28, 2023
2 parents 6dfda5e + ffc347a commit 1f97ad8
Show file tree
Hide file tree
Showing 3 changed files with 100 additions and 17 deletions.
24 changes: 17 additions & 7 deletions LLama/ChatSession.cs
Original file line number Diff line number Diff line change
Expand Up @@ -152,14 +152,24 @@ public async IAsyncEnumerable<string> ChatAsync(string prompt, IInferenceParams?
foreach (var inputTransform in InputTransformPipeline)
prompt = inputTransform.Transform(prompt);

History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));

if (_executor is InteractiveExecutor executor)
// TODO: need to be refactored.
if (_executor is InteractiveExecutor executor && ((InteractiveExecutorState)executor.GetStateData()).IsPromptRun)
{
InteractiveExecutorState state = (InteractiveExecutorState)executor.GetStateData();
prompt = state.IsPromptRun
? HistoryTransform.HistoryToText(History)
: prompt;
History.Messages.Add(new ChatHistory.Message(AuthorRole.System, prompt));
var converted_prompt = HistoryTransform.HistoryToText(History);
// Avoid missing anti-prompt.
if (!prompt.EndsWith("\n") && !prompt.EndsWith("\r\n"))
{
prompt = converted_prompt.Trim();
}
else
{
prompt = converted_prompt;
}
}
else
{
History.Messages.Add(new ChatHistory.Message(AuthorRole.User, prompt));
}

StringBuilder sb = new();
Expand Down
20 changes: 13 additions & 7 deletions LLama/Native/NativeApi.Load.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Diagnostics;
using System.IO;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text.Json;

Expand Down Expand Up @@ -258,6 +259,7 @@ private static IntPtr TryLoadLibrary()
enableLogging = configuration.Logging;
// We move the flag to avoid loading library when the variable is called else where.
NativeLibraryConfig.LibraryHasLoaded = true;
Log(configuration.ToString(), LogLevel.Information);

if (!string.IsNullOrEmpty(configuration.Path))
{
Expand All @@ -273,26 +275,30 @@ private static IntPtr TryLoadLibrary()

var libraryTryLoadOrder = GetLibraryTryOrder(configuration);

string[] preferredPaths = configuration.SearchDirectories;
string[] possiblePathPrefix = new string[] {
System.AppDomain.CurrentDomain.BaseDirectory,
Path.GetDirectoryName(System.Reflection.Assembly.GetExecutingAssembly().Location) ?? ""
};

var tryFindPath = (string filename) =>
{
int i = 0;
while (!File.Exists(filename))
foreach(var path in preferredPaths)
{
if (i < possiblePathPrefix.Length)
if (File.Exists(Path.Combine(path, filename)))
{
filename = Path.Combine(possiblePathPrefix[i], filename);
i++;
return Path.Combine(path, filename);
}
else
}
foreach(var path in possiblePathPrefix)
{
if (File.Exists(Path.Combine(path, filename)))
{
break;
return Path.Combine(path, filename);
}
}
return filename;
};

Expand Down
73 changes: 70 additions & 3 deletions LLama/Native/NativeLibraryConfig.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,6 @@
using System;
using System.Collections.Generic;
using System.Linq;

namespace LLama.Native
{
Expand Down Expand Up @@ -27,6 +29,10 @@ public sealed class NativeLibraryConfig
private bool _allowFallback = true;
private bool _skipCheck = false;
private bool _logging = false;
/// <summary>
/// search directory -> priority level, 0 is the lowest.
/// </summary>
private List<string> _searchDirectories = new List<string>();

private static void ThrowIfLoaded()
{
Expand Down Expand Up @@ -120,13 +126,50 @@ public NativeLibraryConfig WithLogs(bool enable = true)
return this;
}

/// <summary>
/// Add self-defined search directories. Note that the file stucture of the added
/// directories must be the same as the default directory. Besides, the directory
/// won't be used recursively.
/// </summary>
/// <param name="directories"></param>
/// <returns></returns>
public NativeLibraryConfig WithSearchDirectories(IEnumerable<string> directories)
{
ThrowIfLoaded();

_searchDirectories.AddRange(directories);
return this;
}

/// <summary>
/// Add self-defined search directories. Note that the file stucture of the added
/// directories must be the same as the default directory. Besides, the directory
/// won't be used recursively.
/// </summary>
/// <param name="directory"></param>
/// <returns></returns>
public NativeLibraryConfig WithSearchDirectory(string directory)
{
ThrowIfLoaded();

_searchDirectories.Add(directory);
return this;
}

internal static Description CheckAndGatherDescription()
{
if (Instance._allowFallback && Instance._skipCheck)
{
throw new ArgumentException("Cannot skip the check when fallback is allowed.");
}
return new Description(Instance._libraryPath, Instance._useCuda, Instance._avxLevel, Instance._allowFallback, Instance._skipCheck, Instance._logging);
return new Description(
Instance._libraryPath,
Instance._useCuda,
Instance._avxLevel,
Instance._allowFallback,
Instance._skipCheck,
Instance._logging,
Instance._searchDirectories.Concat(new string[] { "./" }).ToArray());
}

internal static string AvxLevelToString(AvxLevel level)
Expand Down Expand Up @@ -183,7 +226,31 @@ public enum AvxLevel
Avx512,
}

internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging);
internal record Description(string Path, bool UseCuda, AvxLevel AvxLevel, bool AllowFallback, bool SkipCheck, bool Logging, string[] SearchDirectories)
{
public override string ToString()
{
string avxLevelString = AvxLevel switch
{
AvxLevel.None => "NoAVX",
AvxLevel.Avx => "AVX",
AvxLevel.Avx2 => "AVX2",
AvxLevel.Avx512 => "AVX512",
_ => "Unknown"
};

string searchDirectoriesString = "{ " + string.Join(", ", SearchDirectories) + " }";

return $"NativeLibraryConfig Description:\n" +
$"- Path: {Path}\n" +
$"- PreferCuda: {UseCuda}\n" +
$"- PreferredAvxLevel: {avxLevelString}\n" +
$"- AllowFallback: {AllowFallback}\n" +
$"- SkipCheck: {SkipCheck}\n" +
$"- Logging: {Logging}\n" +
$"- SearchDirectories and Priorities: {searchDirectoriesString}";
}
}
}
#endif
}
}

0 comments on commit 1f97ad8

Please sign in to comment.