Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

May 2024 Binary Update (Take 2) #712

Merged
merged 12 commits into from
May 12, 2024
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/BatchedExecutorGuidance.cs
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ await AnsiConsole
guidance.Prompt(g);

// Early exit if we reach the natural end of the guided sentence
if (g == model.Tokens.EOS)
if (model.Tokens.IsEndOfGeneration(g))
break;

// Update progress bar
Expand Down
9 changes: 7 additions & 2 deletions LLama.Web/Common/ModelOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -29,12 +29,14 @@ public class ModelOptions
/// <inheritdoc />
public int GpuLayerCount { get; set; } = 20;

public uint SeqMax { get; }
/// <inheritdoc />
public uint SeqMax { get; set; }

/// <inheritdoc />
public uint Seed { get; set; } = 1686349486;

public bool Embeddings { get; }
/// <inheritdoc />
public bool Embeddings { get; set; }

/// <inheritdoc />
public bool UseMemorymap { get; set; } = true;
Expand Down Expand Up @@ -102,6 +104,9 @@ public class ModelOptions
/// <inheritdoc />
public bool NoKqvOffload { get; set; }

/// <inheritdoc />
public bool FlashAttention { get; set; }

/// <inheritdoc />
public Encoding Encoding { get; set; } = Encoding.UTF8;

Expand Down
5 changes: 5 additions & 0 deletions LLama/Abstractions/IContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,11 @@ public interface IContextParams
/// </summary>
bool NoKqvOffload { get; }

/// <summary>
/// Whether to use flash attention
/// </summary>
bool FlashAttention { get; }

/// <summary>
/// defragment the KV cache if holes/size &gt; defrag_threshold, Set to &lt; 0 to disable (default)
/// </summary>
Expand Down
24 changes: 24 additions & 0 deletions LLama/Abstractions/IModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.ComponentModel;
using System.Linq;
using System.Text;
using System.Text.Json;
using System.Text.Json.Serialization;
using LLama.Native;
Expand Down Expand Up @@ -241,6 +242,7 @@ public sealed record MetadataOverride
private readonly int _valueInt;
private readonly float _valueFloat;
private readonly bool _valueBool;
private readonly byte[]? _valueString;

/// <summary>
/// Create a new override for an int key
Expand Down Expand Up @@ -278,6 +280,21 @@ public MetadataOverride(string key, bool value)
Type = LLamaModelKvOverrideType.Bool;
}

/// <summary>
/// Create a new override for a string key
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
public MetadataOverride(string key, string value)
{
Key = key;
_valueString = Encoding.UTF8.GetBytes(value);
Type = LLamaModelKvOverrideType.String;

if (_valueString.Length > 128)
throw new ArgumentException("Value string is too long, must be < 128 UTF8 bytes", nameof(value));
}

internal void WriteValue(ref LLamaModelMetadataOverride dest)
{
switch (Type)
Expand All @@ -291,6 +308,13 @@ internal void WriteValue(ref LLamaModelMetadataOverride dest)
case LLamaModelKvOverrideType.Bool:
dest.BoolValue = _valueBool ? -1L : 0;
break;
case LLamaModelKvOverrideType.String:
unsafe
{
fixed (byte* strValPtr = dest.StringValue)
new Span<byte>(_valueString!).CopyTo(new Span<byte>(strValPtr, 128));
}
break;
default:
throw new InvalidEnumArgumentException($"Unknown {nameof(LLamaModelKvOverrideType)} value: {Type}");
}
Expand Down
3 changes: 3 additions & 0 deletions LLama/Common/ModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,9 @@ public record ModelParams
/// <inheritdoc />
public bool NoKqvOffload { get; set; }

/// <inheritdoc />
public bool FlashAttention { get; set; }

/// <inheritdoc />
public float DefragThreshold { get; set; }

Expand Down
1 change: 1 addition & 0 deletions LLama/Extensions/IContextParamsExtensions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public static void ToLlamaContextParams(this IContextParams @params, out LLamaCo
result.type_k = @params.TypeK ?? GGMLType.GGML_TYPE_F16;
result.type_k = @params.TypeV ?? GGMLType.GGML_TYPE_F16;
result.offload_kqv = [email protected];
result.flash_attention = @params.FlashAttention;
result.llama_pooling_type = @params.PoolingType;

result.n_threads = Threads(@params.Threads);
Expand Down
5 changes: 3 additions & 2 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
using System.Collections.Generic;
using System.Linq;
using System.Runtime.CompilerServices;
using System.Text;
using System.Threading;
using LLama.Exceptions;
using LLama.Native;
Expand Down Expand Up @@ -123,8 +124,8 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
);
}

