Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Conversation Generate All Logits #743

Merged
merged 1 commit into from
May 19, 2024
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 46 additions & 25 deletions LLama/Batched/Conversation.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System;
using System;
using System.Buffers;
using System.Collections.Generic;
using System.Linq;
using System.Runtime.InteropServices;
using System.Text.Json;
using LLama.Native;
Expand All @@ -15,10 +16,19 @@ public sealed class Conversation
{
private ulong _requiredEpoch;
private LLamaPos _end;
private int _batchSampleIndex;
private bool _disposed;

/// <summary>
/// Indicates if this conversation has been "forked" and may share logits with another conversation.
/// </summary>
private bool _forked;

/// <summary>
/// Stores the indices to sample from. Contains <see cref="_batchSampleCount"/> valid items.
/// </summary>
private int[] _batchSampleIndices = new int[4];
private int _batchSampleCount;

/// <summary>
/// The executor which this conversation belongs to
/// </summary>
Expand Down Expand Up @@ -108,7 +118,8 @@ public Conversation Fork()
// logits, so sampling one conversation may mess up the fork! Setting the "forked" flag on both sequences ensures
// they both copy the logits before the next sampling run, to fix this issue.
_requiredEpoch = _requiredEpoch,
_batchSampleIndex = _batchSampleIndex,
_batchSampleIndices = _batchSampleIndices.ToArray(),
_batchSampleCount = _batchSampleCount,
_forked = true,

_end = _end,
Expand All @@ -128,20 +139,24 @@ public Conversation Fork()
/// <summary>
/// Get the logits from this conversation, ready for sampling
/// </summary>
/// <param name="offset">How far from the <b>end</b> of the previous prompt should logits be sampled. Any value other than 0 requires allLogits to have been set during prompting</param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="CannotSampleRequiresPromptException">Thrown if this conversation was not prompted before the previous call to infer</exception>
/// <exception cref="CannotSampleRequiresInferenceException">Thrown if Infer() must be called on the executor</exception>
public Span<float> Sample()
public Span<float> Sample(int offset = 0)
{
AssertNotDisposed();

if (_requiredEpoch < Executor.Epoch)
throw new CannotSampleRequiresPromptException();
if (_requiredEpoch > Executor.Epoch)
throw new CannotSampleRequiresInferenceException();

var span = Executor.Context.NativeHandle.GetLogitsIth(_batchSampleIndex);
if (offset >= _batchSampleCount)
throw new ArgumentException("Cannot sample offset more than the previous prompt count", nameof(offset));

var index = _batchSampleIndices[_batchSampleCount - offset - 1];
var span = Executor.Context.NativeHandle.GetLogitsIth(index);

// If necessary copy the span, to protect it from modification. This is only done when
// this conversation has been forked in this epoch.
Expand All @@ -161,33 +176,21 @@ private void AssertCanBePrompted()
throw new AlreadyPromptedConversationException();
}

/// <summary>
/// Add tokens to this conversation
/// </summary>
/// <param name="input"></param>
/// <returns></returns>
[Obsolete("Tokenize the text and pass the tokens instead")]
public void Prompt(string input, bool addBos, bool special)
{
AssertCanBePrompted();

Prompt(Executor.Context.Tokenize(input, addBos, special));
}

/// <summary>
/// Add tokens to this conversation
/// </summary>
/// <param name="tokens"></param>
/// <param name="allLogits">If true, generate logits for all tokens. If false, only generate logits for the last token.</param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="AlreadyPromptedConversationException"></exception>
public void Prompt(List<LLamaToken> tokens)
public void Prompt(List<LLamaToken> tokens, bool allLogits = false)
{
AssertCanBePrompted();

#if NET6_0_OR_GREATER
var span = CollectionsMarshal.AsSpan(tokens);
Prompt(span);
Prompt(span, allLogits);
#else
// Borrow an array and copy tokens into it
var arr = ArrayPool<LLamaToken>.Shared.Rent(tokens.Count);
Expand All @@ -204,15 +207,16 @@ public void Prompt(List<LLamaToken> tokens)
}
#endif
}

/// <summary>
/// Add tokens to this conversation
/// </summary>
/// <param name="tokens"></param>
/// <param name="allLogits">If true, generate logits for all tokens. If false, only generate logits for the last token.</param>
/// <returns></returns>
/// <exception cref="ObjectDisposedException"></exception>
/// <exception cref="AlreadyPromptedConversationException"></exception>
public void Prompt(ReadOnlySpan<LLamaToken> tokens)
public void Prompt(ReadOnlySpan<LLamaToken> tokens, bool allLogits = false)
{
AssertCanBePrompted();

Expand All @@ -221,8 +225,25 @@ public void Prompt(ReadOnlySpan<LLamaToken> tokens)
return;

// Add the prompt to the batch
for (var i = 0; i < tokens.Length; i++)
_batchSampleIndex = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
if (allLogits)
{
if (_batchSampleIndices.Length < tokens.Length)
_batchSampleIndices = new int[tokens.Length];

_batchSampleCount = tokens.Length;

for (var i = 0; i < tokens.Length; i++)
_batchSampleIndices[i] = Executor.Batch.Add(tokens[i], _end++, ConversationId, true);
}
else
{
_batchSampleCount = 1;

for (var i = 0; i < tokens.Length; i++)
_batchSampleIndices[0] = Executor.Batch.Add(tokens[i], _end++, ConversationId, i == tokens.Length - 1);
}



// Mark this conversation as needing inference/sampling
_requiredEpoch = Executor.Epoch + 1;
Expand Down
Loading