Skip to content

Commit

Permalink
Merge pull request #973 from martindevans/fix_null_sampler_pipeline
Browse files Browse the repository at this point in the history
Non-Null Default `SamplingPipeline`
  • Loading branch information
martindevans authored Nov 7, 2024
2 parents 3f176be + 20f5485 commit 079410c
Show file tree
Hide file tree
Showing 9 changed files with 62 additions and 99 deletions.
1 change: 0 additions & 1 deletion LLama.Benchmark/LLamaExecutorBenchmark/Prefill.cs
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,6 @@ private void InitializeParamsAndModel()
Prompt = File.ReadAllText(Constants.TextCompletionPromptsFilePath).Substring(0, PromptAndContextLength.Item1);
InferenceParams = new InferenceParams()
{
Temperature = 0.6f,
MaxTokens = 1 // Only prefill, no generation here.
};

Expand Down
3 changes: 1 addition & 2 deletions LLama.Examples/Examples/ChatChineseGB2312.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,8 @@ public static async Task Run()
session
.WithHistoryTransform(new LLamaTransforms.DefaultHistoryTransform("用户", "坤坤"));

InferenceParams inferenceParams = new InferenceParams()
var inferenceParams = new InferenceParams
{
Temperature = 0.9f,
AntiPrompts = new List<string> { "用户:" }
};

Expand Down
12 changes: 11 additions & 1 deletion LLama.Examples/Examples/InteractiveModeExecute.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using LLama.Common;
using LLama.Sampling;

namespace LLama.Examples.Examples
{
Expand All @@ -25,7 +26,16 @@ public static async Task Run()

Console.Write(prompt);

var inferenceParams = new InferenceParams() { Temperature = 0.6f, AntiPrompts = new List<string> { "User:" }, MaxTokens = 128 };
var inferenceParams = new InferenceParams
{
AntiPrompts = new List<string> { "User:" },
MaxTokens = 128,

SamplingPipeline = new DefaultSamplingPipeline
{
Temperature = 0.6f
}
};

while (true)
{
Expand Down
31 changes: 19 additions & 12 deletions LLama.KernelMemory/LlamaSharpTextGenerator.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
using LLama;
using LLama.Common;
using LLama.Sampling;
using Microsoft.KernelMemory.AI;

namespace LLamaSharp.KernelMemory
Expand Down Expand Up @@ -86,25 +87,31 @@ private static InferenceParams OptionsToParams(TextGenerationOptions options, In
return defaultParams with
{
AntiPrompts = defaultParams.AntiPrompts.Concat(options.StopSequences).ToList().AsReadOnly(),
Temperature = (float)options.Temperature,
MaxTokens = options.MaxTokens ?? defaultParams.MaxTokens,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling

SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)options.Temperature,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
}
};
}
else

return new InferenceParams
{
return new InferenceParams
AntiPrompts = options.StopSequences.ToList().AsReadOnly(),
MaxTokens = options.MaxTokens ?? 1024,

SamplingPipeline = new DefaultSamplingPipeline()
{
AntiPrompts = options.StopSequences.ToList().AsReadOnly(),
Temperature = (float)options.Temperature,
MaxTokens = options.MaxTokens ?? 1024,
FrequencyPenalty = (float)options.FrequencyPenalty,
PresencePenalty = (float)options.PresencePenalty,
AlphaFrequency = (float)options.FrequencyPenalty,
AlphaPresence = (float)options.PresencePenalty,
TopP = (float)options.NucleusSampling,
};
}
}
};
}

/// <inheritdoc/>
Expand Down
15 changes: 10 additions & 5 deletions LLama.SemanticKernel/ExtensionMethods.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using LLama.Sampling;
using Microsoft.SemanticKernel.ChatCompletion;
using AuthorRole = LLama.Common.AuthorRole;

Expand Down Expand Up @@ -45,12 +46,16 @@ internal static LLama.Common.InferenceParams ToLLamaSharpInferenceParams(this LL
};
return new LLama.Common.InferenceParams
{
Temperature = (float)requestSettings.Temperature,
TopP = (float)requestSettings.TopP,
PresencePenalty = (float)requestSettings.PresencePenalty,
FrequencyPenalty = (float)requestSettings.FrequencyPenalty,
AntiPrompts = antiPrompts,
MaxTokens = requestSettings.MaxTokens ?? -1
MaxTokens = requestSettings.MaxTokens ?? -1,

SamplingPipeline = new DefaultSamplingPipeline()
{
Temperature = (float)requestSettings.Temperature,
TopP = (float)requestSettings.TopP,
AlphaPresence = (float)requestSettings.PresencePenalty,
AlphaFrequency = (float)requestSettings.FrequencyPenalty,
}
};
}
}
1 change: 0 additions & 1 deletion LLama.WebAPI/Controllers/ChatController.cs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
using LLama.WebAPI.Models;
using LLama.WebAPI.Services;
using Microsoft.AspNetCore.Mvc;
using System;

