From 20f548537cbcffe8adc9534c5e47e6cee0d41022 Mon Sep 17 00:00:00 2001 From: Martin Evans Date: Fri, 1 Nov 2024 21:47:34 +0000 Subject: [PATCH] - Removed the deprecated sampler parameters - Made the `SamplingPipeline` default to a `DefaultSamplingPipeline` --- .../LLamaExecutorBenchmark/Prefill.cs | 1 - LLama.Examples/Examples/ChatChineseGB2312.cs | 3 +- .../Examples/InteractiveModeExecute.cs | 12 +++- LLama.KernelMemory/LlamaSharpTextGenerator.cs | 31 +++++---- LLama.SemanticKernel/ExtensionMethods.cs | 15 +++-- LLama.WebAPI/Controllers/ChatController.cs | 1 - LLama.WebAPI/Services/StatefulChatService.cs | 29 ++++---- LLama.WebAPI/Services/StatelessChatService.cs | 3 +- LLama/Common/InferenceParams.cs | 66 +------------------ 9 files changed, 62 insertions(+), 99 deletions(-) diff --git a/LLama.Benchmark/LLamaExecutorBenchmark/Prefill.cs b/LLama.Benchmark/LLamaExecutorBenchmark/Prefill.cs index fca00d3e9..33b399ec9 100644 --- a/LLama.Benchmark/LLamaExecutorBenchmark/Prefill.cs +++ b/LLama.Benchmark/LLamaExecutorBenchmark/Prefill.cs @@ -83,7 +83,6 @@ private void InitializeParamsAndModel() Prompt = File.ReadAllText(Constants.TextCompletionPromptsFilePath).Substring(0, PromptAndContextLength.Item1); InferenceParams = new InferenceParams() { - Temperature = 0.6f, MaxTokens = 1 // Only prefill, no generation here. }; diff --git a/LLama.Examples/Examples/ChatChineseGB2312.cs b/LLama.Examples/Examples/ChatChineseGB2312.cs index 70c6557d3..49d8dce1c 100644 --- a/LLama.Examples/Examples/ChatChineseGB2312.cs +++ b/LLama.Examples/Examples/ChatChineseGB2312.cs @@ -55,9 +55,8 @@ public static async Task Run() session .WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤")); - InferenceParams inferenceParams = new InferenceParams() + var inferenceParams = new InferenceParams { - Temperature = 0.9f, AntiPrompts = new List { "用户:" } }; diff --git a/LLama.Examples/Examples/InteractiveModeExecute.cs b/LLama.Examples/Examples/InteractiveModeExecute.cs index d04eb987c..73dbbd12c 100644 --- a/LLama.Examples/Examples/InteractiveModeExecute.cs +++ b/LLama.Examples/Examples/InteractiveModeExecute.cs @@ -1,4 +1,5 @@ using LLama.Common; +using LLama.Sampling; namespace LLama.Examples.Examples { @@ -25,7 +26,16 @@ public static async Task Run() Console.Write(prompt); - var inferenceParams = new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List { "User:" }, MaxTokens = 128 }; + var inferenceParams = new InferenceParams + { + AntiPrompts = new List { "User:" }, + MaxTokens = 128, + + SamplingPipeline = new DefaultSamplingPipeline + { + Temperature = 0.6f + } + }; while (true) { diff --git a/LLama.KernelMemory/LlamaSharpTextGenerator.cs b/LLama.KernelMemory/LlamaSharpTextGenerator.cs index 735d4b4e6..02be0b34b 100644 --- a/LLama.KernelMemory/LlamaSharpTextGenerator.cs +++ b/LLama.KernelMemory/LlamaSharpTextGenerator.cs @@ -1,5 +1,6 @@ using LLama; using LLama.Common; +using LLama.Sampling; using Microsoft.KernelMemory.AI; namespace LLamaSharp.KernelMemory @@ -86,25 +87,31 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In return defaultParams with { AntiPrompts = defaultParams.AntiPrompts.Concat(options.StopSequences).ToList().AsReadOnly(), - Temperature = (float)options.Temperature, MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens, - FrequencyPenalty = (float)options.FrequencyPenalty, - PresencePenalty = (float)options.PresencePenalty, - TopP = (float)options.NucleusSampling + + SamplingPipeline = new DefaultSamplingPipeline() + { + Temperature = (float)options.Temperature, + AlphaFrequency = (float)options.FrequencyPenalty, + AlphaPresence = (float)options.PresencePenalty, + TopP = (float)options.NucleusSampling, + } }; } - else + + return new InferenceParams { - return new InferenceParams + AntiPrompts = options.StopSequences.ToList().AsReadOnly(), + MaxTokens = options.MaxTokens ?? 1024, + + SamplingPipeline = new DefaultSamplingPipeline() { - AntiPrompts = options.StopSequences.ToList().AsReadOnly(), Temperature = (float)options.Temperature, - MaxTokens = options.MaxTokens ?? 1024, - FrequencyPenalty = (float)options.FrequencyPenalty, - PresencePenalty = (float)options.PresencePenalty, + AlphaFrequency = (float)options.FrequencyPenalty, + AlphaPresence = (float)options.PresencePenalty, TopP = (float)options.NucleusSampling, - }; - } + } + }; } /// diff --git a/LLama.SemanticKernel/ExtensionMethods.cs b/LLama.SemanticKernel/ExtensionMethods.cs index c63ee42b8..0439533d0 100644 --- a/LLama.SemanticKernel/ExtensionMethods.cs +++ b/LLama.SemanticKernel/ExtensionMethods.cs @@ -1,3 +1,4 @@ +using LLama.Sampling; using Microsoft.SemanticKernel.ChatCompletion; using AuthorRole = LLama.Common.AuthorRole; @@ -45,12 +46,16 @@ internal static LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LL }; return new LLama.Common.InferenceParams { - Temperature = (float)requestSettings.Temperature, - TopP = (float)requestSettings.TopP, - PresencePenalty = (float)requestSettings.PresencePenalty, - FrequencyPenalty = (float)requestSettings.FrequencyPenalty, AntiPrompts = antiPrompts, - MaxTokens = requestSettings.MaxTokens ?? -1 + MaxTokens = requestSettings.MaxTokens ?? -1, + + SamplingPipeline = new DefaultSamplingPipeline() + { + Temperature = (float)requestSettings.Temperature, + TopP = (float)requestSettings.TopP, + AlphaPresence = (float)requestSettings.PresencePenalty, + AlphaFrequency = (float)requestSettings.FrequencyPenalty, + } }; } } diff --git a/LLama.WebAPI/Controllers/ChatController.cs b/LLama.WebAPI/Controllers/ChatController.cs index 9643ccf80..1a80745a1 100644 --- a/LLama.WebAPI/Controllers/ChatController.cs +++ b/LLama.WebAPI/Controllers/ChatController.cs @@ -2,7 +2,6 @@ using LLama.WebAPI.Models; using LLama.WebAPI.Services; using Microsoft.AspNetCore.Mvc; -using System; namespace LLama.WebAPI.Controllers { diff --git a/LLama.WebAPI/Services/StatefulChatService.cs b/LLama.WebAPI/Services/StatefulChatService.cs index ae2401c90..a1e8513d7 100644 --- a/LLama.WebAPI/Services/StatefulChatService.cs +++ b/LLama.WebAPI/Services/StatefulChatService.cs @@ -1,11 +1,10 @@ - using LLama.WebAPI.Models; -using Microsoft; -using System.Runtime.CompilerServices; +using LLama.Sampling; namespace LLama.WebAPI.Services; -public class StatefulChatService : IDisposable +public sealed class StatefulChatService + : IDisposable { private readonly ChatSession _session; private readonly LLamaContext _context; @@ -47,10 +46,14 @@ public async Task Send(SendMessageInput input) _logger.LogInformation("Input: {text}", input.Text); var outputs = _session.ChatAsync( new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text), - new Common.InferenceParams() + new Common.InferenceParams { - RepeatPenalty = 1.0f, - AntiPrompts = new string[] { "User:" }, + AntiPrompts = [ "User:" ], + + SamplingPipeline = new DefaultSamplingPipeline + { + RepeatPenalty = 1.0f + } }); var result = ""; @@ -74,11 +77,15 @@ public async IAsyncEnumerable SendStream(SendMessageInput input) _logger.LogInformation(input.Text); var outputs = _session.ChatAsync( - new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text!) - , new Common.InferenceParams() + new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text), + new Common.InferenceParams { - RepeatPenalty = 1.0f, - AntiPrompts = new string[] { "User:" }, + AntiPrompts = [ "User:" ], + + SamplingPipeline = new DefaultSamplingPipeline + { + RepeatPenalty = 1.0f + } }); await foreach (var output in outputs) diff --git a/LLama.WebAPI/Services/StatelessChatService.cs b/LLama.WebAPI/Services/StatelessChatService.cs index 3520c29b0..c965b3018 100644 --- a/LLama.WebAPI/Services/StatelessChatService.cs +++ b/LLama.WebAPI/Services/StatelessChatService.cs @@ -1,5 +1,4 @@ -using LLama.Common; -using Microsoft.AspNetCore.Http; +using LLama.Common; using System.Text; using static LLama.LLamaTransforms; diff --git a/LLama/Common/InferenceParams.cs b/LLama/Common/InferenceParams.cs index b56528c4f..8f2a5a14f 100644 --- a/LLama/Common/InferenceParams.cs +++ b/LLama/Common/InferenceParams.cs @@ -13,7 +13,7 @@ public record InferenceParams : IInferenceParams { /// - /// number of tokens to keep from initial prompt + /// number of tokens to keep from initial prompt when applying context shifting /// public int TokensKeep { get; set; } = 0; @@ -23,75 +23,13 @@ public record InferenceParams /// public int MaxTokens { get; set; } = -1; - /// - /// logit bias for specific tokens - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public Dictionary? LogitBias { get; set; } = null; - /// /// Sequences where the model will stop generating further tokens. /// public IReadOnlyList AntiPrompts { get; set; } = []; /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public int TopK { get; set; } = 40; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public float TopP { get; set; } = 0.95f; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public float MinP { get; set; } = 0.05f; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public float TfsZ { get; set; } = 1.0f; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public float TypicalP { get; set; } = 1.0f; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public float Temperature { get; set; } = 0.8f; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public float RepeatPenalty { get; set; } = 1.1f; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public int RepeatLastTokensCount { get; set; } = 64; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public float FrequencyPenalty { get; set; } = .0f; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public float PresencePenalty { get; set; } = .0f; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")] - public MirostatType Mirostat { get; set; } = MirostatType.Disable; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")] - public float MirostatTau { get; set; } = 5.0f; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")] - public float MirostatEta { get; set; } = 0.1f; - - /// - [Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")] - public bool PenalizeNL { get; set; } = true; - - /// - public ISamplingPipeline? SamplingPipeline { get; set; } + public ISamplingPipeline SamplingPipeline { get; set; } = new DefaultSamplingPipeline(); } ///