Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

MinP Sampler #277

Merged
merged 1 commit into from
Nov 13, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
106 changes: 37 additions & 69 deletions LLama.Web/Common/InferenceOptions.cs
Original file line number Diff line number Diff line change
Expand Up @@ -4,93 +4,61 @@

namespace LLama.Web.Common
{
public class InferenceOptions : IInferenceParams
public class InferenceOptions
: IInferenceParams
{
/// <summary>
/// number of tokens to keep from initial prompt
/// </summary>
/// <inheritdoc />
public int TokensKeep { get; set; } = 0;
/// <summary>
/// how many new tokens to predict (n_predict), set to -1 to inifinitely generate response
/// until it complete.
/// </summary>

/// <inheritdoc />
public int MaxTokens { get; set; } = -1;
/// <summary>
/// logit bias for specific tokens
/// </summary>

/// <inheritdoc />
public Dictionary<int, float>? LogitBias { get; set; } = null;

/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
/// <inheritdoc />
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();
/// <summary>
/// path to file for saving/loading model eval state
/// </summary>
public string PathSession { get; set; } = string.Empty;
/// <summary>
/// string to suffix user inputs with
/// </summary>
public string InputSuffix { get; set; } = string.Empty;
/// <summary>
/// string to prefix user inputs with
/// </summary>
public string InputPrefix { get; set; } = string.Empty;
/// <summary>
/// 0 or lower to use vocab size
/// </summary>

/// <inheritdoc />
public int TopK { get; set; } = 40;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TopP { get; set; } = 0.95f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float MinP { get; set; } = 0.05f;

/// <inheritdoc />
public float TfsZ { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TypicalP { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float Temperature { get; set; } = 0.8f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float RepeatPenalty { get; set; } = 1.1f;
/// <summary>
/// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
/// </summary>

/// <inheritdoc />
public int RepeatLastTokensCount { get; set; } = 64;
/// <summary>
/// frequency penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float FrequencyPenalty { get; set; } = .0f;
/// <summary>
/// presence penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float PresencePenalty { get; set; } = .0f;
/// <summary>
/// Mirostat uses tokens instead of words.
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary>

/// <inheritdoc />
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary>
/// target entropy
/// </summary>

/// <inheritdoc />
public float MirostatTau { get; set; } = 5.0f;
/// <summary>
/// learning rate
/// </summary>

/// <inheritdoc />
public float MirostatEta { get; set; } = 0.1f;
/// <summary>
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>

/// <inheritdoc />
public bool PenalizeNL { get; set; } = true;

/// <summary>
Expand Down
14 changes: 9 additions & 5 deletions LLama/Abstractions/IInferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ public interface IInferenceParams
/// </summary>
public Dictionary<int, float>? LogitBias { get; set; }


/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
Expand All @@ -41,10 +40,15 @@ public interface IInferenceParams
/// </summary>
public float TopP { get; set; }

/// <summary>
/// 1.0 = disabled
/// </summary>
public float TfsZ { get; set; }
/// <summary>llama_eval
/// 0.0 = disabled
/// </summary>
public float MinP { get; set; }

/// <summary>
/// 1.0 = disabled
/// </summary>
public float TfsZ { get; set; }

/// <summary>
/// 1.0 = disabled
Expand Down
79 changes: 32 additions & 47 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,12 @@
namespace LLama.Common
{
using llama_token = Int32;

/// <summary>
/// The paramters used for inference.
/// </summary>
public record InferenceParams : IInferenceParams
public record InferenceParams
: IInferenceParams
{
/// <summary>
/// number of tokens to keep from initial prompt
Expand All @@ -30,66 +32,49 @@ public record InferenceParams : IInferenceParams
/// </summary>
public IReadOnlyList<string> AntiPrompts { get; set; } = Array.Empty<string>();

/// <summary>
/// 0 or lower to use vocab size
/// </summary>
/// <inheritdoc />
public int TopK { get; set; } = 40;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TopP { get; set; } = 0.95f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float MinP { get; set; } = 0.05f;

/// <inheritdoc />
public float TfsZ { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float TypicalP { get; set; } = 1.0f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float Temperature { get; set; } = 0.8f;
/// <summary>
/// 1.0 = disabled
/// </summary>

/// <inheritdoc />
public float RepeatPenalty { get; set; } = 1.1f;
/// <summary>
/// last n tokens to penalize (0 = disable penalty, -1 = context size) (repeat_last_n)
/// </summary>

/// <inheritdoc />
public int RepeatLastTokensCount { get; set; } = 64;
/// <summary>
/// frequency penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float FrequencyPenalty { get; set; } = .0f;
/// <summary>
/// presence penalty coefficient
/// 0.0 = disabled
/// </summary>

/// <inheritdoc />
public float PresencePenalty { get; set; } = .0f;
/// <summary>
/// Mirostat uses tokens instead of words.
/// algorithm described in the paper https://arxiv.org/abs/2007.14966.
/// 0 = disabled, 1 = mirostat, 2 = mirostat 2.0
/// </summary>

/// <inheritdoc />
public MirostatType Mirostat { get; set; } = MirostatType.Disable;
/// <summary>
/// target entropy
/// </summary>

/// <inheritdoc />
public float MirostatTau { get; set; } = 5.0f;
/// <summary>
/// learning rate
/// </summary>

/// <inheritdoc />
public float MirostatEta { get; set; } = 0.1f;
/// <summary>
/// consider newlines as a repeatable token (penalize_nl)
/// </summary>

/// <inheritdoc />
public bool PenalizeNL { get; set; } = true;

/// <summary>
/// A grammar to constrain the possible tokens
/// </summary>
/// <inheritdoc />
public SafeLLamaGrammarHandle? Grammar { get; set; }
}

Expand Down
8 changes: 5 additions & 3 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -226,10 +226,11 @@ public void LoadState(State state)
/// <param name="tfsZ"></param>
/// <param name="typicalP"></param>
/// <param name="grammar"></param>
/// <param name="minP"></param>
/// <returns></returns>
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature = 0.8f, MirostatType mirostat = MirostatType.Disable,
float mirostatTau = 5.0f, float mirostatEta = 0.1f, int topK = 40, float topP = 0.95f, float tfsZ = 1.0f, float typicalP = 1.0f,
SafeLLamaGrammarHandle? grammar = null)
public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu, float temperature, MirostatType mirostat,
float mirostatTau, float mirostatEta, int topK, float topP, float tfsZ, float typicalP,
SafeLLamaGrammarHandle? grammar, float minP)
{
llama_token id;

Expand Down Expand Up @@ -264,6 +265,7 @@ public llama_token Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu
candidates.TailFree(NativeHandle, tfsZ);
candidates.LocallyTypical(NativeHandle, typicalP);
candidates.TopP(NativeHandle, topP);
candidates.MinP(NativeHandle, minP);
candidates.Temperature(NativeHandle, temperature);
id = candidates.SampleToken(NativeHandle);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
await LoadState(state);

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 109 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -147,11 +147,11 @@
}

/// <inheritdoc />
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 150 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 150 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 154 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 154 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
{
args.WaitForInput = true;
return (true, Array.Empty<string>());
Expand Down Expand Up @@ -207,7 +207,7 @@
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);

Check warning on line 210 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.
}

var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
Expand All @@ -216,8 +216,8 @@
var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;

Expand Down Expand Up @@ -258,12 +258,12 @@
/// Instruction prefix tokens.
/// </summary>
[JsonPropertyName("inp_pfx")]
public llama_token[] InputPrefixTokens { get; set; }

Check warning on line 261 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'InputPrefixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
/// <summary>
/// Instruction suffix tokens.
/// </summary>
[JsonPropertyName("inp_sfx")]
public llama_token[] InputSuffixTokens { get; set; }

Check warning on line 266 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Non-nullable property 'InputSuffixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
}
}
}
6 changes: 3 additions & 3 deletions LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state);

