Skip to content

Commit

Permalink
- Implemented Tokenizer.IDsToTokens()
Browse files Browse the repository at this point in the history
- TempFixedAllocator's memory is now 128b-aligned

- Moved output data structures to Output directory
  • Loading branch information
budgetdevv committed Oct 11, 2024
1 parent d44b3f0 commit 907997a
Show file tree
Hide file tree
Showing 11 changed files with 191 additions and 10 deletions.
1 change: 1 addition & 0 deletions Codegen/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
using System.Threading;
using Tokenizers.NET;
using Tokenizers.NET.Collections;
using Tokenizers.NET.Outputs;

namespace Codegen
{
Expand Down
30 changes: 29 additions & 1 deletion Native/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
use std::string::String;
use std::marker::PhantomData;
use std::ptr::{ null, null_mut };
use std::slice;
use tokenizers::tokenizer::Tokenizer;
use tokenizers::Encoding;

// #[inline(always)] is used aggressively - Realistically we only have a few callsites.

#[repr(C)]
Expand Down Expand Up @@ -445,6 +445,34 @@ pub unsafe extern "C" fn tokenizer_decode_core(
return DecodeOutput::from_text(text);
}

#[no_mangle]
#[inline(always)]
pub unsafe extern "C" fn ids_to_tokens(
tokenizer_ptr: *mut Tokenizer,
id_buffer: NativeBuffer<u32>,
token_buffer: NativeBuffer<NativeBuffer<u8>>)
-> *mut DropHandle<Vec<String>>
{
let tokenizer = &*tokenizer_ptr;

let mut token_buffers = Vec::with_capacity(id_buffer.length);

let mut current_token_ptr = token_buffer.ptr.mutable;

for id in id_buffer.as_slice()
{
let mut token = tokenizer.id_to_token(*id).unwrap();

*current_token_ptr = NativeBuffer::from_mutable_vec(token.as_mut_vec());

current_token_ptr = current_token_ptr.add(1);

token_buffers.push(token);
}

return DropHandle::from_value_and_allocate_box(token_buffers);
}

#[no_mangle]
#[inline(always)]
pub unsafe extern "C" fn free_with_handle(handle: *mut DropHandle<()>)
Expand Down
2 changes: 1 addition & 1 deletion Sample/Program.cs
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ private static void Main(string[] args)

foreach (var token in outputSpan)
{
const bool TEST_OVERFLOW = true;
const bool TEST_OVERFLOW = false;

if (TEST_OVERFLOW)
{
Expand Down
29 changes: 29 additions & 0 deletions Tests/DecodeTests.cs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
using System.Text;
using Allure.NUnit;
using FluentAssertions;
using Tokenizers.NET;
Expand Down Expand Up @@ -107,5 +108,33 @@ public void DecodeMutatingStressTest()
x.Should().Be(text);
}
}

[Test]
public void IDsToTokens()
{
ref var tokenizer = ref FlorenceTokenizer;

const nuint MAX_VALUE = 500;

var stringBuilder = new StringBuilder();

for (nuint i = 1; i <= MAX_VALUE; i++)
{
var text = AllocateStringWithRandomChars((int) i);

using var tokenizeResult = tokenizer.Tokenize(text, addSpecialTokens: false);

var tokens = tokenizer.IDsToTokens(tokenizeResult.IDs);

foreach (var token in tokens)
{
stringBuilder.Append(token.Replace('Ġ', ' '));
}

stringBuilder.ToString().Should().Be(text);

stringBuilder.Clear();
}
}
}
}
1 change: 1 addition & 0 deletions Tests/EncodeTests.cs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
using FluentAssertions;
using Tokenizers.NET;
using Tokenizers.NET.Collections;
using Tokenizers.NET.Outputs;

