5
5
using System . Linq ;
6
6
using System . Runtime . CompilerServices ;
7
7
using System . Threading ;
8
- using System . Threading . Tasks ;
8
+ using LLama . Exceptions ;
9
9
using LLama . Native ;
10
10
using LLama . Sampling ;
11
11
using Microsoft . Extensions . Logging ;
@@ -22,6 +22,7 @@ public class StatelessExecutor
22
22
private readonly LLamaWeights _weights ;
23
23
private readonly IContextParams _params ;
24
24
private readonly ILogger ? _logger ;
25
+ private readonly LLamaBatch _batch ;
25
26
26
27
/// <summary>
27
28
/// The context used by the executor when running the inference.
@@ -39,6 +40,7 @@ public StatelessExecutor(LLamaWeights weights, IContextParams @params, ILogger?
39
40
_weights = weights ;
40
41
_params = @params ;
41
42
_logger = logger ;
43
+ _batch = new LLamaBatch ( 1 ) ;
42
44
43
45
Context = _weights . CreateContext ( _params , logger ) ;
44
46
Context . Dispose ( ) ;
@@ -71,16 +73,29 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
71
73
var repeat_last_n = Math . Max ( 0 , inferenceParams . RepeatLastTokensCount < 0 ? _weights . ContextSize : inferenceParams . RepeatLastTokensCount ) ;
72
74
var lastTokens = new List < LLamaToken > ( repeat_last_n ) ;
73
75
for ( var i = 0 ; i < repeat_last_n ; i ++ )
74
- lastTokens . Add ( ( LLamaToken ) 0 ) ;
76
+ lastTokens . Add ( 0 ) ;
75
77
76
78
// Tokenize the prompt
77
79
var tokens = Context . Tokenize ( prompt ) . ToList ( ) ;
78
80
lastTokens . AddRange ( tokens ) ;
79
- var n_past = 1 + tokens . Count ;
80
81
81
- // Evaluate the prompt
82
- await Task . Run ( ( ) => { Context . Eval ( tokens , 1 ) ; } , cancellationToken )
83
- . ConfigureAwait ( false ) ;
82
+ // Evaluate the prompt, in chunks smaller than the max batch size
83
+ var n_past = 0 ;
84
+ var batchSize = ( int ) Context . Params . BatchSize ;
85
+ for ( var i = 0 ; i < tokens . Count ; i += batchSize )
86
+ {
87
+ var n_eval = tokens . Count - i ;
88
+ if ( n_eval > batchSize )
89
+ n_eval = batchSize ;
90
+
91
+ _batch . Clear ( ) ;
92
+ for ( var j = 0 ; j < n_eval ; j ++ )
93
+ _batch . Add ( tokens [ i + j ] , n_past ++ , LLamaSeqId . Zero , ( i + j ) == tokens . Count - 1 ) ;
94
+
95
+ var returnCode = await Context . DecodeAsync ( _batch , cancellationToken ) ;
96
+ if ( returnCode != 0 )
97
+ throw new LLamaDecodeError ( returnCode ) ;
98
+ }
84
99
85
100
// Begin loop, evaluating one token at a time
86
101
var mu = ( float ? ) null ;
@@ -90,12 +105,12 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
90
105
LLamaToken id ;
91
106
if ( inferenceParams . SamplingPipeline is not null )
92
107
{
93
- id = inferenceParams . SamplingPipeline . Sample ( Context . NativeHandle , Context . NativeHandle . GetLogits ( ) , lastTokens ) ;
108
+ id = inferenceParams . SamplingPipeline . Sample ( Context . NativeHandle , Context . NativeHandle . GetLogitsIth ( _batch . TokenCount - 1 ) , lastTokens ) ;
94
109
}
95
110
else
96
111
{
97
112
// Penalize the generated tokens by various penalties
98
- var tokenDataArray = Context . ApplyPenalty ( lastTokens , inferenceParams . LogitBias , repeat_last_n ,
113
+ var tokenDataArray = Context . ApplyPenalty ( _batch . TokenCount - 1 , lastTokens , inferenceParams . LogitBias , repeat_last_n ,
99
114
inferenceParams . RepeatPenalty , inferenceParams . FrequencyPenalty , inferenceParams . PresencePenalty , inferenceParams . PenalizeNL ) ;
100
115
101
116
// Sample a single token
@@ -136,9 +151,12 @@ public async IAsyncEnumerable<string> InferAsync(string prompt, IInferenceParams
136
151
n_past -= n_discard ;
137
152
}
138
153
139
- // ReSharper disable once AccessToModifiedClosure (Justification: n_past is modified inside and outside the capture, but not concurrently)
140
- n_past = await Task . Run ( ( ) => Context . Eval ( tokens , n_past ) , cancellationToken )
141
- . ConfigureAwait ( false ) ;
154
+ // Evaluate with this new token
155
+ _batch . Clear ( ) ;
156
+ _batch . Add ( id , n_past ++ , LLamaSeqId . Zero , true ) ;
157
+ var returnCode = await context . DecodeAsync ( _batch , cancellationToken ) ;
158
+ if ( returnCode != 0 )
159
+ throw new LLamaDecodeError ( returnCode ) ;
142
160
}
143
161
}
144
162
}
0 commit comments