Skip to content

Commit

Permalink
Merge pull request #380 from martindevans/LLamaWeights.Metadata_Property
Browse files Browse the repository at this point in the history
Added `LLamaWeights.Metadata` property
  • Loading branch information
martindevans authored Dec 21, 2023
2 parents 4635185 + fb606c2 commit 9b1ff0b
Show file tree
Hide file tree
Showing 4 changed files with 109 additions and 24 deletions.
33 changes: 10 additions & 23 deletions LLama.Unittest/BasicTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ public void BasicModelProperties()
[Fact]
public void AdvancedModelProperties()
{
// These are the keys in the llama 7B test model. This will need changing if
// tests are switched to use a new model!
var expected = new Dictionary<string, string>
{
{ "general.name", "LLaMA v2" },
Expand All @@ -60,31 +62,16 @@ public void AdvancedModelProperties()
{ "tokenizer.ggml.unknown_token_id", "0" },
};

var metaCount = NativeApi.llama_model_meta_count(_model.NativeHandle);
Assert.Equal(expected.Count, metaCount);
// Print all keys
foreach (var (key, value) in _model.Metadata)
_testOutputHelper.WriteLine($"{key} = {value}");

Span<byte> buffer = stackalloc byte[128];
for (var i = 0; i < expected.Count; i++)
{
unsafe
{
fixed (byte* ptr = buffer)
{
var length = NativeApi.llama_model_meta_key_by_index(_model.NativeHandle, i, ptr, 128);
Assert.True(length > 0);
var key = Encoding.UTF8.GetString(buffer[..length]);

length = NativeApi.llama_model_meta_val_str_by_index(_model.NativeHandle, i, ptr, 128);
Assert.True(length > 0);
var val = Encoding.UTF8.GetString(buffer[..length]);

_testOutputHelper.WriteLine($"{key} == {val}");
// Check the count is equal
Assert.Equal(expected.Count, _model.Metadata.Count);

Assert.True(expected.ContainsKey(key));
Assert.Equal(expected[key], val);
}
}
}
// Check every key
foreach (var (key, value) in _model.Metadata)
Assert.Equal(expected[key], value);
}
}
}
13 changes: 12 additions & 1 deletion LLama/Extensions/EncodingExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ public static int GetCharCount(this Encoding encoding, ReadOnlySpan<byte> bytes)
return GetCharCountImpl(encoding, bytes);
}
#elif !NET6_0_OR_GREATER && !NETSTANDARD2_1_OR_GREATER
#error Target framework not supported!
#error Target framework not supported!
#endif

internal static int GetCharsImpl(Encoding encoding, ReadOnlySpan<byte> bytes, Span<char> output)
Expand Down Expand Up @@ -47,4 +47,15 @@ internal static int GetCharCountImpl(Encoding encoding, ReadOnlySpan<byte> bytes
}
}
}

internal static string GetStringFromSpan(this Encoding encoding, ReadOnlySpan<byte> bytes)
{
unsafe
{
fixed (byte* bytesPtr = bytes)
{
return encoding.GetString(bytesPtr, bytes.Length);
}
}
}
}
7 changes: 7 additions & 0 deletions LLama/LLamaWeights.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using System.Collections.Generic;
using LLama.Abstractions;
using LLama.Extensions;
using LLama.Native;
Expand Down Expand Up @@ -58,9 +59,15 @@ public sealed class LLamaWeights
/// </summary>
public int EmbeddingSize => NativeHandle.EmbeddingSize;

/// <summary>
/// All metadata keys in this model
/// </summary>
public IReadOnlyDictionary<string, string> Metadata { get; set; }

internal LLamaWeights(SafeLlamaModelHandle weights)
{
NativeHandle = weights;
Metadata = weights.ReadMetadata();
}

/// <summary>
Expand Down
80 changes: 80 additions & 0 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Text;
using LLama.Exceptions;
using LLama.Extensions;
using EncodingExtensions = LLama.Extensions.EncodingExtensions;

namespace LLama.Native
{
Expand Down Expand Up @@ -36,6 +39,12 @@ public sealed class SafeLlamaModelHandle
/// </summary>
public ulong ParameterCount { get; }

/// <summary>
/// Get the number of metadata key/value pairs
/// </summary>
/// <returns></returns>
public int MetadataCount { get; }

internal SafeLlamaModelHandle(IntPtr handle)
: base(handle)
{
Expand All @@ -44,6 +53,7 @@ internal SafeLlamaModelHandle(IntPtr handle)
EmbeddingSize = NativeApi.llama_n_embd(this);
SizeInBytes = NativeApi.llama_model_size(this);
ParameterCount = NativeApi.llama_model_n_params(this);
MetadataCount = NativeApi.llama_model_meta_count(this);
}

/// <inheritdoc />
Expand Down Expand Up @@ -199,5 +209,75 @@ public SafeLLamaContextHandle CreateContext(LLamaContextParams @params)
return SafeLLamaContextHandle.Create(this, @params);
}
#endregion

#region metadata
/// <summary>
/// Get the metadata key for the given index
/// </summary>
/// <param name="index">The index to get</param>
/// <param name="buffer">A temporary buffer to store key characters in. Must be large enough to contain the key.</param>
/// <returns>The key, null if there is no such key or if the buffer was too small</returns>
public Memory<byte>? MetadataKeyByIndex(int index, Memory<byte> buffer)
{
unsafe
{
using var pin = buffer.Pin();
var keyLength = NativeApi.llama_model_meta_key_by_index(this, index, (byte*)pin.Pointer, buffer.Length);
if (keyLength < 0)
return null;
return buffer.Slice(0, keyLength);
}
}

/// <summary>
/// Get the metadata value for the given index
/// </summary>
/// <param name="index">The index to get</param>
/// <param name="buffer">A temporary buffer to store value characters in. Must be large enough to contain the value.</param>
/// <returns>The value, null if there is no such value or if the buffer was too small</returns>
public Memory<byte>? MetadataValueByIndex(int index, Memory<byte> buffer)
{
unsafe
{
using var pin = buffer.Pin();
var keyLength = NativeApi.llama_model_meta_val_str_by_index(this, index, (byte*)pin.Pointer, buffer.Length);
if (keyLength < 0)
return null;
return buffer.Slice(0, keyLength);
}
}

internal IReadOnlyDictionary<string, string> ReadMetadata()
{
var result = new Dictionary<string, string>();

var dest = ArrayPool<byte>.Shared.Rent(1024);
try
{
for (var i = 0; i < MetadataCount; i++)
{
Array.Clear(dest, 0, dest.Length);

var keyBytes = MetadataKeyByIndex(i, dest.AsMemory());
if (keyBytes == null)
continue;
var key = Encoding.UTF8.GetStringFromSpan(keyBytes.Value.Span);

var valBytes = MetadataValueByIndex(i, dest.AsMemory());
if (valBytes == null)
continue;
var val = Encoding.UTF8.GetStringFromSpan(valBytes.Value.Span);

result[key] = val;
}
}
finally
{
ArrayPool<byte>.Shared.Return(dest);
}

return result;
}
#endregion
}
}

0 comments on commit 9b1ff0b

Please sign in to comment.