namespace Tests
{
Expand Down
6 changes: 6 additions & 0 deletions Tokenizers.NET/Helpers/ThrowHelpers.cs
Original file line number Diff line number Diff line change
Expand Up @@ -32,5 +32,11 @@ public static void UTF8EncodingPirated_GetMaxCharCount_OutOfRange()
{
throw new InvalidOperationException("Too many bytes. The resulting number of chars is larger than what can be returned as an int.");
}

[DoesNotReturn]
public static void IDsToTokens_LengthCheckFailed()
{
throw new ArgumentException("Output Span / Buffer length must be more than or equal to the input length.");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
using System.Text;
using Tokenizers.NET.Collections;

namespace Tokenizers.NET
namespace Tokenizers.NET.Outputs
{
[StructLayout(LayoutKind.Sequential)]
public readonly struct DecodeOutput: IDisposable
Expand Down
14 changes: 14 additions & 0 deletions Tokenizers.NET/Outputs/FreeHandle.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
using System;
using System.Runtime.CompilerServices;

namespace Tokenizers.NET.Outputs
{
public readonly struct FreeHandle(nint handle): IDisposable
{
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public void Dispose()
{
TokenizerNativeMethods.FreeWithHandle(handle);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Tokenizers.NET.Collections;
#if DEBUG
using System.Diagnostics;
#endif

namespace Tokenizers.NET
namespace Tokenizers.NET.Outputs
{
public interface ITokenizeOutput
{
Expand Down
102 changes: 99 additions & 3 deletions Tokenizers.NET/Tokenizer.cs
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
using Tokenizers.NET.Collections;
using Tokenizers.NET.Enumerators;
using Tokenizers.NET.Helpers;
using Tokenizers.NET.Outputs;

namespace Tokenizers.NET
{
Expand Down Expand Up @@ -171,15 +172,28 @@ private static readonly int

private readonly int Count;

// Modern cacheline size is either 64 or 128 bytes,
// reducing cross-cacheline reads for SIMD instructions.
// This should also satisfy the alignment for NativeBuffer<NativeBuffer<byte>>,
// enabling us to reinterpret the memory in IDsToTokens() to avoid allocation.
private const int ALIGNMENT = 128;

static TempFixedAllocator()
{
Debug.Assert(ALIGNMENT % sizeof(NativeBuffer<NativeBuffer<byte>>) == 0);
}

public TempFixedAllocator()
{
var maxExpectedBatches = Config.ExpectedMaxBatches.ToSignedUnchecked();

var buffers = Buffers = AllocationHelpers.AllocatePinnedUninitialized<byte>(
TOTAL_BUFFER_SIZE
var buffers = Buffers = AllocationHelpers.AllocatePinnedUninitializedAligned<byte>(
TOTAL_BUFFER_SIZE,
ALIGNMENT,
out var buffersPtr
);

BuffersPtr = buffers.PinnedArrayToPointer();
BuffersPtr = buffersPtr;

Count = maxExpectedBatches;

Expand Down Expand Up @@ -525,6 +539,88 @@ public DecodeOutput DecodeMutating(NativeBuffer<ulong> ids, bool skipSpecialToke
skipSpecialTokens
);
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public FreeHandle IDsToTokens(NativeBuffer<uint> ids, Span<NativeBuffer<byte>> u8Strings)
{
fixed (NativeBuffer<byte>* ptr = &MemoryMarshal.GetReference(u8Strings))
{
var u8StringsBuffer = new NativeBuffer<NativeBuffer<byte>>(ptr, (nuint) u8Strings.Length);

return IDsToTokens(ids, u8StringsBuffer);
}
}

[MethodImpl(MethodImplOptions.AggressiveInlining)]
public FreeHandle IDsToTokens(
NativeBuffer<uint> ids,
NativeBuffer<NativeBuffer<byte>> tokens,
bool performSizeCheck = true)
{
if (performSizeCheck && tokens.Length < ids.Length)
{
ThrowHelpers.IDsToTokens_LengthCheckFailed();
}

var tokenizerHandle = TokenizerHandle;

return new(TokenizerNativeMethods.IDsToTokens(tokenizerHandle, ids, tokens));
}

public string[] IDsToTokens(NativeBuffer<uint> ids)
{
var tokens = new string[ids.Length];

IDsToTokens(ids, tokens, performSizeCheck: false);

return tokens;
}

public void IDsToTokens(NativeBuffer<uint> ids, Span<string> tokens, bool performSizeCheck = true)
{
var inputLength = ids.Length;

if (performSizeCheck && (nuint) tokens.Length < inputLength)
{
ThrowHelpers.IDsToTokens_LengthCheckFailed();
}

var allocationSizeInBytes = (int) inputLength * sizeof(NativeBuffer<NativeBuffer<byte>>);

var allocateNative = allocationSizeInBytes > (Config.ExpectedMaxInputLength * Config.ExpectedMaxBatches);

NativeBuffer<NativeBuffer<byte>> allocation;

if (!allocateNative)
{
var ptr = Allocator.GetFullAllocationUnsafely().Ptr;

allocation = new((NativeBuffer<byte>*) ptr, inputLength);
}

else
{
allocation = new NativeMemory<NativeBuffer<byte>>(inputLength).Buffer;
}

using var freeHandle = IDsToTokens(ids, allocation, performSizeCheck: false);

ref var currentToken = ref MemoryMarshal.GetReference(tokens);

foreach (var buffer in allocation)
{
// In theory, we could intern the tokenizer's vocab and greatly reduce string allocs,
// but it is what it is for now...
currentToken = Encoding.UTF8.GetString(buffer.Ptr, (int) buffer.Length);

currentToken = ref Unsafe.Add(ref currentToken, 1);

Check warning on line 616 in Tokenizers.NET/Tokenizer.cs

View workflow job for this annotation

GitHub Actions / Test and Publish NuGet

Nullability of reference types in value of type 'string' doesn't match target type 'string?'.

Check warning on line 616 in Tokenizers.NET/Tokenizer.cs

View workflow job for this annotation

GitHub Actions / Test and Publish NuGet

Nullability of reference types in value of type 'string' doesn't match target type 'string?'.
}

if (allocateNative)
{
NativeMemory<NativeBuffer<byte>>.FreeWithPtrUnsafely(allocation.Ptr);
}
}

public void Dispose()
{
Expand Down
9 changes: 9 additions & 0 deletions Tokenizers.NET/TokenizerNativeMethods.cs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using Tokenizers.NET.Collections;
using Tokenizers.NET.Outputs;

namespace Tokenizers.NET
{
Expand Down Expand Up @@ -118,6 +119,14 @@ public static DecodeOutput TokenizerDecode(
[LibraryImport(DLL_NAME, EntryPoint = "tokenizer_decode_skip_special_tokens")]
private static partial DecodeOutput TokenizerDecodeSkipSpecialTokens(nint tokenizerPtr, NativeBuffer<uint> idBuffer);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[LibraryImport(DLL_NAME, EntryPoint = "ids_to_tokens")]
public static partial nint IDsToTokens(
nint tokenizerPtr,
NativeBuffer<uint> idBuffer,
NativeBuffer<NativeBuffer<byte>> tokenBuffer
);

[MethodImpl(MethodImplOptions.AggressiveInlining)]
[LibraryImport(DLL_NAME, EntryPoint = "free_with_handle")]
public static partial void FreeWithHandle(nint handle);
Expand Down

0 comments on commit 907997a

Please sign in to comment.