Skip to content

Commit 96c26c2

Browse files
authored
Merge pull request #445 from martindevans/stateless_executor_llama_decode
Swapped `StatelessExecutor` to use `llama_decode`!
2 parents 1bc6147 + a2e29d3 commit 96c26c2

File tree

8 files changed

+90
-38
lines changed

8 files changed

+90
-38
lines changed

LLama.Examples/Examples/BatchedDecoding.cs

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -105,12 +105,7 @@ public static async Task Run()
105105
if (i_batch[i] < 0)
106106
continue;
107107

108-
var n_vocab = model.VocabCount;
109-
LLamaTokenDataArray candidates;
110-
unsafe
111-
{
112-
candidates = LLamaTokenDataArray.Create(new Span<float>(NativeApi.llama_get_logits_ith(context.NativeHandle, i_batch[i]), n_vocab));
113-
}
108+
var candidates = LLamaTokenDataArray.Create(context.NativeHandle.GetLogitsIth(i_batch[i]));
114109

115110
candidates.TopK(context.NativeHandle, top_k);
116111
candidates.TopP(context.NativeHandle, top_p);

LLama.Unittest/StatelessExecutorTest.cs

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ public StatelessExecutorTest(ITestOutputHelper testOutputHelper)
1919
{
2020
ContextSize = 60,
2121
Seed = 1754,
22+
BatchSize = 2,
2223
};
2324
_weights = LLamaWeights.LoadFromFile(_params);
2425
}
@@ -60,7 +61,7 @@ public async Task OutOfContext()
6061
{
6162
var executor = new StatelessExecutor(_weights, _params);
6263

63-
const string question = " Question. cats or dogs?\nAnswer: ";
64+
const string question = " Question. cats or dogs?\nAnswer:";
6465

6566
// The context size is set to 60. Generate more than that, forcing it to generate a coherent response
6667
// with a modified context

LLama/Exceptions/RuntimeError.cs

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
using System;
2+
using LLama.Native;
23

34
namespace LLama.Exceptions;
45

@@ -36,4 +37,23 @@ public LoadWeightsFailedException(string modelPath)
3637
{
3738
ModelPath = modelPath;
3839
}
40+
}
41+
42+
/// <summary>
43+
/// `llama_decode` return a non-zero status code
44+
/// </summary>
45+
public class LLamaDecodeError
46+
: RuntimeError
47+
{
48+
/// <summary>
49+
/// The return status code
50+
/// </summary>
51+
public DecodeResult ReturnCode { get; }
52+
53+
/// <inheritdoc />
54+
public LLamaDecodeError(DecodeResult returnCode)
55+
: base($"llama_decode failed: '{returnCode}'")
56+
{
57+
ReturnCode = returnCode;
58+
}
3959
}

LLama/LLamaContext.cs

