Skip to content

MinP Sampler #277

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 37 additions & 69 deletions LLama.Web/Common/InferenceOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,93 +4,61 @@

namespace LLama.Web.Common
{
public class InferenceOptions : IInferenceParams
public class InferenceOptions
: IInferenceParams
{
/// <summary>
/// number of tokens to keep from initial prompt
/// </summary>
/// <inheritdoc />
public int TokensKeep { get; set; } = 0;
/// <summary>
/// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
/// until it complete.
/// </summary>

/// <inheritdoc />
public int MaxTokens { get; set; } = -1;
/// <summary>
/// logit bias for specific tokens
/// </summary>

/// <inheritdoc />
public Dictionary<int, float>? LogitBias { get; set; } = null;

/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
/// <inheritdoc />
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
/// <summary>
/// path to file for saving/loading model eval state
/// </summary>
public string PathSession { get; set; } = string.Empty;
/// <summary>
/// string to suffix user inputs with
/// </summary>
public string InputSuffix { get; set; } = string.Empty;
/// <summary>
/// string to prefix user inputs with
/// </summary>
public string InputPrefix { get; set; } = string.Empty;
/// <summary>
/// 0 or lower to use vocab size
/// </summary>

/// <inheritdoc />
public int TopK { get; set; } = 40;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TopP { get; set; } = 0.95f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float MinP { get; set; } = 0.05f;

/// <inheritdoc />
public float TfsZ { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TypicalP { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float Temperature { get; set; } = 0.8f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float RepeatPenalty { get; set; } = 1.1f;
/// <summary>
/// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
/// </summary>

/// <inheritdoc />
public int RepeatLastTokensCount { get; set; } = 64;
/// <summary>
/// frequency penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float FrequencyPenalty { get; set; } = .0f;
/// <summary>
/// presence penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float PresencePenalty { get; set; } = .0f;
/// <summary>
/// Mirostat uses tokens instead of words.
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary>

/// <inheritdoc />
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary>
/// target entropy
/// </summary>

/// <inheritdoc />
public float MirostatTau { get; set; } = 5.0f;
/// <summary>
/// learning rate
/// </summary>

/// <inheritdoc />
public float MirostatEta { get; set; } = 0.1f;
/// <summary>
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>

/// <inheritdoc />
public bool PenalizeNL { get; set; } = true;

/// <summary>
Expand Down
14 changes: 9 additions & 5 deletions LLama/Abstractions/IInferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ public interface IInferenceParams
/// </summary>
public Dictionary<int, float>? LogitBias { get; set; }


/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
Expand All @@ -41,10 +40,15 @@ public interface IInferenceParams
/// </summary>
public float TopP { get; set; }

/// <summary>
/// 1.0 = disabled
/// </summary>
public float TfsZ { get; set; }
/// <summary>llama_eval
/// 0.0 = disabled
/// </summary>
public float MinP { get; set; }

/// <summary>
/// 1.0 = disabled
/// </summary>
public float TfsZ { get; set; }

/// <summary>
/// 1.0 = disabled
Expand Down
79 changes: 32 additions & 47 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
namespace LLama.Common
{
using llama_token = Int32;

/// <summary>
/// The paramters used for inference.
/// </summary>
public record InferenceParams : IInferenceParams
public record InferenceParams
: IInferenceParams
{
/// <summary>
/// number of tokens to keep from initial prompt
Expand All @@ -30,66 +32,49 @@ public record InferenceParams : IInferenceParams
/// </summary>
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();

/// <summary>
/// 0 or lower to use vocab size
/// </summary>
/// <inheritdoc />
public int TopK { get; set; } = 40;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TopP { get; set; } = 0.95f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float MinP { get; set; } = 0.05f;

/// <inheritdoc />
public float TfsZ { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TypicalP { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float Temperature { get; set; } = 0.8f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float RepeatPenalty { get; set; } = 1.1f;
/// <summary>
/// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
/// </summary>

/// <inheritdoc />
public int RepeatLastTokensCount { get; set; } = 64;
/// <summary>
/// frequency penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float FrequencyPenalty { get; set; } = .0f;
/// <summary>
/// presence penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float PresencePenalty { get; set; } = .0f;
/// <summary>
/// Mirostat uses tokens instead of words.
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary>

/// <inheritdoc />
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary>
/// target entropy
/// </summary>

/// <inheritdoc />
public float MirostatTau { get; set; } = 5.0f;
/// <summary>
/// learning rate
/// </summary>

/// <inheritdoc />
public float MirostatEta { get; set; } = 0.1f;
/// <summary>
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>

/// <inheritdoc />
public bool PenalizeNL { get; set; } = true;

/// <summary>
/// A grammar to constrain the possible tokens
/// </summary>
/// <inheritdoc />
public SafeLLamaGrammarHandle? Grammar { get; set; }
}

Expand Down
8 changes: 5 additions & 3 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,11 @@ public void LoadState(State state)
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <param name="grammar"></param>
/// <param name="minP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f,
SafeLLamaGrammarHandle? grammar = null)
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
SafeLLamaGrammarHandle? grammar, float minP)
{
llama_token id;

Expand Down Expand Up @@ -264,6 +265,7 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu
candidates.TailFree(NativeHandle, tfsZ);
candidates.LocallyTypical(NativeHandle, typicalP);
candidates.TopP(NativeHandle, topP);
candidates.MinP(NativeHandle, minP);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleToken(NativeHandle);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -216,8 +216,8 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;

Expand Down
6 changes: 3 additions & 3 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,9 +194,9 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;

Expand Down
7 changes: 5 additions & 2 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,11 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);

// Decode this token into text
decoder.Add(id);
Expand Down
15 changes: 15 additions & 0 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
}
}

/// <summary>
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
/// </summary>
/// <param name="context"></param>
/// <param name="p">All tokens with probability greater than this will be kept</param>
/// <param name="minKeep"></param>
public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_min_p(context, ref st, p, minKeep);
sorted = st.sorted;
}
}

/// <summary>
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
/// </summary>
Expand Down