diff --git a/LLama/Native/LLamaTokenDataArray.cs b/LLama/Native/LLamaTokenDataArray.cs index dcb0b5a82..1a656c42e 100644 --- a/LLama/Native/LLamaTokenDataArray.cs +++ b/LLama/Native/LLamaTokenDataArray.cs @@ -37,11 +37,7 @@ public LLamaTokenDataArray(Memory tokens, bool isSorted = false) /// public static LLamaTokenDataArray Create(ReadOnlySpan logits) { - var candidates = new LLamaTokenData[logits.Length]; - for (var token = 0; token < logits.Length; token++) - candidates[token] = new LLamaTokenData(token, logits[token], 0.0f); - - return new LLamaTokenDataArray(candidates); + return Create(logits, new LLamaTokenData[logits.Length]); } /// @@ -49,15 +45,17 @@ public static LLamaTokenDataArray Create(ReadOnlySpan logits) /// /// The memory must not be modified while this is in use. /// - /// Temporary memory which will be used to work on these logits. Must be at least as large as logits array + /// Temporary memory which will be used to work on these logits. Must be at least as large as logits array /// - public static LLamaTokenDataArray Create(ReadOnlySpan logits, Memory temporary) + public static LLamaTokenDataArray Create(ReadOnlySpan logits, Memory buffer) { - if (temporary.Length < logits.Length) + if (buffer.Length < logits.Length) throw new ArgumentException("temporary memory is shorter than logits span"); - var candidates = temporary.Slice(0, logits.Length); + // take a slice of the output buffer which is exactly the size we need. + var candidates = buffer.Slice(0, logits.Length); var candidatesSpan = candidates.Span; + for (var token = 0; token < logits.Length; token++) candidatesSpan[token] = new LLamaTokenData(token, logits[token], 0.0f); @@ -361,7 +359,7 @@ public struct LLamaTokenDataArrayNative /// /// A pointer to an array of LlamaTokenData /// - /// Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use + /// Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use (i.e. `fixed` or `.Pin()`) private unsafe LLamaTokenData* _data; /// @@ -372,7 +370,6 @@ public struct LLamaTokenDataArrayNative /// /// A pointer to an array of LlamaTokenData /// - /// Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use public Span data { get