Skip to content

Commit

Permalink
Changed based on review feedback.
Browse files Browse the repository at this point in the history
  • Loading branch information
martindevans committed May 13, 2024
1 parent 60ddd44 commit b80f043
Showing 1 changed file with 8 additions and 11 deletions.
19 changes: 8 additions & 11 deletions LLama/Native/LLamaTokenDataArray.cs
Original file line number Diff line number Diff line change
Expand Up @@ -37,27 +37,25 @@ public LLamaTokenDataArray(Memory<LLamaTokenData> tokens, bool isSorted = false)
/// <returns></returns>
public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits)
{
var candidates = new LLamaTokenData[logits.Length];
for (var token = 0; token < logits.Length; token++)
candidates[token] = new LLamaTokenData(token, logits[token], 0.0f);

return new LLamaTokenDataArray(candidates);
return Create(logits, new LLamaTokenData[logits.Length]);
}

/// <summary>
/// Create a new LLamaTokenDataArray, copying the data from the given logits into temporary memory.
/// </summary>
/// <remarks>The memory must not be modified while this <see cref="LLamaTokenDataArray"/> is in use.</remarks>
/// <param name="logits"></param>
/// <param name="temporary">Temporary memory which will be used to work on these logits. Must be at least as large as logits array</param>
/// <param name="buffer">Temporary memory which will be used to work on these logits. Must be at least as large as logits array</param>
/// <returns></returns>
public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits, Memory<LLamaTokenData> temporary)
public static LLamaTokenDataArray Create(ReadOnlySpan<float> logits, Memory<LLamaTokenData> buffer)
{
if (temporary.Length < logits.Length)
if (buffer.Length < logits.Length)
throw new ArgumentException("temporary memory is shorter than logits span");
var candidates = temporary.Slice(0, logits.Length);

// take a slice of the output buffer which is exactly the size we need.
var candidates = buffer.Slice(0, logits.Length);
var candidatesSpan = candidates.Span;

for (var token = 0; token < logits.Length; token++)
candidatesSpan[token] = new LLamaTokenData(token, logits[token], 0.0f);

Expand Down Expand Up @@ -361,7 +359,7 @@ public struct LLamaTokenDataArrayNative
/// <summary>
/// A pointer to an array of LlamaTokenData
/// </summary>
/// <remarks>Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use</remarks>
/// <remarks>Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use (i.e. `fixed` or `.Pin()`)</remarks>
private unsafe LLamaTokenData* _data;

/// <summary>
Expand All @@ -372,7 +370,6 @@ public struct LLamaTokenDataArrayNative
/// <summary>
/// A pointer to an array of LlamaTokenData
/// </summary>
/// <remarks>Memory must be pinned in place for all the time this LLamaTokenDataArrayNative is in use</remarks>
public Span<LLamaTokenData> data
{
get
Expand Down

0 comments on commit b80f043

Please sign in to comment.