diff --git a/Codegen/Program.cs b/Codegen/Program.cs index 71d63fb..14b91ef 100644 --- a/Codegen/Program.cs +++ b/Codegen/Program.cs @@ -5,6 +5,7 @@ using System.Threading; using Tokenizers.NET; using Tokenizers.NET.Collections; +using Tokenizers.NET.Outputs; namespace Codegen { diff --git a/Native/src/lib.rs b/Native/src/lib.rs index 8ecb6c6..68440a4 100644 --- a/Native/src/lib.rs +++ b/Native/src/lib.rs @@ -1,9 +1,9 @@ +use std::string::String; use std::marker::PhantomData; use std::ptr::{ null, null_mut }; use std::slice; use tokenizers::tokenizer::Tokenizer; use tokenizers::Encoding; - // #[inline(always)] is used aggressively - Realistically we only have a few callsites. #[repr(C)] @@ -445,6 +445,34 @@ pub unsafe extern "C" fn tokenizer_decode_core( return DecodeOutput::from_text(text); } +#[no_mangle] +#[inline(always)] +pub unsafe extern "C" fn ids_to_tokens( + tokenizer_ptr: *mut Tokenizer, + id_buffer: NativeBuffer, + token_buffer: NativeBuffer>) + -> *mut DropHandle> +{ + let tokenizer = &*tokenizer_ptr; + + let mut token_buffers = Vec::with_capacity(id_buffer.length); + + let mut current_token_ptr = token_buffer.ptr.mutable; + + for id in id_buffer.as_slice() + { + let mut token = tokenizer.id_to_token(*id).unwrap(); + + *current_token_ptr = NativeBuffer::from_mutable_vec(token.as_mut_vec()); + + current_token_ptr = current_token_ptr.add(1); + + token_buffers.push(token); + } + + return DropHandle::from_value_and_allocate_box(token_buffers); +} + #[no_mangle] #[inline(always)] pub unsafe extern "C" fn free_with_handle(handle: *mut DropHandle<()>) diff --git a/Sample/Program.cs b/Sample/Program.cs index 5ca8f2b..d83ce46 100644 --- a/Sample/Program.cs +++ b/Sample/Program.cs @@ -47,7 +47,7 @@ private static void Main(string[] args) foreach (var token in outputSpan) { - const bool TEST_OVERFLOW = true; + const bool TEST_OVERFLOW = false; if (TEST_OVERFLOW) { diff --git a/Tests/DecodeTests.cs b/Tests/DecodeTests.cs index 719506e..7e861e6 100644 --- a/Tests/DecodeTests.cs +++ b/Tests/DecodeTests.cs @@ -1,3 +1,4 @@ +using System.Text; using Allure.NUnit; using FluentAssertions; using Tokenizers.NET; @@ -107,5 +108,33 @@ public void DecodeMutatingStressTest() x.Should().Be(text); } } + + [Test] + public void IDsToTokens() + { + ref var tokenizer = ref FlorenceTokenizer; + + const nuint MAX_VALUE = 500; + + var stringBuilder = new StringBuilder(); + + for (nuint i = 1; i <= MAX_VALUE; i++) + { + var text = AllocateStringWithRandomChars((int) i); + + using var tokenizeResult = tokenizer.Tokenize(text, addSpecialTokens: false); + + var tokens = tokenizer.IDsToTokens(tokenizeResult.IDs); + + foreach (var token in tokens) + { + stringBuilder.Append(token.Replace('Ġ', ' ')); + } + + stringBuilder.ToString().Should().Be(text); + + stringBuilder.Clear(); + } + } } } \ No newline at end of file diff --git a/Tests/EncodeTests.cs b/Tests/EncodeTests.cs index 15e41b9..1f86c28 100644 --- a/Tests/EncodeTests.cs +++ b/Tests/EncodeTests.cs @@ -3,6 +3,7 @@ using FluentAssertions; using Tokenizers.NET; using Tokenizers.NET.Collections; +using Tokenizers.NET.Outputs; namespace Tests { diff --git a/Tokenizers.NET/Helpers/ThrowHelpers.cs b/Tokenizers.NET/Helpers/ThrowHelpers.cs index 796d6bd..70e8687 100644 --- a/Tokenizers.NET/Helpers/ThrowHelpers.cs +++ b/Tokenizers.NET/Helpers/ThrowHelpers.cs @@ -32,5 +32,11 @@ public static void UTF8EncodingPirated_GetMaxCharCount_OutOfRange() { throw new InvalidOperationException("Too many bytes. The resulting number of chars is larger than what can be returned as an int."); } + + [DoesNotReturn] + public static void IDsToTokens_LengthCheckFailed() + { + throw new ArgumentException("Output Span / Buffer length must be more than or equal to the input length."); + } } } \ No newline at end of file diff --git a/Tokenizers.NET/DecodeOutput.cs b/Tokenizers.NET/Outputs/DecodeOutput.cs similarity index 95% rename from Tokenizers.NET/DecodeOutput.cs rename to Tokenizers.NET/Outputs/DecodeOutput.cs index 8425e55..088cc99 100644 --- a/Tokenizers.NET/DecodeOutput.cs +++ b/Tokenizers.NET/Outputs/DecodeOutput.cs @@ -4,7 +4,7 @@ using System.Text; using Tokenizers.NET.Collections; -namespace Tokenizers.NET +namespace Tokenizers.NET.Outputs { [StructLayout(LayoutKind.Sequential)] public readonly struct DecodeOutput: IDisposable diff --git a/Tokenizers.NET/Outputs/FreeHandle.cs b/Tokenizers.NET/Outputs/FreeHandle.cs new file mode 100644 index 0000000..8cfeda7 --- /dev/null +++ b/Tokenizers.NET/Outputs/FreeHandle.cs @@ -0,0 +1,14 @@ +using System; +using System.Runtime.CompilerServices; + +namespace Tokenizers.NET.Outputs +{ + public readonly struct FreeHandle(nint handle): IDisposable + { + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public void Dispose() + { + TokenizerNativeMethods.FreeWithHandle(handle); + } + } +} \ No newline at end of file diff --git a/Tokenizers.NET/TokenizeOutput.cs b/Tokenizers.NET/Outputs/TokenizeOutput.cs similarity index 99% rename from Tokenizers.NET/TokenizeOutput.cs rename to Tokenizers.NET/Outputs/TokenizeOutput.cs index 9aca418..dc2f8df 100644 --- a/Tokenizers.NET/TokenizeOutput.cs +++ b/Tokenizers.NET/Outputs/TokenizeOutput.cs @@ -2,11 +2,8 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using Tokenizers.NET.Collections; -#if DEBUG -using System.Diagnostics; -#endif -namespace Tokenizers.NET +namespace Tokenizers.NET.Outputs { public interface ITokenizeOutput { diff --git a/Tokenizers.NET/Tokenizer.cs b/Tokenizers.NET/Tokenizer.cs index 45d5400..a256e9d 100644 --- a/Tokenizers.NET/Tokenizer.cs +++ b/Tokenizers.NET/Tokenizer.cs @@ -9,6 +9,7 @@ using Tokenizers.NET.Collections; using Tokenizers.NET.Enumerators; using Tokenizers.NET.Helpers; +using Tokenizers.NET.Outputs; namespace Tokenizers.NET { @@ -171,15 +172,28 @@ private static readonly int private readonly int Count; + // Modern cacheline size is either 64 or 128 bytes, + // reducing cross-cacheline reads for SIMD instructions. + // This should also satisfy the alignment for NativeBuffer>, + // enabling us to reinterpret the memory in IDsToTokens() to avoid allocation. + private const int ALIGNMENT = 128; + + static TempFixedAllocator() + { + Debug.Assert(ALIGNMENT % sizeof(NativeBuffer>) == 0); + } + public TempFixedAllocator() { var maxExpectedBatches = Config.ExpectedMaxBatches.ToSignedUnchecked(); - var buffers = Buffers = AllocationHelpers.AllocatePinnedUninitialized( - TOTAL_BUFFER_SIZE + var buffers = Buffers = AllocationHelpers.AllocatePinnedUninitializedAligned( + TOTAL_BUFFER_SIZE, + ALIGNMENT, + out var buffersPtr ); - BuffersPtr = buffers.PinnedArrayToPointer(); + BuffersPtr = buffersPtr; Count = maxExpectedBatches; @@ -525,6 +539,88 @@ public DecodeOutput DecodeMutating(NativeBuffer ids, bool skipSpecialToke skipSpecialTokens ); } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public FreeHandle IDsToTokens(NativeBuffer ids, Span> u8Strings) + { + fixed (NativeBuffer* ptr = &MemoryMarshal.GetReference(u8Strings)) + { + var u8StringsBuffer = new NativeBuffer>(ptr, (nuint) u8Strings.Length); + + return IDsToTokens(ids, u8StringsBuffer); + } + } + + [MethodImpl(MethodImplOptions.AggressiveInlining)] + public FreeHandle IDsToTokens( + NativeBuffer ids, + NativeBuffer> tokens, + bool performSizeCheck = true) + { + if (performSizeCheck && tokens.Length < ids.Length) + { + ThrowHelpers.IDsToTokens_LengthCheckFailed(); + } + + var tokenizerHandle = TokenizerHandle; + + return new(TokenizerNativeMethods.IDsToTokens(tokenizerHandle, ids, tokens)); + } + + public string[] IDsToTokens(NativeBuffer ids) + { + var tokens = new string[ids.Length]; + + IDsToTokens(ids, tokens, performSizeCheck: false); + + return tokens; + } + + public void IDsToTokens(NativeBuffer ids, Span tokens, bool performSizeCheck = true) + { + var inputLength = ids.Length; + + if (performSizeCheck && (nuint) tokens.Length < inputLength) + { + ThrowHelpers.IDsToTokens_LengthCheckFailed(); + } + + var allocationSizeInBytes = (int) inputLength * sizeof(NativeBuffer>); + + var allocateNative = allocationSizeInBytes > (Config.ExpectedMaxInputLength * Config.ExpectedMaxBatches); + + NativeBuffer> allocation; + + if (!allocateNative) + { + var ptr = Allocator.GetFullAllocationUnsafely().Ptr; + + allocation = new((NativeBuffer*) ptr, inputLength); + } + + else + { + allocation = new NativeMemory>(inputLength).Buffer; + } + + using var freeHandle = IDsToTokens(ids, allocation, performSizeCheck: false); + + ref var currentToken = ref MemoryMarshal.GetReference(tokens); + + foreach (var buffer in allocation) + { + // In theory, we could intern the tokenizer's vocab and greatly reduce string allocs, + // but it is what it is for now... + currentToken = Encoding.UTF8.GetString(buffer.Ptr, (int) buffer.Length); + + currentToken = ref Unsafe.Add(ref currentToken, 1); + } + + if (allocateNative) + { + NativeMemory>.FreeWithPtrUnsafely(allocation.Ptr); + } + } public void Dispose() { diff --git a/Tokenizers.NET/TokenizerNativeMethods.cs b/Tokenizers.NET/TokenizerNativeMethods.cs index ece49e3..435fe1c 100644 --- a/Tokenizers.NET/TokenizerNativeMethods.cs +++ b/Tokenizers.NET/TokenizerNativeMethods.cs @@ -1,6 +1,7 @@ using System.Runtime.CompilerServices; using System.Runtime.InteropServices; using Tokenizers.NET.Collections; +using Tokenizers.NET.Outputs; namespace Tokenizers.NET { @@ -118,6 +119,14 @@ public static DecodeOutput TokenizerDecode( [LibraryImport(DLL_NAME, EntryPoint = "tokenizer_decode_skip_special_tokens")] private static partial DecodeOutput TokenizerDecodeSkipSpecialTokens(nint tokenizerPtr, NativeBuffer idBuffer); + [MethodImpl(MethodImplOptions.AggressiveInlining)] + [LibraryImport(DLL_NAME, EntryPoint = "ids_to_tokens")] + public static partial nint IDsToTokens( + nint tokenizerPtr, + NativeBuffer idBuffer, + NativeBuffer> tokenBuffer + ); + [MethodImpl(MethodImplOptions.AggressiveInlining)] [LibraryImport(DLL_NAME, EntryPoint = "free_with_handle")] public static partial void FreeWithHandle(nint handle);