Skip to content

Commit

Permalink
Switch implementation to avoid throwing at start of stream if no data…
Browse files Browse the repository at this point in the history
… can be read
  • Loading branch information
voidstar69 committed Sep 19, 2024
1 parent 947d227 commit d5215ec
Showing 1 changed file with 22 additions and 13 deletions.
35 changes: 22 additions & 13 deletions csharp/src/Apache.Arrow/Ipc/ArrowStreamReaderImplementation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -55,9 +55,6 @@ public override RecordBatch ReadNextRecordBatch()

protected async ValueTask<RecordBatch> ReadRecordBatchAsync(CancellationToken cancellationToken = default)
{
if (BaseStream.Length == 0)
return null;

await ReadSchemaAsync().ConfigureAwait(false);

ReadResult result = default;
Expand All @@ -71,7 +68,7 @@ protected async ValueTask<RecordBatch> ReadRecordBatchAsync(CancellationToken ca

protected async ValueTask<ReadResult> ReadMessageAsync(CancellationToken cancellationToken)
{
int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, cancellationToken)
int messageLength = await ReadMessageLengthAsync(throwOnFullRead: false, returnOnEmptyStream: false, cancellationToken)
.ConfigureAwait(false);

if (messageLength == 0)
Expand Down Expand Up @@ -106,9 +103,6 @@ protected async ValueTask<ReadResult> ReadMessageAsync(CancellationToken cancell

protected RecordBatch ReadRecordBatch()
{
if (BaseStream.Length == 0)
return null;

ReadSchema();

ReadResult result = default;
Expand All @@ -122,8 +116,7 @@ protected RecordBatch ReadRecordBatch()

protected ReadResult ReadMessage()
{
int messageLength = ReadMessageLength(throwOnFullRead: false);

int messageLength = ReadMessageLength(throwOnFullRead: false, returnOnEmptyStream: false);
if (messageLength == 0)
{
// reached end
Expand Down Expand Up @@ -166,8 +159,12 @@ public override async ValueTask ReadSchemaAsync(CancellationToken cancellationTo
}

// Figure out length of schema
int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true, cancellationToken)
int schemaMessageLength = await ReadMessageLengthAsync(throwOnFullRead: true, returnOnEmptyStream: true, cancellationToken)
.ConfigureAwait(false);
if (schemaMessageLength == 0)
{
return;
}

using (ArrayPool<byte>.Shared.RentReturn(schemaMessageLength, out Memory<byte> buff))
{
Expand All @@ -188,7 +185,11 @@ public override void ReadSchema()
}

// Figure out length of schema
int schemaMessageLength = ReadMessageLength(throwOnFullRead: true);
int schemaMessageLength = ReadMessageLength(throwOnFullRead: true, returnOnEmptyStream: true);
if(schemaMessageLength == 0)
{
return;
}

using (ArrayPool<byte>.Shared.RentReturn(schemaMessageLength, out Memory<byte> buff))
{
Expand All @@ -200,13 +201,17 @@ public override void ReadSchema()
}
}

private async ValueTask<int> ReadMessageLengthAsync(bool throwOnFullRead, CancellationToken cancellationToken = default)
private async ValueTask<int> ReadMessageLengthAsync(bool throwOnFullRead, bool returnOnEmptyStream, CancellationToken cancellationToken = default)
{
int messageLength = 0;
using (ArrayPool<byte>.Shared.RentReturn(4, out Memory<byte> lengthBuffer))
{
int bytesRead = await BaseStream.ReadFullBufferAsync(lengthBuffer, cancellationToken)
.ConfigureAwait(false);
if (bytesRead == 0)
{
return 0;
}
if (throwOnFullRead)
{
EnsureFullRead(lengthBuffer, bytesRead);
Expand Down Expand Up @@ -239,12 +244,16 @@ private async ValueTask<int> ReadMessageLengthAsync(bool throwOnFullRead, Cancel
return messageLength;
}

private int ReadMessageLength(bool throwOnFullRead)
private int ReadMessageLength(bool throwOnFullRead, bool returnOnEmptyStream)
{
int messageLength = 0;
using (ArrayPool<byte>.Shared.RentReturn(4, out Memory<byte> lengthBuffer))
{
int bytesRead = BaseStream.ReadFullBuffer(lengthBuffer);
if (bytesRead == 0 && returnOnEmptyStream)
{
return 0;
}
if (throwOnFullRead)
{
EnsureFullRead(lengthBuffer, bytesRead);
Expand Down

0 comments on commit d5215ec

Please sign in to comment.