From 31249167b82b8ea87dc185e955b3af2da6392b62 Mon Sep 17 00:00:00 2001 From: Peter Song Date: Thu, 21 Nov 2024 00:19:22 -0800 Subject: [PATCH] Fix bug in StreamingUtf8JsonReader, add tests --- .../Transform/StreamingUtf8JsonReader.cs | 25 ++++-- .../StreamingUtf8JsonReaderTests.cs | 82 +++++++++++++++++++ 2 files changed, 102 insertions(+), 5 deletions(-) create mode 100644 sdk/test/UnitTests/Custom/Marshalling/StreamingUtf8JsonReaderTests.cs diff --git a/sdk/src/Core/Amazon.Runtime/Internal/Transform/StreamingUtf8JsonReader.cs b/sdk/src/Core/Amazon.Runtime/Internal/Transform/StreamingUtf8JsonReader.cs index 8e490554a1ad..5a21326b2a91 100644 --- a/sdk/src/Core/Amazon.Runtime/Internal/Transform/StreamingUtf8JsonReader.cs +++ b/sdk/src/Core/Amazon.Runtime/Internal/Transform/StreamingUtf8JsonReader.cs @@ -41,12 +41,15 @@ public Utf8JsonReader Reader } private Stream _stream; + // due to all the resizing that happens with the byte array, pooling is not used here because the + // byte array returned by Array.Resize is not owned by the pool. private byte[] _buffer; public StreamingUtf8JsonReader(Stream stream) { _stream = stream; - _buffer = ArrayPool.Shared.Rent(4096); + // the default for STJ's deserializer is 16,384 but we'll leave this at 4096 for now. + _buffer = new byte[4096]; // need to initialize the reader even if the buffer is empty because auto-default of unassigned fields is only // supported in C# 11+ _reader = new Utf8JsonReader(_buffer); @@ -95,10 +98,15 @@ public void PassReaderByRef(RefAction action) /// true if there is more data, false otherwise. public bool Read() { + // hasMoreData can return false if the value starts in one buffer and leaks into the next buffer bool hasMoreData = _reader.Read(); - if (!hasMoreData) + + while (!hasMoreData) { GetMoreBytesFromStream(_stream, ref _buffer, ref _reader); + if (_reader.IsFinalBlock) + break; + hasMoreData = _reader.Read(); } @@ -109,24 +117,31 @@ public bool Read() private static void GetMoreBytesFromStream(Stream stream, ref byte[] buffer, ref Utf8JsonReader reader) { int bytesRead; + bool resized = false; + // if Read() returned false and we are here that means that we couldn't fully parse the JSON token + // because it was too large to fit in the remainder of the buffer. if (reader.BytesConsumed < buffer.Length) { - ReadOnlySpan leftover = buffer.AsSpan((int)reader.BytesConsumed); + ReadOnlySpan leftover = buffer.AsSpan().Slice((int)reader.BytesConsumed); + int previousBufferLength = buffer.Length; if (leftover.Length == buffer.Length) { + resized = true; Array.Resize(ref buffer, buffer.Length * 2); } leftover.CopyTo(buffer); bytesRead = stream.Read(buffer, leftover.Length, buffer.Length - leftover.Length); + // remove null bytes if they exist, since we don't know when the stream will end and we could have doubled the buffer size. + // otherwise the json reader will throw an exception. + if (resized) + Array.Resize(ref buffer, bytesRead + previousBufferLength); } else { bytesRead = stream.Read(buffer, 0, buffer.Length); } - if (bytesRead == 0) - ArrayPool.Shared.Return(buffer); reader = new Utf8JsonReader(buffer, isFinalBlock: bytesRead == 0, reader.CurrentState); } diff --git a/sdk/test/UnitTests/Custom/Marshalling/StreamingUtf8JsonReaderTests.cs b/sdk/test/UnitTests/Custom/Marshalling/StreamingUtf8JsonReaderTests.cs new file mode 100644 index 000000000000..ac734eb04f31 --- /dev/null +++ b/sdk/test/UnitTests/Custom/Marshalling/StreamingUtf8JsonReaderTests.cs @@ -0,0 +1,82 @@ +using AWSSDK_DotNet.CommonTest.Utils; +using AWSSDK_DotNet.UnitTests; +using Microsoft.VisualStudio.TestTools.UnitTesting; +using System; +using System.IO; +using System.Collections.Generic; +using System.Linq; +using System.Text; +using System.Text.Json; +using System.Threading.Tasks; +using Amazon.Runtime.Internal.Transform; +namespace AWSSDK.UnitTests +{ + /// + /// Protocol Tests already exists to test the marshalling and unmarhsalling of request and responses in json, but they don't test very + /// large payloads, which would trigger the logic for + /// This class just tests the wrapper class StreamingUtf8JsonReader. + /// + [TestClass] + public class StreamingUtf8JsonReaderTests + { + [TestMethod] + public void HandlesUtf8BOM() + { + // we can't use reflection to access the private fields of StreamingUtf8JsonReader since it is a ref struct so we have to test it this way. + var a = Convert.ToByte('{'); + var b = Convert.ToByte('x'); + var c = Convert.ToByte(':'); + var d = Convert.ToByte('y'); + var e = Convert.ToByte('}'); + + byte[] payload = new byte[] { 0xEF, 0xBB, 0xBF, a, b, c ,d, e}; + MemoryStream stream = new MemoryStream(payload); + StreamingUtf8JsonReader reader = new StreamingUtf8JsonReader(stream); + bool firstIteration = true; + while (reader.Read()) + { + if (firstIteration) + { + // make sure the BOM was removed + Assert.IsTrue(reader.Reader.TokenType == JsonTokenType.StartObject); + firstIteration = false; + return; + } + } + } + // This method tests that if the json token starts in one buffer but continues into the next buffer + // the reader can handle parsing it correctly. + [TestMethod] + public void Utf8JsonReaderHandlesJsonTokenThatSpansMultipleBuffers() + { + // Arrange + // here we're creating a json string that is greater than 4096 bytes to test the GetMoreBytesFromStream logic + var sb = new StringBuilder(); + sb.Append("{ \"key\": \""); + sb.Append(new string('x', 7500)); // String with 5000 'x' characters + sb.Append("\" }"); + string largeJson = sb.ToString(); + + byte[] payload = Encoding.UTF8.GetBytes(largeJson); + using (var stream = new MemoryStream(payload)) + { + var reader = new StreamingUtf8JsonReader(stream); + string key = null, value = null; + + while (reader.Read()) + { + if (reader.Reader.TokenType == JsonTokenType.PropertyName) + { + key = reader.Reader.GetString(); + } + else if (reader.Reader.TokenType == JsonTokenType.String) + { + value = reader.Reader.GetString(); + } + } + Assert.AreEqual("key", key); + Assert.AreEqual(new string('x', 7500), value); + } + } + } +}