// Check if this is the EOS token
if (id == _weights.Tokens.EOS)
// Check if this token should end generation
if (_weights.Tokens.IsEndOfGeneration(id))
break;

// Decode this token into text
Expand Down
10 changes: 10 additions & 0 deletions LLama/Native/LLamaContextParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,16 @@ public bool offload_kqv
}
private sbyte _offload_kqv;

/// <summary>
/// whether to use flash attention
/// </summary>
public bool flash_attention
{
readonly get => Convert.ToBoolean(_flash_attention);
set => _flash_attention = Convert.ToSByte(value);
}
private sbyte _flash_attention;

//todo: implement abort callback support
/// <summary>
/// ggml_abort_callback
Expand Down
5 changes: 5 additions & 0 deletions LLama/Native/LLamaFtype.cs
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,11 @@ public enum LLamaFtype
/// </summary>
LLAMA_FTYPE_MOSTLY_IQ1_M = 31,

/// <summary>
/// except 1d tensors
/// </summary>
LLAMA_FTYPE_MOSTLY_BF16 = 32,

/// <summary>
/// File type was not specified
/// </summary>
Expand Down
11 changes: 11 additions & 0 deletions LLama/Native/LLamaModelMetadataOverride.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,12 @@ public unsafe struct LLamaModelMetadataOverride
/// </summary>
[FieldOffset(136)]
public long BoolValue;

/// <summary>
/// Value, **must** only be used if Tag == String
/// </summary>
[FieldOffset(136)]
public fixed byte StringValue[128];
}

/// <summary>
Expand All @@ -65,4 +71,9 @@ public enum LLamaModelKvOverrideType
/// Overriding a bool value
/// </summary>
Bool = 2,

/// <summary>
/// Overriding a string value
/// </summary>
String = 3,
}
10 changes: 10 additions & 0 deletions LLama/Native/LLamaModelParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,16 @@ public bool use_mlock
}
private sbyte _use_mlock;

/// <summary>
/// validate model tensor data
/// </summary>
public bool check_tensors
{
readonly get => Convert.ToBoolean(_check_tensors);
set => _check_tensors = Convert.ToSByte(value);
}
private sbyte _check_tensors;

/// <summary>
/// Create a LLamaModelParams with default values
/// </summary>
Expand Down
10 changes: 10 additions & 0 deletions LLama/Native/LLamaModelQuantizeParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,16 @@ public bool pure
}
private sbyte _pure;

/// <summary>
/// quantize to the same number of shards
/// </summary>
public bool keep_split
{
get => Convert.ToBoolean(_keep_split);
set => _keep_split = Convert.ToSByte(value);
}
private sbyte _keep_split;

/// <summary>
/// pointer to importance matrix data
/// </summary>
Expand Down
17 changes: 17 additions & 0 deletions LLama/Native/LLamaVocabPreType.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
namespace LLama.Native;

/// <summary>
///
/// </summary>
/// <remarks>llama_vocab_pre_type</remarks>
internal enum LLamaVocabPreType
{
LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0,
LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2,
LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3,
LLAMA_VOCAB_PRE_TYPE_FALCON = 4,
LLAMA_VOCAB_PRE_TYPE_MPT = 5,
LLAMA_VOCAB_PRE_TYPE_STARCODER = 6,
LLAMA_VOCAB_PRE_TYPE_GPT2 = 7,
}
5 changes: 3 additions & 2 deletions LLama/Native/NativeApi.LLava.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ public static unsafe partial class NativeApi
/// <param name="ctxClip">Llava Model</param>
/// <returns>True if validate successfully</returns>
[DllImport(llavaLibraryName, EntryPoint = "llava_validate_embed_size", CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llava_validate_embed_size( SafeLLamaContextHandle ctxLlama, SafeLlavaModelHandle ctxClip);

/// <summary>
Expand Down Expand Up @@ -56,7 +57,7 @@ SafeLlavaImageEmbedHandle llava_image_embed_make_with_filename(SafeLlavaModelHan
/// <param name="embed">Embedding handle</param>
/// <returns>True on success</returns>
[DllImport(llavaLibraryName, EntryPoint = "llava_eval_image_embed", CallingConvention = CallingConvention.Cdecl)]
public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed,
int n_batch, ref int n_past);
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llava_eval_image_embed(SafeLLamaContextHandle ctx_llama, SafeLlavaImageEmbedHandle embed, int n_batch, ref int n_past);

}
2 changes: 1 addition & 1 deletion LLama/Native/NativeApi.Sampling.cs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,7 @@ public static void llama_sample_apply_guidance(SafeLLamaContextHandle ctx, Span<
public static extern LLamaToken llama_sample_token_greedy(SafeLLamaContextHandle ctx, ref LLamaTokenDataArrayNative candidates);

/// <summary>
/// Randomly selects a token from the candidates based on their probabilities.
/// Randomly selects a token from the candidates based on their probabilities using the RNG of ctx.
/// </summary>
/// <param name="ctx"></param>
/// <param name="candidates">Pointer to LLamaTokenDataArray</param>
Expand Down
23 changes: 19 additions & 4 deletions LLama/Native/NativeApi.cs
Original file line number Diff line number Diff line change
Expand Up @@ -34,20 +34,23 @@ public static void llama_empty_call()
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_supports_mmap();

/// <summary>
/// Check if memory locking is supported
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_supports_mlock();

/// <summary>
/// Check if GPU offload is supported
/// </summary>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_supports_gpu_offload();

/// <summary>
Expand Down Expand Up @@ -77,6 +80,7 @@ public static void llama_empty_call()
/// <param name="n_token_count_out"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_state_load_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens_out, ulong n_token_capacity, out ulong n_token_count_out);

/// <summary>
Expand All @@ -88,6 +92,7 @@ public static void llama_empty_call()
/// <param name="n_token_count"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_state_save_file(SafeLLamaContextHandle ctx, string path_session, LLamaToken[] tokens, ulong n_token_count);

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
Expand Down Expand Up @@ -133,6 +138,14 @@ public static void llama_empty_call()
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern uint llama_n_seq_max(SafeLLamaContextHandle ctx);

/// <summary>
/// Get the pooling type for this context
/// </summary>
/// <param name="ctx"></param>
/// <returns></returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
public static extern LLamaPoolingType llama_pooling_type(SafeLLamaContextHandle ctx);

/// <summary>
/// Get the embeddings for the a specific sequence.
/// Equivalent to: llama_get_embeddings(ctx) + ctx->output_ids[i]*n_embd
Expand Down Expand Up @@ -218,19 +231,20 @@ public static void llama_empty_call()
/// <param name="model"></param>
/// <param name="llamaToken"></param>
/// <param name="buffer">buffer to write string into</param>
/// <param name="special">If true, special tokens are rendered in the output</param>
/// <returns>The length written, or if the buffer is too small a negative that indicates the length required</returns>
public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken llamaToken, Span<byte> buffer)
public static int llama_token_to_piece(SafeLlamaModelHandle model, LLamaToken llamaToken, Span<byte> buffer, bool special)
{
unsafe
{
fixed (byte* bufferPtr = buffer)
{
return llama_token_to_piece_native(model, llamaToken, bufferPtr, buffer.Length);
return llama_token_to_piece_native(model, llamaToken, bufferPtr, buffer.Length, special);
}
}

[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl, EntryPoint = "llama_token_to_piece")]
static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, LLamaToken llamaToken, byte* buffer, int length);
static extern unsafe int llama_token_to_piece_native(SafeLlamaModelHandle model, LLamaToken llamaToken, byte* buffer, int length, bool special);
}

/// <summary>
Expand Down Expand Up @@ -260,7 +274,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
}

/// <summary>
/// Clear the KV cache
/// Clear the KV cache. Both cell info is erased and KV data is zeroed
/// </summary>
/// <param name="ctx"></param>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
Expand All @@ -275,6 +289,7 @@ public static void llama_log_set(NativeLogConfig.LLamaLogCallback logCallback)
/// <param name="p1"></param>
/// <returns>Returns false if a partial sequence cannot be removed. Removing a whole sequence never fails</returns>
[DllImport(libraryName, CallingConvention = CallingConvention.Cdecl)]
[return: MarshalAs(UnmanagedType.U1)]
public static extern bool llama_kv_cache_seq_rm(SafeLLamaContextHandle ctx, LLamaSeqId seq, LLamaPos p0, LLamaPos p1);

/// <summary>
Expand Down
Loading
Loading