Skip to content

Commit

Permalink
Fixed decoding of large tokens (over 16 bytes) in streaming text decoder
Browse files Browse the repository at this point in the history
  • Loading branch information
martindevans committed Jan 9, 2024
1 parent 54dffe7 commit 98635a0
Show file tree
Hide file tree
Showing 4 changed files with 61 additions and 8 deletions.
53 changes: 53 additions & 0 deletions LLama.Unittest/StreamingTextDecoderTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
using System.Text;
using LLama.Common;
using Xunit.Abstractions;

namespace LLama.Unittest;

public class StreamingTextDecoderTests
: IDisposable
{
private readonly LLamaWeights _model;
private readonly ITestOutputHelper _testOutputHelper;
private readonly ModelParams _params;

public StreamingTextDecoderTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;
_params = new ModelParams(Constants.ModelPath);
_model = LLamaWeights.LoadFromFile(_params);
}

public void Dispose()
{
_model.Dispose();
}

[Fact]
public void DecodesSimpleText()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);

const string text = "The cat sat on the mat";
var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);

foreach (var lLamaToken in tokens)
decoder.Add(lLamaToken);

Assert.Equal(text, decoder.Read().Trim());
}

[Fact]
public void DecodesComplexText()
{
var decoder = new StreamingTokenDecoder(Encoding.UTF8, _model);

const string text = "猫坐在垫子上 😀🤨🤐😏";
var tokens = _model.NativeHandle.Tokenize(text, false, false, Encoding.UTF8);

foreach (var lLamaToken in tokens)
decoder.Add(lLamaToken);

Assert.Equal(text, decoder.Read().Trim());
}
}
2 changes: 1 addition & 1 deletion LLama/Native/SafeLLamaContextHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public LLamaToken[] Tokenize(string text, bool add_bos, bool special, Encoding e
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(LLamaToken token, Span<byte> dest)
public uint TokenToSpan(LLamaToken token, Span<byte> dest)
{
return ThrowIfDisposed().TokenToSpan(token, dest);
}
Expand Down
4 changes: 2 additions & 2 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -126,10 +126,10 @@ public void ApplyLoraFromFile(string lora, float scale, string? modelBase = null
/// <param name="token">Token to decode</param>
/// <param name="dest">A span to attempt to write into. If this is too small nothing will be written</param>
/// <returns>The size of this token. **nothing will be written** if this is larger than `dest`</returns>
public int TokenToSpan(LLamaToken token, Span<byte> dest)
public uint TokenToSpan(LLamaToken token, Span<byte> dest)
{
var length = NativeApi.llama_token_to_piece(this, token, dest);
return Math.Abs(length);
return (uint)Math.Abs(length);
}

/// <summary>
Expand Down
10 changes: 5 additions & 5 deletions LLama/StreamingTokenDecoder.cs
Original file line number Diff line number Diff line change
Expand Up @@ -113,19 +113,19 @@ static Span<byte> TokenToBytes(ref byte[] bytes, LLamaToken token, SafeLlamaMode
// Try to get bytes
var l = model.TokenToSpan(token, bytes);

// Negative length indicates that the output was too small. Expand it to twice that size and try again.
if (l < 0)
// Check if the length was larger than the buffer. If so expand the buffer and try again
if (l > bytes.Length)
{
// Return the old array to the pool and get a new one
ArrayPool<byte>.Shared.Return(bytes);
bytes = ArrayPool<byte>.Shared.Rent(-l * 2);
bytes = ArrayPool<byte>.Shared.Rent((int)(l * 2));

// Get bytes, this time it can't fail
l = model.TokenToSpan(token, bytes);
}

Debug.Assert(l >= 0);
return new Span<byte>(bytes, 0, l);
Debug.Assert(l <= bytes.Length);
return new Span<byte>(bytes, 0, (int)l);
}
}

Expand Down

0 comments on commit 98635a0

Please sign in to comment.