namespace LLama.WebAPI.Controllers
{
Expand Down
29 changes: 18 additions & 11 deletions LLama.WebAPI/Services/StatefulChatService.cs
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@

using LLama.WebAPI.Models;
using Microsoft;
using System.Runtime.CompilerServices;
using LLama.Sampling;

namespace LLama.WebAPI.Services;

public class StatefulChatService : IDisposable
public sealed class StatefulChatService
: IDisposable
{
private readonly ChatSession _session;
private readonly LLamaContext _context;
Expand Down Expand Up @@ -47,10 +46,14 @@ public async Task<string> Send(SendMessageInput input)
_logger.LogInformation("Input: {text}", input.Text);
var outputs = _session.ChatAsync(
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
new Common.InferenceParams()
new Common.InferenceParams
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
AntiPrompts = [ "User:" ],

SamplingPipeline = new DefaultSamplingPipeline
{
RepeatPenalty = 1.0f
}
});

var result = "";
Expand All @@ -74,11 +77,15 @@ public async IAsyncEnumerable<string> SendStream(SendMessageInput input)
_logger.LogInformation(input.Text);

var outputs = _session.ChatAsync(
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text!)
, new Common.InferenceParams()
new Common.ChatHistory.Message(Common.AuthorRole.User, input.Text),
new Common.InferenceParams
{
RepeatPenalty = 1.0f,
AntiPrompts = new string[] { "User:" },
AntiPrompts = [ "User:" ],

SamplingPipeline = new DefaultSamplingPipeline
{
RepeatPenalty = 1.0f
}
});

await foreach (var output in outputs)
Expand Down
3 changes: 1 addition & 2 deletions LLama.WebAPI/Services/StatelessChatService.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using LLama.Common;
using Microsoft.AspNetCore.Http;
using LLama.Common;
using System.Text;
using static LLama.LLamaTransforms;

Expand Down
66 changes: 2 additions & 64 deletions LLama/Common/InferenceParams.cs
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public record InferenceParams
: IInferenceParams
{
/// <summary>
/// number of tokens to keep from initial prompt
/// number of tokens to keep from initial prompt when applying context shifting
/// </summary>
public int TokensKeep { get; set; } = 0;

Expand All @@ -23,75 +23,13 @@ public record InferenceParams
/// </summary>
public int MaxTokens { get; set; } = -1;

/// <summary>
/// logit bias for specific tokens
/// </summary>
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public Dictionary<LLamaToken, float>? LogitBias { get; set; } = null;

/// <summary>
/// Sequences where the model will stop generating further tokens.
/// </summary>
public IReadOnlyList<string> AntiPrompts { get; set; } = [];

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public int TopK { get; set; } = 40;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float TopP { get; set; } = 0.95f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float MinP { get; set; } = 0.05f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float TfsZ { get; set; } = 1.0f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float TypicalP { get; set; } = 1.0f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float Temperature { get; set; } = 0.8f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float RepeatPenalty { get; set; } = 1.1f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public int RepeatLastTokensCount { get; set; } = 64;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float FrequencyPenalty { get; set; } = .0f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public float PresencePenalty { get; set; } = .0f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")]
public MirostatType Mirostat { get; set; } = MirostatType.Disable;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")]
public float MirostatTau { get; set; } = 5.0f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. MirostatSamplingPipeline or Mirostat2SamplingPipeline")]
public float MirostatEta { get; set; } = 0.1f;

/// <inheritdoc />
[Obsolete("Use the SamplingPipeline property instead with a configured pipeline e.g. DefaultSamplingPipeline")]
public bool PenalizeNL { get; set; } = true;

/// <inheritdoc />
public ISamplingPipeline? SamplingPipeline { get; set; }
public ISamplingPipeline SamplingPipeline { get; set; } = new DefaultSamplingPipeline();
}

/// <summary>
Expand Down

0 comments on commit 079410c

Please sign in to comment.