Lines changed: 14 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -293,6 +293,7 @@ public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu,
293293
/// <summary>
294294
/// Apply the penalty for the tokens. Please don't use it unless you fully know what it does.
295295
/// </summary>
296+
/// <param name="logits_i"></param>
296297
/// <param name="lastTokens"></param>
297298
/// <param name="logitBias"></param>
298299
/// <param name="repeatLastTokensCount"></param>
@@ -301,11 +302,11 @@ public LLamaToken Sample(LLamaTokenDataArray candidates, ref float? mirostat_mu,
301302
/// <param name="alphaPresence"></param>
302303
/// <param name="penalizeNL"></param>
303304
/// <returns></returns>
304-
public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
305-
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
306-
bool penalizeNL = true)
305+
public LLamaTokenDataArray ApplyPenalty(int logits_i, IEnumerable<LLamaToken> lastTokens, Dictionary<LLamaToken, float>? logitBias = null,
306+
int repeatLastTokensCount = 64, float repeatPenalty = 1.1f, float alphaFrequency = .0f, float alphaPresence = .0f,
307+
bool penalizeNL = true)
307308
{
308-
var logits = NativeHandle.GetLogits();
309+
var logits = NativeHandle.GetLogitsIth(logits_i);
309310

310311
// Apply params.logit_bias map
311312
if (logitBias is not null)
@@ -348,28 +349,23 @@ public LLamaTokenDataArray ApplyPenalty(IEnumerable<LLamaToken> lastTokens, Dict
348349
/// <summary>
349350
/// </summary>
350351
/// <param name="batch"></param>
351-
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
352-
/// - 0: success<br />
353-
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
354-
/// - &lt; 0: error<br />
355-
/// </returns>
356-
public int Decode(LLamaBatch batch)
352+
public DecodeResult Decode(LLamaBatch batch)
357353
{
358-
return NativeHandle.Decode(batch);
354+
if (batch.TokenCount == 0)
355+
return 0;
356+
if (batch.TokenCount > Params.BatchSize)
357+
throw new ArgumentException("Input contains more tokens than configured batch size", nameof(batch));
358+
359+
return (DecodeResult)NativeHandle.Decode(batch);
359360
}
360361

361362
/// <summary>
362363
/// </summary>
363364
/// <param name="batch"></param>
364365
/// <param name="cancellationToken"></param>
365-
/// <returns>Positive return values does not mean a fatal error, but rather a warning:<br />
366-
/// - 0: success<br />
367-
/// - 1: could not find a KV slot for the batch (try reducing the size of the batch or increase the context)<br />
368-
/// - &lt; 0: error<br />
369-
/// </returns>
370-
public Task<int> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
366+
public Task<DecodeResult> DecodeAsync(LLamaBatch batch, CancellationToken cancellationToken = default)
371367
{
372-
return Task.Run(() => NativeHandle.Decode(batch), cancellationToken);
368+
return Task.Run(() => Decode(batch), cancellationToken);
373369
}
374370

375371
/// <summary>

LLama/LLamaInstructExecutor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -216,7 +216,7 @@ protected override Task InferInternal(IInferenceParams inferenceParams, InferSta
216216
}
217217
else
218218
{
219-
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
219+
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
220220
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
221221

222222
var mu = MirostatMu;

LLama/LLamaInteractExecutor.cs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -195,7 +195,7 @@ protected override async Task InferInternal(IInferenceParams inferenceParams, In
195195
}
196196
else
197197
{
198-
var tokenDataArray = Context.ApplyPenalty(_last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
198+
var tokenDataArray = Context.ApplyPenalty(0, _last_n_tokens, inferenceParams.LogitBias, repeat_last_n,
199199
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
200200

201201
var mu = MirostatMu;

LLama/LLamaStatelessExecutor.cs

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
using System.Linq;
66
using System.Runtime.CompilerServices;
77
using System.Threading;
8-
using System.Threading.Tasks;
8+
using LLama.Exceptions;
99
using LLama.Native;
1010
using LLama.Sampling;
1111
using Microsoft.Extensions.Logging;
@@ -22,6 +22,7 @@ public class StatelessExecutor
2222
private readonly LLamaWeights _weights;
2323
private readonly IContextParams _params;
2424
private readonly ILogger? _logger;
25+
private readonly LLamaBatch _batch;
2526

2627
/// <summary>
2728
/// The context used by the executor when running the inference.
@@ -39,6 +40,7 @@ public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger?
3940
_weights = weights;
4041
_params = @params;
4142
_logger = logger;
43+
_batch = new LLamaBatch(1);
4244

4345
Context = _weights.CreateContext(_params, logger);
4446
Context.Dispose();
@@ -71,16 +73,29 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
7173
var repeat_last_n = Math.Max(0, inferenceParams.RepeatLastTokensCount <0 ? _weights.ContextSize : inferenceParams.RepeatLastTokensCount);
7274
var lastTokens = new List<LLamaToken>(repeat_last_n);
7375
for (var i = 0; i < repeat_last_n; i++)
74-
lastTokens.Add((LLamaToken)0);
76+
lastTokens.Add(0);
7577

7678
// Tokenize the prompt
7779
var tokens = Context.Tokenize(prompt).ToList();
7880
lastTokens.AddRange(tokens);
79-
var n_past = 1 + tokens.Count;
8081

81-
// Evaluate the prompt
82-
await Task.Run(() => { Context.Eval(tokens, 1); }, cancellationToken)
83-
.ConfigureAwait(false);
82+
// Evaluate the prompt, in chunks smaller than the max batch size
83+
var n_past = 0;
84+
var batchSize = (int)Context.Params.BatchSize;
85+
for (var i = 0; i < tokens.Count; i += batchSize)
86+
{
87+
var n_eval = tokens.Count - i;
88+
if (n_eval > batchSize)
89+
n_eval = batchSize;
90+
91+
_batch.Clear();
92+
for (var j = 0; j < n_eval; j++)
93+
_batch.Add(tokens[i + j], n_past++, LLamaSeqId.Zero, (i + j) == tokens.Count - 1);
94+
95+
var returnCode = await Context.DecodeAsync(_batch, cancellationToken);
96+
if (returnCode != 0)
97+
throw new LLamaDecodeError(returnCode);
98+
}
8499

85100
// Begin loop, evaluating one token at a time
86101
var mu = (float?)null;
@@ -90,12 +105,12 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
90105
LLamaToken id;
91106
if (inferenceParams.SamplingPipeline is not null)
92107
{
93-
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogits(), lastTokens);
108+
id = inferenceParams.SamplingPipeline.Sample(Context.NativeHandle, Context.NativeHandle.GetLogitsIth(_batch.TokenCount - 1), lastTokens);
94109
}
95110
else
96111
{
97112
// Penalize the generated tokens by various penalties
98-
var tokenDataArray = Context.ApplyPenalty(lastTokens, inferenceParams.LogitBias, repeat_last_n,
113+
var tokenDataArray = Context.ApplyPenalty(_batch.TokenCount - 1, lastTokens, inferenceParams.LogitBias, repeat_last_n,
99114
inferenceParams.RepeatPenalty, inferenceParams.FrequencyPenalty, inferenceParams.PresencePenalty, inferenceParams.PenalizeNL);
100115

101116
// Sample a single token
@@ -136,9 +151,12 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
136151
n_past -= n_discard;
137152
}
138153

139-
// ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
140-
n_past = await Task.Run(() => Context.Eval(tokens, n_past), cancellationToken)
141-
.ConfigureAwait(false);
154+
// Evaluate with this new token
155+
_batch.Clear();
156+
_batch.Add(id, n_past++, LLamaSeqId.Zero, true);
157+
var returnCode = await context.DecodeAsync(_batch, cancellationToken);
158+
if (returnCode != 0)
159+
throw new LLamaDecodeError(returnCode);
142160
}
143161
}
144162
}

LLama/Native/DecodeResult.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,22 @@
1+
namespace LLama.Native;
2+
3+
/// <summary>
4+
/// Return codes from llama_decode
5+
/// </summary>
6+
public enum DecodeResult
7+
{
8+
/// <summary>
9+
/// An unspecified error
10+
/// </summary>
11+
Error = -1,
12+
13+
/// <summary>
14+
/// Ok.
15+
/// </summary>
16+
Ok = 0,
17+
18+
/// <summary>
19+
/// Could not find a KV slot for the batch (try reducing the size of the batch or increase the context)
20+
/// </summary>
21+
NoKvSlot = 1,
22+
}

0 commit comments

Comments
 (0)