-
Notifications
You must be signed in to change notification settings - Fork 366
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Removed
LLamaBatchSafeHandle
(using unmanaged memory, created by ll…
…ama.cpp) and replaced it with a fully managed `LLamaBatch`. Modified the `BatchedDecoding` example to use new managed batch.
- Loading branch information
1 parent
4b11fed
commit 36a9335
Showing
6 changed files
with
137 additions
and
176 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
using System; | ||
|
||
namespace LLama.Native; | ||
|
||
/// <summary> | ||
/// A batch allows submitting multiple tokens to multiple sequences simultaneously | ||
/// </summary> | ||
public class LLamaBatch | ||
{ | ||
private readonly byte[] _logits; | ||
|
||
private readonly LLamaToken[] _tokens; | ||
private readonly LLamaPos[] _positions; | ||
|
||
private readonly int[] _sequenceIdCount; | ||
private readonly LLamaSeqId[][] _sequenceIds; | ||
private readonly IntPtr[] _sequenceIdsPtrs; | ||
|
||
/// <summary> | ||
/// The number of tokens in this batch | ||
/// </summary> | ||
public int TokenCount { get; private set; } | ||
|
||
/// <summary> | ||
/// Create a new batch for submitting inputs to llama.cpp | ||
/// </summary> | ||
/// <param name="n_tokens"></param> | ||
/// <param name="n_seq_max"></param> | ||
public LLamaBatch(int n_tokens, int n_seq_max) | ||
{ | ||
_logits = new byte[n_tokens]; | ||
_tokens = new LLamaToken[n_tokens]; | ||
_positions = new LLamaPos[n_tokens]; | ||
|
||
_sequenceIdCount = new int[n_tokens]; | ||
_sequenceIdsPtrs = new IntPtr[_sequenceIdCount.Length]; | ||
|
||
_sequenceIds = new LLamaSeqId[n_tokens][]; | ||
for (var i = 0; i < _sequenceIds.Length; i++) | ||
_sequenceIds[i] = new LLamaSeqId[n_seq_max]; | ||
} | ||
|
||
internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch) | ||
{ | ||
// This group holds all of the memory pins | ||
var group = new GroupDisposable(); | ||
|
||
unsafe | ||
{ | ||
batch = new LLamaNativeBatch | ||
{ | ||
n_tokens = TokenCount, | ||
logits = (byte*)group.Add(_logits.AsMemory().Pin()).Pointer, | ||
|
||
n_seq_id = (int*)group.Add(_sequenceIdCount.AsMemory().Pin()).Pointer, | ||
pos = (LLamaPos*)group.Add(_positions.AsMemory().Pin()).Pointer, | ||
seq_id = (LLamaSeqId**)group.Add(_sequenceIdsPtrs.AsMemory().Pin()).Pointer, | ||
|
||
// embd is not currently supported, so this is always null! | ||
embd = null, | ||
|
||
// Note that if embd is **not null** then this will be null! | ||
tokens = (LLamaToken*)group.Add(_tokens.AsMemory().Pin()).Pointer, | ||
}; | ||
|
||
// Create pointers to each of the arrays in turns | ||
for (var i = 0; i < _sequenceIdsPtrs.Length; i++) | ||
_sequenceIdsPtrs[i] = (IntPtr)group.Add(_sequenceIds[i].AsMemory().Pin()).Pointer; | ||
} | ||
|
||
return group; | ||
} | ||
|
||
/// <summary> | ||
/// Add a single token to the batch at the same position in several sequences | ||
/// </summary> | ||
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks> | ||
/// <param name="token">The token to add</param> | ||
/// <param name="pos">The position to add it att</param> | ||
/// <param name="sequences">The set of sequences to add this token to</param> | ||
/// <param name="logits"></param> | ||
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits) | ||
{ | ||
_tokens[TokenCount] = token; | ||
_positions[TokenCount] = pos; | ||
|
||
_sequenceIdCount[TokenCount] = sequences.Length; | ||
for (var i = 0; i < sequences.Length; i++) | ||
_sequenceIds[TokenCount][i] = sequences[i]; | ||
|
||
_logits[TokenCount] = Convert.ToByte(logits); | ||
|
||
TokenCount++; | ||
} | ||
|
||
/// <summary> | ||
/// Add a single token to the batch at a certain position for a single sequences | ||
/// </summary> | ||
/// <remarks>https://github.com/ggerganov/llama.cpp/blob/ad939626577cd25b462e8026cc543efb71528472/common/common.cpp#L829C2-L829C2</remarks> | ||
/// <param name="token">The token to add</param> | ||
/// <param name="pos">The position to add it att</param> | ||
/// <param name="sequence">The sequence to add this token to</param> | ||
/// <param name="logits"></param> | ||
public void LLamaBatchAdd(LLamaToken token, LLamaPos pos, LLamaSeqId sequence, bool logits) | ||
{ | ||
// Create a temporary span to contain 1 item without allocating | ||
Span<LLamaSeqId> sequences = stackalloc LLamaSeqId[1]; | ||
sequences[0] = sequence; | ||
|
||
// Add it | ||
LLamaBatchAdd(token, pos, sequences, logits); | ||
} | ||
|
||
/// <summary> | ||
/// Set TokenCount to zero for this batch | ||
/// </summary> | ||
public void LLamaBatchClear() | ||
{ | ||
TokenCount = 0; | ||
} | ||
} |
This file was deleted.
Oops, something went wrong.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters