Skip to content

Commit

Permalink
Merge pull request #417 from martindevans/safe_handle_initialisation
Browse files Browse the repository at this point in the history
Safer Handle Initialisation
  • Loading branch information
martindevans authored Jan 7, 2024
2 parents c696cda + 1e69e26 commit 9573e2c
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 62 deletions.
2 changes: 1 addition & 1 deletion LLama/Native/NativeApi.Load.cs
Original file line number Diff line number Diff line change
Expand Up @@ -329,7 +329,7 @@ string TryFindPath(string filename)
#endif
}

private const string libraryName = "libllama";
internal const string libraryName = "libllama";
private const string cudaVersionFile = "version.json";
private const string loggingPrefix = "[LLamaSharp Native]";
private static bool enableLogging = false;
Expand Down
39 changes: 4 additions & 35 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,10 @@ public static partial class NativeApi
/// A method that does nothing. This is a native method, calling it will force the llama native dependencies to be loaded.
/// </summary>
/// <returns></returns>
[DllImport(libraryName, EntryPoint = "llama_mmap_supported", CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_empty_call();
public static void llama_empty_call()
{
llama_mmap_supported();
}

/// <summary>
/// Get the maximum number of devices supported by llama.cpp
Expand Down Expand Up @@ -70,46 +72,13 @@ public static partial class NativeApi
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern bool llama_mlock_supported();

/// <summary>
/// Load all of the weights of a model into memory.
/// </summary>
/// <param name="path_model"></param>
/// <param name="params"></param>
/// <returns>The loaded model, or null on failure.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);

/// <summary>
/// Create a new llama_context with the given model.
/// Return value should always be wrapped in SafeLLamaContextHandle!
/// </summary>
/// <param name="model"></param>
/// <param name="params"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params);

/// <summary>
/// Initialize the llama + ggml backend
/// Call once at the start of the program
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_backend_init(bool numa);

/// <summary>
/// Frees all allocated memory in the given llama_context
/// </summary>
/// <param name="ctx"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free(IntPtr ctx);

/// <summary>
/// Frees all allocated memory associated with a model
/// </summary>
/// <param name="model"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_free_model(IntPtr model);

/// <summary>
/// Apply a LoRA adapter to a loaded model
/// path_base_model is the path to a higher quality model to use as a base for
Expand Down
56 changes: 36 additions & 20 deletions LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using System;
using System.Buffers;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;

Expand All @@ -8,6 +9,7 @@ namespace LLama.Native
/// <summary>
/// A safe wrapper around a llama_context
/// </summary>
// ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API)
public sealed class SafeLLamaContextHandle
: SafeLLamaHandleBase
{
Expand Down Expand Up @@ -36,26 +38,10 @@ public sealed class SafeLLamaContextHandle
#endregion

#region construction/destruction
/// <summary>
/// Create a new SafeLLamaContextHandle
/// </summary>
/// <param name="handle">pointer to an allocated llama_context</param>
/// <param name="model">the model which this context was created from</param>
public SafeLLamaContextHandle(IntPtr handle, SafeLlamaModelHandle model)
: base(handle)
{
// Increment the model reference count while this context exists
_model = model;
var success = false;
_model.DangerousAddRef(ref success);
if (!success)
throw new RuntimeError("Failed to increment model refcount");
}

/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_free(DangerousGetHandle());
llama_free(handle);
SetHandle(IntPtr.Zero);

// Decrement refcount on model
Expand Down Expand Up @@ -84,12 +70,42 @@ private SafeLlamaModelHandle ThrowIfDisposed()
/// <exception cref="RuntimeError"></exception>
public static SafeLLamaContextHandle Create(SafeLlamaModelHandle model, LLamaContextParams lparams)
{
var ctx_ptr = NativeApi.llama_new_context_with_model(model, lparams);
if (ctx_ptr == IntPtr.Zero)
var ctx = llama_new_context_with_model(model, lparams);
if (ctx == null)
throw new RuntimeError("Failed to create context from model");

return new(ctx_ptr, model);
// Increment the model reference count while this context exists.
// DangerousAddRef throws if it fails, so there is no need to check "success"
ctx._model = model;
var success = false;
ctx._model.DangerousAddRef(ref success);

return ctx;
}
#endregion

#region Native API
static SafeLLamaContextHandle()
{
// This ensures that `NativeApi` has been loaded before calling the two native methods below
NativeApi.llama_empty_call();
}

/// <summary>
/// Create a new llama_context with the given model. **This should never be called directly! Always use SafeLLamaContextHandle.Create**!
/// </summary>
/// <param name="model"></param>
/// <param name="params"></param>
/// <returns></returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern SafeLLamaContextHandle llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params);

/// <summary>
/// Frees all allocated memory in the given llama_context
/// </summary>
/// <param name="ctx"></param>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_free(IntPtr ctx);
#endregion

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion LLama/Native/SafeLLamaHandleBase.cs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,6 @@ private protected SafeLLamaHandleBase(IntPtr handle, bool ownsHandle)

/// <inheritdoc />
public override string ToString()
=> $"0x{handle.ToString("x16")}";
=> $"0x{handle:x16}";
}
}
35 changes: 30 additions & 5 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System.Collections.Generic;
using System.Diagnostics;
using System.Runtime.InteropServices;
using System.Text;
using LLama.Exceptions;
using LLama.Extensions;
Expand All @@ -10,6 +11,7 @@ namespace LLama.Native
/// <summary>
/// A reference to a set of llama model weights
/// </summary>
// ReSharper disable once ClassNeverInstantiated.Global (used implicitly in native API)
public sealed class SafeLlamaModelHandle
: SafeLLamaHandleBase
{
Expand Down Expand Up @@ -47,8 +49,7 @@ public sealed class SafeLlamaModelHandle
/// <inheritdoc />
protected override bool ReleaseHandle()
{
NativeApi.llama_free_model(DangerousGetHandle());
SetHandle(IntPtr.Zero);
llama_free_model(handle);
return true;
}

Expand All @@ -61,13 +62,37 @@ protected override bool ReleaseHandle()
/// <exception cref="RuntimeError"></exception>
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams)
{
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
if (model_ptr == null)
var model = llama_load_model_from_file(modelPath, lparams);
if (model == null)
throw new RuntimeError($"Failed to load model {modelPath}.");

return model_ptr;
return model;
}

#region native API
static SafeLlamaModelHandle()
{
// This ensures that `NativeApi` has been loaded before calling the two native methods below
NativeApi.llama_empty_call();
}

/// <summary>
/// Load all of the weights of a model into memory.
/// </summary>
/// <param name="path_model"></param>
/// <param name="params"></param>
/// <returns>The loaded model, or null on failure.</returns>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);

/// <summary>
/// Frees all allocated memory associated with a model
/// </summary>
/// <param name="model"></param>
[DllImport(NativeApi.libraryName, CallingConvention = CallingConvention.Cdecl)]
private static extern void llama_free_model(IntPtr model);
#endregion

#region LoRA

/// <summary>
Expand Down

0 comments on commit 9573e2c

Please sign in to comment.