Skip to content

Commit

Permalink
Swapped StatelessExecutor to use llama_decode!
Browse files Browse the repository at this point in the history
 - Added `logits_i` argument to `Context.ApplyPenalty`
 - Added a new exception type for `llama_decode` return code
  • Loading branch information
martindevans committed Jan 20, 2024
1 parent 892e841 commit a2e29d3
Show file tree
Hide file tree
Showing 8 changed files with 90 additions and 38 deletions.
7 changes: 1 addition & 6 deletions LLama.Examples/Examples/BatchedDecoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -105,12 +105,7 @@ public static async Task Run()
if (i_batch[i] < 0)
continue;

var n_vocab = model.VocabCount;
LLamaTokenDataArray candidates;
unsafe
{
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
}
var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i]));

candidates.TopK(context.NativeHandle, top_k);
candidates.TopP(context.NativeHandle, top_p);
Expand Down
3 changes: 2 additions & 1 deletion LLama.Unittest/StatelessExecutorTest.cs
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
{
ContextSize = 60,
Seed = 1754,
BatchSize = 2,
};
_weights = LLamaWeights.LoadFromFile(_params);
}
Expand Down Expand Up @@ -60,7 +61,7 @@ public async Task OutOfContext()
{
var executor = new StatelessExecutor(_weights, _params);

const string question = " Question. cats or dogs?\nAnswer: ";
const string question = " Question. cats or dogs?\nAnswer:";

// The context size is set to 60. Generate more than that, forcing it to generate a coherent response
// with a modified context
Expand Down
20 changes: 20 additions & 0 deletions LLama/Exceptions/RuntimeError.cs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
using System;
using LLama.Native;

namespace LLama.Exceptions;

Expand Down Expand Up @@ -36,4 +37,23 @@ public LoadWeightsFailedException(string modelPath)
{
ModelPath = modelPath;
}
}

/// <summary>
/// `llama_decode` return a non-zero status code
/// </summary>
public class LLamaDecodeError
: RuntimeError
{
/// <summary>
/// The return status code
/// </summary>
public DecodeResult ReturnCode { get; }

/// <inheritdoc />
public LLamaDecodeError(DecodeResult returnCode)
: base($"llama_decode failed: '{returnCode}'")
{
ReturnCode = returnCode;
}
}
32 changes: 14 additions & 18 deletions LLama/LLamaContext.cs
Original file line number Diff line number Diff line change
Expand Up @@ -293,6 +293,7 @@ public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu,
/// <summary>
/// Apply the penalty for the tokens. Please don't use it unless you fully know what it does.
/// </summary>
/// <param name="logits_i"></param>
/// <param name="lastTokens"></param>
/// <param name="logitBias"></param>
/// <param name="repeatLastTokensCount"></param>
Expand All @@ -301,11 +302,11 @@ public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu,
/// <param name="alphaPresence"></param>
/// <param name="penalizeNL"></param>
/// <returns></returns>
public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
bool penalizeNL = true)
{
var logits = NativeHandle.GetLogits();
var logits = NativeHandle.GetLogitsIth(logits_i);

// Apply params.logit_bias map
if (logitBias is not null)
Expand Down Expand Up @@ -348,28 +349,23 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dict
/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public int Decode(LLamaBatch batch)
public DecodeResult Decode(LLamaBatch batch)
{
return NativeHandle.Decode(batch);
if (batch.TokenCount == 0)
return 0;
if (batch.TokenCount > Params.BatchSize)
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));

return (DecodeResult)NativeHandle.Decode(batch);
}

/// <summary>
/// </summary>
/// <param name="batch"></param>
/// <param name="cancellationToken"></param>
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
/// - 0: success<br />
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
/// - &lt; 0: error<br />
/// </returns>
public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
{
return Task.Run(() => NativeHandle.Decode(batch), cancellationToken);
return Task.Run(() => Decode(batch), cancellationToken);
}

/// <summary>
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaInstructExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
Expand Down
2 changes: 1 addition & 1 deletion LLama/LLamaInteractExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
}
else
{
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

var mu = MirostatMu;
Expand Down
40 changes: 29 additions & 11 deletions LLama/LLamaStatelessExecutor.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
using System.Linq;
using System.Runtime.CompilerServices;
using System.Threading;
using System.Threading.Tasks;
using LLama.Exceptions;
using LLama.Native;
using LLama.Sampling;
using Microsoft.Extensions.Logging;
Expand All @@ -22,6 +22,7 @@ public class StatelessExecutor
private readonly LLamaWeights _weights;
private readonly IContextParams _params;
private readonly ILogger? _logger;
private readonly LLamaBatch _batch;

/// <summary>
/// The context used by the executor when running the inference.
Expand All @@ -39,6 +40,7 @@ public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger?
_weights = weights;
_params = @params;
_logger = logger;
_batch = new LLamaBatch(1);

Context = _weights.CreateContext(_params, logger);
Context.Dispose();
Expand Down Expand Up @@ -71,16 +73,29 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
var lastTokens = new List<LLamaToken>(repeat_last_n);
for (var i = 0; i < repeat_last_n; i++)
lastTokens.Add((LLamaToken)0);
lastTokens.Add(0);

// Tokenize the prompt
var tokens = Context.Tokenize(prompt).ToList();
lastTokens.AddRange(tokens);
var n_past = 1 + tokens.Count;

// Evaluate the prompt
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
.ConfigureAwait(false);
// Evaluate the prompt, in chunks smaller than the max batch size
var n_past = 0;
var batchSize = (int)Context.Params.BatchSize;
for (var i = 0; i < tokens.Count; i += batchSize)
{
var n_eval = tokens.Count - i;
if (n_eval > batchSize)
n_eval = batchSize;

_batch.Clear();
for (var j = 0; j < n_eval; j++)
_batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1);

var returnCode = await Context.DecodeAsync(_batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}

// Begin loop, evaluating one token at a time
var mu = (float?)null;
Expand All @@ -90,12 +105,12 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
LLamaToken id;
if (inferenceParams.SamplingPipeline is not null)
{
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens);
}
else
{
// Penalize the generated tokens by various penalties
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n,
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);

// Sample a single token
Expand Down Expand Up @@ -136,9 +151,12 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
n_past -= n_discard;
}

// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
.ConfigureAwait(false);
// Evaluate with this new token
_batch.Clear();
_batch.Add(id, n_past++, LLamaSeqId.Zero, true);
var returnCode = await context.DecodeAsync(_batch, cancellationToken);
if (returnCode != 0)
throw new LLamaDecodeError(returnCode);
}
}
}
Expand Down
22 changes: 22 additions & 0 deletions LLama/Native/DecodeResult.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
namespace LLama.Native;

/// <summary>
/// Return codes from llama_decode
/// </summary>
public enum DecodeResult
{
/// <summary>
/// An unspecified error
/// </summary>
Error = -1,

/// <summary>
/// Ok.
/// </summary>
Ok = 0,

/// <summary>
/// Could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
/// </summary>
NoKvSlot = 1,
}

0 comments on commit a2e29d3

Please sign in to comment.