diff --git a/LLama.Web/Common/InferenceOptions.cs b/LLama.Web/Common/InferenceOptions.cs
index 0b38a04d8..89d94ade3 100644
--- a/LLama.Web/Common/InferenceOptions.cs
+++ b/LLama.Web/Common/InferenceOptions.cs
@@ -4,93 +4,61 @@
namespace LLama.Web.Common
{
- public class InferenceOptions : IInferenceParams
+ public class InferenceOptions
+ : IInferenceParams
{
- ///
- /// number of tokens to keep from initial prompt
- ///
+ ///
public int TokensKeep { get; set; } = 0;
- ///
- /// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
- /// until it complete.
- ///
+
+ ///
public int MaxTokens { get; set; } = -1;
- ///
- /// logit bias for specific tokens
- ///
+
+ ///
public Dictionary? LogitBias { get; set; } = null;
- ///
- /// Sequences where the model will stop generating further tokens.
- ///
+ ///
public IReadOnlyList AntiPrompts { get; set; } = Array.Empty();
- ///
- /// path to file for saving/loading model eval state
- ///
- public string PathSession { get; set; } = string.Empty;
- ///
- /// string to suffix user inputs with
- ///
- public string InputSuffix { get; set; } = string.Empty;
- ///
- /// string to prefix user inputs with
- ///
- public string InputPrefix { get; set; } = string.Empty;
- ///
- /// 0 or lower to use vocab size
- ///
+
+ ///
public int TopK { get; set; } = 40;
- ///
- /// 1.0 = disabled
- ///
+
+ ///
public float TopP { get; set; } = 0.95f;
- ///
- /// 1.0 = disabled
- ///
+
+ ///
+ public float MinP { get; set; } = 0.05f;
+
+ ///
public float TfsZ { get; set; } = 1.0f;
- ///
- /// 1.0 = disabled
- ///
+
+ ///
public float TypicalP { get; set; } = 1.0f;
- ///
- /// 1.0 = disabled
- ///
+
+ ///
public float Temperature { get; set; } = 0.8f;
- ///
- /// 1.0 = disabled
- ///
+
+ ///
public float RepeatPenalty { get; set; } = 1.1f;
- ///
- /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
- ///
+
+ ///
public int RepeatLastTokensCount { get; set; } = 64;
- ///
- /// frequency penalty coefficient
- /// 0.0 = disabled
- ///
+
+ ///
public float FrequencyPenalty { get; set; } = .0f;
- ///
- /// presence penalty coefficient
- /// 0.0 = disabled
- ///
+
+ ///
public float PresencePenalty { get; set; } = .0f;
- ///
- /// Mirostat uses tokens instead of words.
- /// algorithm described in the paper https://arxiv.org/abs/2007.14966.
- /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
- ///
+
+ ///
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
- ///
- /// target entropy
- ///
+
+ ///
public float MirostatTau { get; set; } = 5.0f;
- ///
- /// learning rate
- ///
+
+ ///
public float MirostatEta { get; set; } = 0.1f;
- ///
- /// consider newlines as a repeatable token (penalize_nl)
- ///
+
+ ///
public bool PenalizeNL { get; set; } = true;
///
diff --git a/LLama/Abstractions/IInferenceParams.cs b/LLama/Abstractions/IInferenceParams.cs
index a21c7306b..d87faf0eb 100644
--- a/LLama/Abstractions/IInferenceParams.cs
+++ b/LLama/Abstractions/IInferenceParams.cs
@@ -25,7 +25,6 @@ public interface IInferenceParams
///
public Dictionary? LogitBias { get; set; }
-
///
/// Sequences where the model will stop generating further tokens.
///
@@ -41,10 +40,15 @@ public interface IInferenceParams
///
public float TopP { get; set; }
- ///
- /// 1.0 = disabled
- ///
- public float TfsZ { get; set; }
+ /// llama_eval
+ /// 0.0 = disabled
+ ///
+ public float MinP { get; set; }
+
+ ///
+ /// 1.0 = disabled
+ ///
+ public float TfsZ { get; set; }
///
/// 1.0 = disabled
diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs
index d0217e2f8..d7bd19d96 100644
--- a/LLama/Common/InferenceParams.cs
+++ b/LLama/Common/InferenceParams.cs
@@ -6,10 +6,12 @@
namespace LLama.Common
{
using llama_token = Int32;
+
///
/// The paramters used for inference.
///
- public record InferenceParams : IInferenceParams
+ public record InferenceParams
+ : IInferenceParams
{
///
/// number of tokens to keep from initial prompt
@@ -30,66 +32,49 @@ public record InferenceParams : IInferenceParams
///
public IReadOnlyList AntiPrompts { get; set; } = Array.Empty();
- ///
- /// 0 or lower to use vocab size
- ///
+ ///
public int TopK { get; set; } = 40;
- ///
- /// 1.0 = disabled
- ///
+
+ ///
public float TopP { get; set; } = 0.95f;
- ///
- /// 1.0 = disabled
- ///
+
+ ///
+ public float MinP { get; set; } = 0.05f;
+
+ ///
public float TfsZ { get; set; } = 1.0f;
- ///
- /// 1.0 = disabled
- ///
+
+ ///
public float TypicalP { get; set; } = 1.0f;
- ///
- /// 1.0 = disabled
- ///
+
+ ///
public float Temperature { get; set; } = 0.8f;
- ///
- /// 1.0 = disabled
- ///
+
+ ///
public float RepeatPenalty { get; set; } = 1.1f;
- ///
- /// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
- ///
+
+ ///
public int RepeatLastTokensCount { get; set; } = 64;
- ///
- /// frequency penalty coefficient
- /// 0.0 = disabled
- ///
+
+ ///
public float FrequencyPenalty { get; set; } = .0f;
- ///
- /// presence penalty coefficient
- /// 0.0 = disabled
- ///
+
+ ///
public float PresencePenalty { get; set; } = .0f;
- ///
- /// Mirostat uses tokens instead of words.
- /// algorithm described in the paper https://arxiv.org/abs/2007.14966.
- /// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
- ///
+
+ ///
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
- ///
- /// target entropy
- ///
+
+ ///
public float MirostatTau { get; set; } = 5.0f;
- ///
- /// learning rate
- ///
+
+ ///
public float MirostatEta { get; set; } = 0.1f;
- ///
- /// consider newlines as a repeatable token (penalize_nl)
- ///
+
+ ///
public bool PenalizeNL { get; set; } = true;
- ///
- /// A grammar to constrain the possible tokens
- ///
+ ///
public SafeLLamaGrammarHandle? Grammar { get; set; }
}
diff --git a/LLama/LLamaContext.cs b/LLama/LLamaContext.cs
index 0ccae76b4..b64befd8b 100644
--- a/LLama/LLamaContext.cs
+++ b/LLama/LLamaContext.cs
@@ -226,10 +226,11 @@ public void LoadState(State state)
///
///
///
+ ///
///
- public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
- float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f,
- SafeLLamaGrammarHandle? grammar = null)
+ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
+ float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
+ SafeLLamaGrammarHandle? grammar, float minP)
{
llama_token id;
@@ -264,6 +265,7 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu
candidates.TailFree(NativeHandle, tfsZ);
candidates.LocallyTypical(NativeHandle, typicalP);
candidates.TopP(NativeHandle, topP);
+ candidates.MinP(NativeHandle, minP);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleToken(NativeHandle);
}
diff --git a/LLama/LLamaInstructExecutor.cs b/LLama/LLamaInstructExecutor.cs
index 80c6f5420..33cbd23e9 100644
--- a/LLama/LLamaInstructExecutor.cs
+++ b/LLama/LLamaInstructExecutor.cs
@@ -216,8 +216,8 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
- inferenceParams.Grammar
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
+ inferenceParams.MinP
);
MirostatMu = mu;
diff --git a/LLama/LLamaInteractExecutor.cs b/LLama/LLamaInteractExecutor.cs
index 5078648b1..98b45814c 100644
--- a/LLama/LLamaInteractExecutor.cs
+++ b/LLama/LLamaInteractExecutor.cs
@@ -194,9 +194,9 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
var mu = MirostatMu;
var id = Context.Sample(
- tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
- inferenceParams.Grammar
+ tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
+ inferenceParams.MinP
);
MirostatMu = mu;
diff --git a/LLama/LLamaStatelessExecutor.cs b/LLama/LLamaStatelessExecutor.cs
index 4bdeaa3f2..9c41af7c0 100644
--- a/LLama/LLamaStatelessExecutor.cs
+++ b/LLama/LLamaStatelessExecutor.cs
@@ -90,8 +90,11 @@ public async IAsyncEnumerable InferAsync(string prompt, IInferenceParams
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
// Sample a single token
- var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
- inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
+ var id = Context.Sample(
+ tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
+ inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
+ inferenceParams.MinP
+ );
// Decode this token into text
decoder.Add(id);
diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs
index 8f20a73ac..4bc154f4c 100644
--- a/LLama/Native/LLamaTokenDataArray.cs
+++ b/LLama/Native/LLamaTokenDataArray.cs
@@ -91,6 +91,21 @@ public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
}
}
+ ///
+ /// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
+ ///
+ ///
+ /// All tokens with probability greater than this will be kept
+ ///
+ public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
+ {
+ using (LLamaTokenDataArrayNative.Create(this, out var st))
+ {
+ NativeApi.llama_sample_min_p(context, ref st, p, minKeep);
+ sorted = st.sorted;
+ }
+ }
+
///
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
///