diff --git a/src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs b/src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs index e2d58d8977..19dad0092d 100644 --- a/src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs +++ b/src/Core/Silk.NET.Core.Tests/TestSilkMarshal.cs @@ -1,4 +1,10 @@ +// Licensed to the .NET Foundation under one or more agreements. +// The .NET Foundation licenses this file to you under the MIT license. + +using System; using System.Collections.Generic; +using System.Runtime.InteropServices; +using System.Text; using Silk.NET.Core.Native; using Xunit; @@ -15,6 +21,44 @@ public class TestSilkMarshal NativeStringEncoding.LPWStr, }; + private readonly Encoding lpwStrEncoding = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) + ? Encoding.Unicode + : Encoding.UTF32; + + private readonly int lpwStrCharacterWidth = RuntimeInformation.IsOSPlatform(OSPlatform.Windows) ? 2 : 4; + + [Fact] + public unsafe void TestEncodingToLPWStr() + { + var input = "Hello world 🧵"; + + var expectedByteCount = lpwStrEncoding.GetByteCount(input); + var expected = new byte[expectedByteCount + lpwStrCharacterWidth]; + lpwStrEncoding.GetBytes(input, expected); + + var pointer = SilkMarshal.StringToPtr(input, NativeStringEncoding.LPWStr); + var pointerByteCount = lpwStrCharacterWidth * (int) SilkMarshal.StringLength(pointer, NativeStringEncoding.LPWStr); + + Assert.Equal(expected, new Span((void*)pointer, pointerByteCount + lpwStrCharacterWidth)); + } + + [Fact] + public unsafe void TestEncodingFromLPWStr() + { + var expected = "Hello world 🧵"; + + var inputByteCount = lpwStrEncoding.GetByteCount(expected); + var input = new byte[inputByteCount + lpwStrCharacterWidth]; + lpwStrEncoding.GetBytes(expected, input); + + fixed (byte* pInput = input) + { + var output = SilkMarshal.PtrToString((nint)pInput, NativeStringEncoding.LPWStr); + + Assert.Equal(expected, output); + } + } + [Fact] public void TestEncodingString() { diff --git a/src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs b/src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs index c759a97408..e1989d4ea6 100644 --- a/src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs +++ b/src/Core/Silk.NET.Core/Native/NativeStringEncoding.cs @@ -9,6 +9,9 @@ public enum NativeStringEncoding LPStr = UnmanagedType.LPStr, LPTStr = UnmanagedType.LPTStr, LPUTF8Str = UnmanagedType.LPUTF8Str, + /// + /// On Windows, a null-terminated UTF-16 string. On other platforms, a null-terminated UTF-32 string. + /// LPWStr = UnmanagedType.LPWStr, WinString = UnmanagedType.WinString, Ansi = LPStr, diff --git a/src/Core/Silk.NET.Core/Native/SilkMarshal.cs b/src/Core/Silk.NET.Core/Native/SilkMarshal.cs index 9bebfc1d5e..9d6b9be839 100644 --- a/src/Core/Silk.NET.Core/Native/SilkMarshal.cs +++ b/src/Core/Silk.NET.Core/Native/SilkMarshal.cs @@ -144,7 +144,8 @@ public static int GetMaxSizeOf(string? input, NativeStringEncoding encoding = Na NativeStringEncoding.BStr => -1, NativeStringEncoding.LPStr or NativeStringEncoding.LPTStr or NativeStringEncoding.LPUTF8Str => (input is null ? 0 : Encoding.UTF8.GetMaxByteCount(input.Length)) + 1, - NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 2, + NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows) => ((input?.Length ?? 0) + 1) * 2, + NativeStringEncoding.LPWStr => ((input?.Length ?? 0) + 1) * 4, _ => -1 }; @@ -188,29 +189,38 @@ public static unsafe int StringIntoSpan int convertedBytes; fixed (char* firstChar = input) + fixed (byte* bytes = span) { - fixed (byte* bytes = span) - { - convertedBytes = Encoding.UTF8.GetBytes(firstChar, input.Length, bytes, span.Length - 1); - } + convertedBytes = Encoding.UTF8.GetBytes(firstChar, input.Length, bytes, span.Length - 1); + bytes[convertedBytes] = 0; } - span[convertedBytes] = 0; - return ++convertedBytes; + return convertedBytes + 1; } - case NativeStringEncoding.LPWStr: + case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows): { fixed (char* firstChar = input) + fixed (byte* bytes = span) { - fixed (byte* bytes = span) - { - Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2); - ((char*)bytes)[input.Length] = default; - } + Buffer.MemoryCopy(firstChar, bytes, span.Length, input.Length * 2); + ((char*)bytes)[input.Length] = default; } return input.Length + 1; } + case NativeStringEncoding.LPWStr: + { + int convertedBytes; + + fixed (char* firstChar = input) + fixed (byte* bytes = span) + { + convertedBytes = Encoding.UTF32.GetBytes(firstChar, input.Length, bytes, span.Length - 4); + ((uint*)bytes)[convertedBytes / 4] = 0; + } + + return convertedBytes + 4; + } default: { ThrowInvalidEncoding(); @@ -311,7 +321,19 @@ static unsafe string BStrToString(nint ptr) => new string((char*) ptr, 0, (int) (*((uint*) ptr - 1) / sizeof(char))); static unsafe string AnsiToString(nint ptr) => new string((sbyte*) ptr); - static unsafe string WideToString(nint ptr) => new string((char*) ptr); + + static unsafe string WideToString(nint ptr) + { + if (RuntimeInformation.IsOSPlatform(OSPlatform.Windows)) + { + return new string((char*) ptr); + } + else + { + var length = StringLength(ptr, NativeStringEncoding.LPWStr); + return Encoding.UTF32.GetString((byte*) ptr, 4 * (int) length); + } + }; } /// @@ -524,15 +546,41 @@ Func customUnmarshaller /// #if NET6_0_OR_GREATER [MethodImpl(MethodImplOptions.AggressiveInlining)] - public static unsafe nuint StringLength( + public static unsafe nuint StringLength + ( nint ptr, NativeStringEncoding encoding = NativeStringEncoding.Ansi - ) => - (nuint)( - encoding == NativeStringEncoding.LPWStr - ? MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length - : MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length - ); + ) + { + switch (encoding) + { + default: + { + return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((byte*)ptr).Length; + } + case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows): + { + return (nuint)MemoryMarshal.CreateReadOnlySpanFromNullTerminated((char*)ptr).Length; + } + case NativeStringEncoding.LPWStr: + { + // No int overload for CreateReadOnlySpanFromNullTerminated + if (ptr == 0) + { + return 0; + } + + nuint length = 0; + while (((uint*) ptr)![length] != 0) + { + length++; + } + + return length; + } + } + } + #else public static unsafe nuint StringLength( nint ptr, @@ -543,15 +591,40 @@ public static unsafe nuint StringLength( { return 0; } - nuint ret; - for ( - ret = 0; - encoding == NativeStringEncoding.LPWStr - ? ((char*)ptr)![ret] != 0 - : ((byte*)ptr)![ret] != 0; - ret++ - ) { } - return ret; + + nuint length = 0; + switch (encoding) + { + default: + { + while (((byte*) ptr)![length] != 0) + { + length++; + } + + break; + } + case NativeStringEncoding.LPWStr when RuntimeInformation.IsOSPlatform(OSPlatform.Windows): + { + while (((char*) ptr)![length] != 0) + { + length++; + } + + break; + } + case NativeStringEncoding.LPWStr: + { + while (((uint*) ptr)![length] != 0) + { + length++; + } + + break; + } + } + + return length; } #endif