Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Non-Null Default SamplingPipeline #973

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion LLama.Benchmark/LLamaExecutorBenchmark/Prefill.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ private void InitializeParamsAndModel()
Prompt = File.ReadAllText(Constants.TextCompletionPromptsFilePath).Substring(0, PromptAndContextLength.Item1);
InferenceParams = new InferenceParams()
{
Temperature = 0.6f,
MaxTokens = 1 // Only prefill, no generation here.
};

Expand Down
3 changes: 1 addition & 2 deletions LLama.Examples/Examples/ChatChineseGB2312.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ public static async Task Run()
session
.WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤"));

InferenceParams inferenceParams = new InferenceParams()
var inferenceParams = new InferenceParams
{
Temperature = 0.9f,
AntiPrompts = new List<string> { "用户:" }
};

Expand Down
12 changes: 11 additions & 1 deletion LLama.Examples/Examples/InteractiveModeExecute.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LLama.Common;
using LLama.Sampling;

namespace LLama.Examples.Examples
{
Expand All @@ -25,7 +26,16 @@ public static async Task Run()

Console.Write(prompt);

var inferenceParams = new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" }, MaxTokens = 128 };
var inferenceParams = new InferenceParams
{
AntiPrompts = new List<string> { "User:" },
MaxTokens = 128,

SamplingPipeline = new DefaultSamplingPipeline
{
Temperature = 0.6f
}
};

while (true)
{
Expand Down
31 changes: 19 additions & 12 deletions LLama.KernelMemory/LlamaSharpTextGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LLama;
using LLama.Common;
using LLama.Sampling;
using Microsoft.KernelMemory.AI;

namespace LLamaSharp.KernelMemory
Expand Down Expand Up @@ -86,25 +87,31 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
return defaultParams with
{
AntiPrompts = defaultParams.AntiPrompts.Concat(options.StopSequences).ToList().AsReadOnly(),
Temperature = (float)options.Temperature,
MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling

SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)options.Temperature,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
}
};
}
else

return new InferenceParams
{
return new InferenceParams
AntiPrompts = options.StopSequences.ToList().AsReadOnly(),
MaxTokens = options.MaxTokens ?? 1024,

SamplingPipeline = new DefaultSamplingPipeline()
{
AntiPrompts = options.StopSequences.ToList().AsReadOnly(),
Temperature = (float)options.Temperature,
MaxTokens = options.MaxTokens ?? 1024,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
};
}
}
};
}

/// <inheritdoc/>
Expand Down
15 changes: 10 additions & 5 deletions LLama.SemanticKernel/ExtensionMethods.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using LLama.Sampling;
using Microsoft.SemanticKernel.ChatCompletion;
using AuthorRole = LLama.Common.AuthorRole;

Expand Down Expand Up @@ -45,12 +46,16 @@ internal static LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LL
};
return new LLama.Common.InferenceParams
{
Temperature = (float)requestSettings.Temperature,
TopP = (float)requestSettings.TopP,
PresencePenalty = (float)requestSettings.PresencePenalty,
FrequencyPenalty = (float)requestSettings.FrequencyPenalty,
AntiPrompts = antiPrompts,
MaxTokens = requestSettings.MaxTokens ?? -1
MaxTokens = requestSettings.MaxTokens ?? -1,

SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)requestSettings.Temperature,
TopP = (float)requestSettings.TopP,
AlphaPresence = (float)requestSettings.PresencePenalty,
AlphaFrequency = (float)requestSettings.FrequencyPenalty,
}
};
}
}
1 change: 0 additions & 1 deletion LLama.WebAPI/Controllers/ChatController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using LLama.WebAPI.Models;
using LLama.WebAPI.Services;
using Microsoft.AspNetCore.Mvc;
using System;

namespace LLama.WebAPI.Controllers
{
Expand Down
29 changes: 18 additions & 11 deletions LLama.WebAPI/Services/StatefulChatService.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@

using LLama.WebAPI.Models;
using Microsoft;
using System.Runtime.CompilerServices;
using LLama.Sampling;

namespace LLama.WebAPI.Services;

public class StatefulChatService : IDisposable
public sealed class StatefulChatService
: IDisposable
{
private readonly ChatSession _session;
private readonly LLamaContext _context;
Expand Down Expand Up @@ -47,10 +46,14 @@ public async Task<string> Send(SendMessageInput input)
_logger.LogInformation("Input: {text}", input.Text);
var outputs = _session.ChatAsync(
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
new Common.InferenceParams()
new Common.InferenceParams
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
AntiPrompts = [ "User:" ],

SamplingPipeline = new DefaultSamplingPipeline
{
RepeatPenalty = 1.0f
}
});

var result = "";
Expand All @@ -74,11 +77,15 @@ public async IAsyncEnumerable<string> SendStream(SendMessageInput input)
_logger.LogInformation(input.Text);

var outputs = _session.ChatAsync(
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text!)
, new Common.InferenceParams()
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
new Common.InferenceParams
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
AntiPrompts = [ "User:" ],

SamplingPipeline = new DefaultSamplingPipeline
{
RepeatPenalty = 1.0f
}
});

await foreach (var output in outputs)
Expand Down
3 changes: 1 addition & 2 deletions LLama.WebAPI/Services/StatelessChatService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LLama.Common;
using Microsoft.AspNetCore.Http;
using LLama.Common;
using System.Text;
using static LLama.LLamaTransforms;

Expand Down
66 changes: 2 additions & 64 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public record InferenceParams
: IInferenceParams
{
/// <summary>
/// number of tokens to keep from initial prompt
/// number of tokens to keep from initial prompt when applying context shifting
/// </summary>
public int TokensKeep { get; set; } = 0;

Expand All @@ -23,75 +23,13 @@ public record InferenceParams
/// </summary>
public int MaxTokens { get; set; } = -1;

/// <summary>
/// logit bias for specific tokens
/// </summary>
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;

/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
public IReadOnlyList<string> AntiPrompts { get; set; } = [];

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public int TopK { get; set; } = 40;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float TopP { get; set; } = 0.95f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float MinP { get; set; } = 0.05f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float TfsZ { get; set; } = 1.0f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float TypicalP { get; set; } = 1.0f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float Temperature { get; set; } = 0.8f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float RepeatPenalty { get; set; } = 1.1f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public int RepeatLastTokensCount { get; set; } = 64;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float FrequencyPenalty { get; set; } = .0f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float PresencePenalty { get; set; } = .0f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")]
public MirostatType Mirostat { get; set; } = MirostatType.Disable;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")]
public float MirostatTau { get; set; } = 5.0f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")]
public float MirostatEta { get; set; } = 0.1f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public bool PenalizeNL { get; set; } = true;

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
public ISamplingPipeline SamplingPipeline { get; set; } = new DefaultSamplingPipeline();
}

/// <summary>
Expand Down
Loading