Skip to content

Commit

Permalink
Merge pull request #447 from martindevans/grow_nseqmax_batch
Browse files Browse the repository at this point in the history
LLamaBatch Grow n_seq_max automatically
  • Loading branch information
martindevans authored Jan 21, 2024
2 parents 892e841 + 9fe878a commit 0074320
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 8 deletions.
2 changes: 1 addition & 1 deletion LLama.Examples/Examples/BatchedDecoding.cs
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ public static async Task Run()
return;
}

var batch = new LLamaBatch(1);
var batch = new LLamaBatch();

// evaluate the initial prompt
for (var i = 0; i < prompt_tokens.Length; i++)
Expand Down
24 changes: 17 additions & 7 deletions LLama/Native/LLamaBatch.cs
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
using System;
using System.Collections.Generic;

namespace LLama.Native;

Expand Down Expand Up @@ -35,11 +34,11 @@ public class LLamaBatch
/// <summary>
/// Create a new batch for submitting inputs to llama.cpp
/// </summary>
/// <param name="n_seq_max">Max number of sequences a token can be assigned to</param>
public LLamaBatch(int n_seq_max)
public LLamaBatch()
{
// The number of tokens can be grown later, start off with a reasonable guess.
const int n_tokens = 64;
// These can both be grown later, start off with reasonable numbers.
const int n_tokens = 128;
const int n_seq_max = 1;

MaxSequences = n_seq_max;
TokenCapacity = n_tokens;
Expand All @@ -56,7 +55,7 @@ public LLamaBatch(int n_seq_max)
_sequenceIds[i] = new LLamaSeqId[MaxSequences];
}

private void Grow()
private void GrowTokenCapacity()
{
var n_tokens = TokenCount * 2;
TokenCapacity = n_tokens;
Expand All @@ -78,6 +77,15 @@ private void Grow()
}
}

private void GrowMaxSequences(int atLeast)
{
var n_seq = Math.Max(MaxSequences * 2, atLeast);
MaxSequences = n_seq;

for (var i = 0; i < _sequenceIds.Length; i++)
Array.Resize(ref _sequenceIds[i], MaxSequences);
}

internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
{
// This group holds all of the memory pins
Expand Down Expand Up @@ -120,7 +128,9 @@ internal GroupDisposable ToNativeBatch(out LLamaNativeBatch batch)
public void Add(LLamaToken token, LLamaPos pos, ReadOnlySpan<LLamaSeqId> sequences, bool logits)
{
if (TokenCount == TokenCapacity)
Grow();
GrowTokenCapacity();
if (sequences.Length > MaxSequences)
GrowMaxSequences(sequences.Length);

_tokens[TokenCount] = token;
_positions[TokenCount] = pos;
Expand Down

0 comments on commit 0074320

Please sign in to comment.