diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs index 0b38a04d8..89d94ade3 100644 --- a/LLama.Web/Common/InferenceOptions.cs +++ b/LLama.Web/Common/InferenceOptions.cs @@ -4,93 +4,61 @@ namespace LLama.Web.Common { - public class InferenceOptions : IInferenceParams + public class InferenceOptions + : IInferenceParams { - /// - /// number of tokens to keep from initial prompt - /// + /// public int TokensKeep { get; set; } = 0; - /// - /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response - /// until it complete. - /// + + /// public int MaxTokens { get; set; } = -1; - /// - /// logit bias for specific tokens - /// + + /// public Dictionary? LogitBias { get; set; } = null; - /// - /// Sequences where the model will stop generating further tokens. - /// + /// public IReadOnlyList AntiPrompts { get; set; } = Array.Empty(); - /// - /// path to file for saving/loading model eval state - /// - public string PathSession { get; set; } = string.Empty; - /// - /// string to suffix user inputs with - /// - public string InputSuffix { get; set; } = string.Empty; - /// - /// string to prefix user inputs with - /// - public string InputPrefix { get; set; } = string.Empty; - /// - /// 0 or lower to use vocab size - /// + + /// public int TopK { get; set; } = 40; - /// - /// 1.0 = disabled - /// + + /// public float TopP { get; set; } = 0.95f; - /// - /// 1.0 = disabled - /// + + /// + public float MinP { get; set; } = 0.05f; + + /// public float TfsZ { get; set; } = 1.0f; - /// - /// 1.0 = disabled - /// + + /// public float TypicalP { get; set; } = 1.0f; - /// - /// 1.0 = disabled - /// + + /// public float Temperature { get; set; } = 0.8f; - /// - /// 1.0 = disabled - /// + + /// public float RepeatPenalty { get; set; } = 1.1f; - /// - /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) - /// + + /// public int RepeatLastTokensCount { get; set; } = 64; - /// - /// frequency penalty coefficient - /// 0.0 = disabled - /// + + /// public float FrequencyPenalty { get; set; } = .0f; - /// - /// presence penalty coefficient - /// 0.0 = disabled - /// + + /// public float PresencePenalty { get; set; } = .0f; - /// - /// Mirostat uses tokens instead of words. - /// algorithm described in the paper https://arxiv.org/abs/2007.14966. - /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - /// + + /// public MirostatType Mirostat { get; set; } = MirostatType.Disable; - /// - /// target entropy - /// + + /// public float MirostatTau { get; set; } = 5.0f; - /// - /// learning rate - /// + + /// public float MirostatEta { get; set; } = 0.1f; - /// - /// consider newlines as a repeatable token (penalize_nl) - /// + + /// public bool PenalizeNL { get; set; } = true; /// diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs index a21c7306b..d87faf0eb 100644 --- a/LLama/Abstractions/IInferenceParams.cs +++ b/LLama/Abstractions/IInferenceParams.cs @@ -25,7 +25,6 @@ public interface IInferenceParams /// public Dictionary? LogitBias { get; set; } - /// /// Sequences where the model will stop generating further tokens. /// @@ -41,10 +40,15 @@ public interface IInferenceParams /// public float TopP { get; set; } - /// - /// 1.0 = disabled - /// - public float TfsZ { get; set; } + /// llama_eval + /// 0.0 = disabled + /// + public float MinP { get; set; } + + /// + /// 1.0 = disabled + /// + public float TfsZ { get; set; } /// /// 1.0 = disabled diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index d0217e2f8..d7bd19d96 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -6,10 +6,12 @@ namespace LLama.Common { using llama_token = Int32; + /// /// The paramters used for inference. /// - public record InferenceParams : IInferenceParams + public record InferenceParams + : IInferenceParams { /// /// number of tokens to keep from initial prompt @@ -30,66 +32,49 @@ public record InferenceParams : IInferenceParams /// public IReadOnlyList AntiPrompts { get; set; } = Array.Empty(); - /// - /// 0 or lower to use vocab size - /// + /// public int TopK { get; set; } = 40; - /// - /// 1.0 = disabled - /// + + /// public float TopP { get; set; } = 0.95f; - /// - /// 1.0 = disabled - /// + + /// + public float MinP { get; set; } = 0.05f; + + /// public float TfsZ { get; set; } = 1.0f; - /// - /// 1.0 = disabled - /// + + /// public float TypicalP { get; set; } = 1.0f; - /// - /// 1.0 = disabled - /// + + /// public float Temperature { get; set; } = 0.8f; - /// - /// 1.0 = disabled - /// + + /// public float RepeatPenalty { get; set; } = 1.1f; - /// - /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n) - /// + + /// public int RepeatLastTokensCount { get; set; } = 64; - /// - /// frequency penalty coefficient - /// 0.0 = disabled - /// + + /// public float FrequencyPenalty { get; set; } = .0f; - /// - /// presence penalty coefficient - /// 0.0 = disabled - /// + + /// public float PresencePenalty { get; set; } = .0f; - /// - /// Mirostat uses tokens instead of words. - /// algorithm described in the paper https://arxiv.org/abs/2007.14966. - /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0 - /// + + /// public MirostatType Mirostat { get; set; } = MirostatType.Disable; - /// - /// target entropy - /// + + /// public float MirostatTau { get; set; } = 5.0f; - /// - /// learning rate - /// + + /// public float MirostatEta { get; set; } = 0.1f; - /// - /// consider newlines as a repeatable token (penalize_nl) - /// + + /// public bool PenalizeNL { get; set; } = true; - /// - /// A grammar to constrain the possible tokens - /// + /// public SafeLLamaGrammarHandle? Grammar { get; set; } } diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs index 0ccae76b4..b64befd8b 100644 --- a/LLama/LLamaContext.cs +++ b/LLama/LLamaContext.cs @@ -226,10 +226,11 @@ public void LoadState(State state) /// /// /// + /// /// - public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable, - float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f, - SafeLLamaGrammarHandle? grammar = null) + public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat, + float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP, + SafeLLamaGrammarHandle? grammar, float minP) { llama_token id; @@ -264,6 +265,7 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu candidates.TailFree(NativeHandle, tfsZ); candidates.LocallyTypical(NativeHandle, typicalP); candidates.TopP(NativeHandle, topP); + candidates.MinP(NativeHandle, minP); candidates.Temperature(NativeHandle, temperature); id = candidates.SampleToken(NativeHandle); } diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs index 80c6f5420..33cbd23e9 100644 --- a/LLama/LLamaInstructExecutor.cs +++ b/LLama/LLamaInstructExecutor.cs @@ -216,8 +216,8 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta var mu = MirostatMu; var id = Context.Sample( tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, - inferenceParams.Grammar + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP ); MirostatMu = mu; diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs index 5078648b1..98b45814c 100644 --- a/LLama/LLamaInteractExecutor.cs +++ b/LLama/LLamaInteractExecutor.cs @@ -194,9 +194,9 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In var mu = MirostatMu; var id = Context.Sample( - tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, - inferenceParams.Grammar + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP ); MirostatMu = mu; diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs index 4bdeaa3f2..9c41af7c0 100644 --- a/LLama/LLamaStatelessExecutor.cs +++ b/LLama/LLamaStatelessExecutor.cs @@ -90,8 +90,11 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL); // Sample a single token - var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, - inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar); + var id = Context.Sample( + tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau, + inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar, + inferenceParams.MinP + ); // Decode this token into text decoder.Add(id); diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index 8f20a73ac..4bc154f4c 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -91,6 +91,21 @@ public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) } } + /// + /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841 + /// + /// + /// All tokens with probability greater than this will be kept + /// + public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1) + { + using (LLamaTokenDataArrayNative.Create(this, out var st)) + { + NativeApi.llama_sample_min_p(context, ref st, p, minKeep); + sorted = st.sorted; + } + } + /// /// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/. ///