Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Swapped StatelessExecutor to use llama_decode! #445

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 1 addition & 6 deletions LLama.Examples/Examples/BatchedDecoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,7 @@ public static async Task Run()
if (i_batch[i] < 0)
continue;

var n_vocab = model.VocabCount;
LLamaTokenDataArray candidates;
unsafe
{
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
}
var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i]));

candidates.TopK(context.NativeHandle, top_k);
candidates.TopP(context.NativeHandle, top_p);
Expand Down
3 changes: 2 additions & 1 deletion LLama.Unittest/StatelessExecutorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
{
ContextSize = 60,
Seed = 1754,
BatchSize = 2,
};
_weights = LLamaWeights.LoadFromFile(_params);
}
Expand Down Expand Up @@ -60,7 +61,7 @@ public async Task OutOfContext()
{
var executor = new StatelessExecutor(_weights, _params);

const string question = " Question. cats or dogs?\nAnswer: ";
const string question = " Question. cats or dogs?\nAnswer:";

// The context size is set to 60. Generate more than that, forcing it to generate a coherent response
// with a modified context
Expand Down
20 changes: 20 additions & 0 deletions LLama/Exceptions/RuntimeError.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using LLama.Native;

namespace LLama.Exceptions;

Expand Down Expand Up @@ -36,4 +37,23 @@ public LoadWeightsFailedException(string modelPath)
{
ModelPath = modelPath;
}
}

/// <summary>
/// `llama_decode` return a non-zero status code
/// </summary>
public class LLamaDecodeError
: RuntimeError
{
/// <summary>
/// The return status code
/// </summary>
public DecodeResult ReturnCode { get; }

/// <inheritdoc />
public LLamaDecodeError(DecodeResult returnCode)
: base($"llama_decode failed: '{returnCode}'")
{
ReturnCode = returnCode;
}
}
32 changes: 14 additions & 18 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu,
/// <summary>
/// Apply the penalty for the tokens. Please don't use it unless you fully know what it does.
/// </summary>
/// <param name="logits_i"></param>
/// <param name="lastTokens"></param>
/// <param name="logitBias"></param>
/// <param name="repeatLastTokensCount"></param>
Expand All @@ -301,11 +302,11 @@ public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu,
/// <param name="alphaPresence"></param>
/// <param name="penalizeNL"></param>
/// <returns></returns>
public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
var logits = NativeHandle.GetLogits();
var logits = NativeHandle.GetLogitsIth(logits_i);

// Apply params.logit_bias map
if (logitBias is not null)
Expand Down Expand Up @@ -348,28 +349,23 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dict
/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatch batch)
public DecodeResult Decode(LLamaBatch batch)
{
return NativeHandle.Decode(batch);
if (batch.TokenCount == 0)
return 0;
if (batch.TokenCount > Params.BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));

return (DecodeResult)NativeHandle.Decode(batch);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <param name="cancellationToken"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
{
return Task.Run(() => NativeHandle.Decode(batch), cancellationToken);
return Task.Run(() => Decode(batch), cancellationToken);
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InstructExecutorState>(fs);
await LoadState(state);

Check warning on line 108 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 108 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'data' in 'Task InstructExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -146,11 +146,11 @@
}

/// <inheritdoc />
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 149 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 149 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 153 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 153 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
{
args.WaitForInput = true;
return (true, Array.Empty<string>());
Expand Down Expand Up @@ -187,7 +187,7 @@
}

TryReuseMathingPrefix();
_pastTokensCount = Context.Eval(_embeds, _pastTokensCount);

Check warning on line 190 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

'LLamaContext.Eval(List<LLamaToken>, int)' is obsolete: 'use Decode() instead'

if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{
Expand All @@ -206,7 +206,7 @@
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);

Check warning on line 209 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.
}

LLamaToken id;
Expand All @@ -216,7 +216,7 @@
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
Expand Down Expand Up @@ -265,12 +265,12 @@
/// Instruction prefix tokens.
/// </summary>
[JsonPropertyName("inp_pfx")]
public LLamaToken[] InputPrefixTokens { get; set; }

Check warning on line 268 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Non-nullable property 'InputPrefixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
/// <summary>
/// Instruction suffix tokens.
/// </summary>
[JsonPropertyName("inp_sfx")]
public LLamaToken[] InputSuffixTokens { get; set; }

