Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Safer Model Handle Creation #402

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 17 additions & 10 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,13 @@ public static partial class NativeApi
public static extern bool llama_mlock_supported();

/// <summary>
/// Various functions for loading a ggml llama model.
/// Allocate (almost) all memory needed for the model.
/// Return NULL on failure
/// Load all of the weights of a model into memory.
/// </summary>
/// <param name="path_model"></param>
/// <param name="params"></param>
/// <returns></returns>
/// <returns>The loaded model, or null on failure.</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern IntPtr llama_load_model_from_file(string path_model, LLamaModelParams @params);
public static extern SafeLlamaModelHandle llama_load_model_from_file(string path_model, LLamaModelParams @params);

/// <summary>
/// Create a new llama_context with the given model.
Expand All @@ -92,12 +90,11 @@ public static partial class NativeApi
public static extern IntPtr llama_new_context_with_model(SafeLlamaModelHandle model, LLamaContextParams @params);

/// <summary>
/// not great API - very likely to change.
/// Initialize the llama + ggml backend
/// Call once at the start of the program
/// </summary>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern void llama_backend_init(bool numa);
private static extern void llama_backend_init(bool numa);

/// <summary>
/// Frees all allocated memory in the given llama_context
Expand Down Expand Up @@ -510,10 +507,20 @@ public static int llama_model_meta_val_str_by_index(SafeLlamaModelHandle model,
/// <param name="model"></param>
/// <param name="llamaToken"></param>
/// <param name="buffer">buffer to write string into</param>
/// <param name="length">size of the buffer</param>
/// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern unsafe int llama_token_to_piece(SafeLlamaModelHandle model, int llamaToken, byte* buffer, int length);
public static int llama_token_to_piece(SafeLlamaModelHandle model, llama_token llamaToken, Span<byte> buffer)
{
unsafe
{
fixed (byte* bufferPtr = buffer)
{
return llama_token_to_piece_native(model, llamaToken, bufferPtr, buffer.Length);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, llama_token llamaToken, byte* buffer, int length);
}

/// <summary>
/// Convert text into tokens
Expand Down
37 changes: 10 additions & 27 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,44 +16,33 @@ public sealed class SafeLlamaModelHandle
/// <summary>
/// Total number of tokens in vocabulary of this model
/// </summary>
public int VocabCount { get; }
public int VocabCount => NativeApi.llama_n_vocab(this);

/// <summary>
/// Total number of tokens in the context
/// </summary>
public int ContextSize { get; }
public int ContextSize => NativeApi.llama_n_ctx_train(this);

/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize { get; }
public int EmbeddingSize => NativeApi.llama_n_embd(this);

/// <summary>
/// Get the size of this model in bytes
/// </summary>
public ulong SizeInBytes { get; }
public ulong SizeInBytes => NativeApi.llama_model_size(this);

/// <summary>
/// Get the number of parameters in this model
/// </summary>
public ulong ParameterCount { get; }
public ulong ParameterCount => NativeApi.llama_model_n_params(this);

/// <summary>
/// Get the number of metadata key/value pairs
/// </summary>
/// <returns></returns>
public int MetadataCount { get; }

internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
{
VocabCount = NativeApi.llama_n_vocab(this);
ContextSize = NativeApi.llama_n_ctx_train(this);
EmbeddingSize = NativeApi.llama_n_embd(this);
SizeInBytes = NativeApi.llama_model_size(this);
ParameterCount = NativeApi.llama_model_n_params(this);
MetadataCount = NativeApi.llama_model_meta_count(this);
}
public int MetadataCount => NativeApi.llama_model_meta_count(this);

/// <inheritdoc />
protected override bool ReleaseHandle()
Expand All @@ -73,10 +62,10 @@ protected override bool ReleaseHandle()
public static SafeLlamaModelHandle LoadFromFile(string modelPath, LLamaModelParams lparams)
{
var model_ptr = NativeApi.llama_load_model_from_file(modelPath, lparams);
if (model_ptr == IntPtr.Zero)
if (model_ptr == null)
throw new RuntimeError($"Failed to load model {modelPath}.");

return new SafeLlamaModelHandle(model_ptr);
return model_ptr;
}

#region LoRA
Expand Down Expand Up @@ -114,14 +103,8 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(int llama_token, Span<byte> dest)
{
unsafe
{
fixed (byte* destPtr = dest)
{
var length = NativeApi.llama_token_to_piece(this, llama_token, destPtr, dest.Length);
return Math.Abs(length);
}
}
var length = NativeApi.llama_token_to_piece(this, llama_token, dest);
return Math.Abs(length);
}

/// <summary>
Expand Down
Loading