Check warning on line 92 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 92 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -130,11 +130,11 @@
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 133 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 133 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.Items.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 137 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 137 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)
Expand All @@ -156,7 +156,7 @@
}

/// <inheritdoc />
protected override async Task InferInternal(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 159 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 159 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embeds.Count > 0)
{
Expand Down Expand Up @@ -186,7 +186,7 @@
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);

Check warning on line 189 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.
}

var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
Expand All @@ -194,9 +194,9 @@

var mu = MirostatMu;
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP,
inferenceParams.Grammar
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);
MirostatMu = mu;

Expand Down
7 changes: 5 additions & 2 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -90,8 +90,11 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
var id = Context.Sample(tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar);
var id = Context.Sample(
tokenDataArray, ref mu, inferenceParams.Temperature, inferenceParams.Mirostat, inferenceParams.MirostatTau,
inferenceParams.MirostatEta, inferenceParams.TopK, inferenceParams.TopP, inferenceParams.TfsZ, inferenceParams.TypicalP, inferenceParams.Grammar,
inferenceParams.MinP
);

// Decode this token into text
decoder.Add(id);
Expand Down
15 changes: 15 additions & 0 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,21 @@ public void TopP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
}
}

/// <summary>
/// Minimum P sampling as described in https://github.com/ggerganov/llama.cpp/pull/3841
/// </summary>
/// <param name="context"></param>
/// <param name="p">All tokens with probability greater than this will be kept</param>
/// <param name="minKeep"></param>
public void MinP(SafeLLamaContextHandle context, float p, ulong minKeep = 1)
{
using (LLamaTokenDataArrayNative.Create(this, out var st))
{
NativeApi.llama_sample_min_p(context, ref st, p, minKeep);
sorted = st.sorted;
}
}

/// <summary>
/// Tail Free Sampling described in https://www.trentonbricken.com/Tail-Free-Sampling/.
/// </summary>
Expand Down
Loading