Check warning on line 273 in LLama/LLamaInstructExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Non-nullable property 'InputSuffixTokens' must contain a non-null value when exiting constructor. Consider declaring the property as nullable.
}
}
}
2 changes: 1 addition & 1 deletion LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@
using (var fs = new FileStream(filename, FileMode.Open, FileAccess.Read))
{
var state = await JsonSerializer.DeserializeAsync<InteractiveExecutorState>(fs);
await LoadState(state);

Check warning on line 91 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 91 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.

Check warning on line 91 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Possible null reference argument for parameter 'data' in 'Task InteractiveExecutor.LoadState(ExecutorBaseState data)'.
}
}

Expand Down Expand Up @@ -129,11 +129,11 @@
/// <param name="inferenceParams"></param>
/// <param name="args"></param>
/// <returns></returns>
protected override async Task<(bool, IReadOnlyList<string>)> PostProcess(IInferenceParams inferenceParams, InferStateArgs args)

Check warning on line 132 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.

Check warning on line 132 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

This async method lacks 'await' operators and will run synchronously. Consider using the 'await' operator to await non-blocking API calls, or 'await Task.Run(...)' to do CPU-bound work on a background thread.
{
if (_embed_inps.Count <= _consumedTokensCount)
{
if (_last_n_tokens.TokensEndsWithAnyString(args.Antiprompts, Context.NativeHandle.ModelHandle, Context.Encoding))

Check warning on line 136 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 136 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'

Check warning on line 136 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

'IReadOnlyListExtensions.TokensEndsWithAnyString<TTokens>(TTokens, IList<string>?, SafeLlamaModelHandle, Encoding)' is obsolete: 'Use an Antiprompt processor instead'
args.WaitForInput = true;

if (_pastTokensCount > 0 && args.WaitForInput)
Expand Down Expand Up @@ -166,7 +166,7 @@
}

TryReuseMathingPrefix();
_pastTokensCount = Context.Eval(_embeds, _pastTokensCount);

Check warning on line 169 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

'LLamaContext.Eval(List<LLamaToken>, int)' is obsolete: 'use Decode() instead'

if (_embeds.Count > 0 && !string.IsNullOrEmpty(_pathSession))
{
Expand All @@ -185,7 +185,7 @@
if (!string.IsNullOrEmpty(_pathSession) && args.NeedToSaveSession)
{
args.NeedToSaveSession = false;
SaveSessionFile(_pathSession);

Check warning on line 188 in LLama/LLamaInteractExecutor.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Possible null reference argument for parameter 'filename' in 'void StatefulExecutorBase.SaveSessionFile(string filename)'.
}

LLamaToken id;
Expand All @@ -195,7 +195,7 @@
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
Expand Down
40 changes: 29 additions & 11 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Exceptions;
using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
Expand All @@ -22,6 +22,7 @@ public class StatelessExecutor
private readonly LLamaWeights _weights;
private readonly IContextParams _params;
private readonly ILogger? _logger;
private readonly LLamaBatch _batch;

/// <summary>
/// The context used by the executor when running the inference.
Expand All @@ -39,6 +40,7 @@ public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger?
_weights = weights;
_params = @params;
_logger = logger;
_batch = new LLamaBatch(1);

Context = _weights.CreateContext(_params, logger);
Context.Dispose();
Expand Down Expand Up @@ -71,16 +73,29 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List<LLamaToken>(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add((LLamaToken)0);
lastTokens.Add(0);

// Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList();
lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count;

// Evaluate the prompt
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);
// Evaluate the prompt, in chunks smaller than the max batch size
var n_past = 0;
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Count; i += batchSize)
{
var n_eval = tokens.Count - i;
if (n_eval > batchSize)
n_eval = batchSize;

_batch.Clear();
for (var j = 0; j < n_eval; j++)
_batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1);

var returnCode = await Context.DecodeAsync(_batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}

// Begin loop, evaluating one token at a time
var mu = (float?)null;
Expand All @@ -90,12 +105,12 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
Expand Down Expand Up @@ -136,9 +151,12 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
n_past -= n_discard;
}

// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
.ConfigureAwait(false);
// Evaluate with this new token
_batch.Clear();
_batch.Add(id, n_past++, LLamaSeqId.Zero, true);
var returnCode = await context.DecodeAsync(_batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}
}
}
Expand Down
22 changes: 22 additions & 0 deletions LLama/Native/DecodeResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
namespace LLama.Native;

/// <summary>
/// Return codes from llama_decode
/// </summary>
public enum DecodeResult
{
/// <summary>
/// An unspecified error
/// </summary>
Error = -1,

/// <summary>
/// Ok.
/// </summary>
Ok = 0,

/// <summary>
/// Could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
/// </summary>
NoKvSlot = 1,
}
Loading