Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve materializing input stream in the client #25171

Merged
merged 8 commits into from
Feb 27, 2025
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
import okhttp3.ResponseBody;

import java.io.IOException;
import java.io.Reader;
import java.io.InputStream;
import java.io.UncheckedIOException;
import java.util.Optional;

Expand Down Expand Up @@ -115,16 +115,16 @@ public static <T> JsonResponse<T> execute(TrinoJsonCodec<T> codec, Call.Factory
if (isJson(responseBody.contentType())) {
T value = null;
IllegalArgumentException exception = null;
MaterializingReader reader = new MaterializingReader(responseBody.charStream(), 128 * 1024);
try (Reader ignored = reader) {
// Parse from input stream, response is either of unknown size or too large to materialize. Raw response body
// will not be available if parsing fails
value = codec.fromJson(reader);
MaterializingInputStream stream = new MaterializingInputStream(responseBody.byteStream(), 8 * 1024);
try (InputStream ignored = stream) {
// Parse from input stream, response is either of unknown size or too large to materialize.
// 8K of the response body will be available if parsing fails.
value = codec.fromJson(stream);
}
catch (JsonProcessingException e) {
exception = new IllegalArgumentException(format("Unable to create %s from JSON response:\n[%s]", codec.getType(), reader.getHeadString()), e);
exception = new IllegalArgumentException(format("Unable to create %s from JSON response:\n[%s]", codec.getType(), stream.getHeadString()), e);
}
return new JsonResponse<>(response.code(), response.headers(), reader.getHeadString(), value, exception);
return new JsonResponse<>(response.code(), response.headers(), stream.getHeadString(), value, exception);
}
return new JsonResponse<>(response.code(), response.headers(), responseBody.string());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,25 +13,26 @@
*/
package io.trino.client;

import java.io.FilterReader;
import java.io.FilterInputStream;
import java.io.IOException;
import java.io.Reader;
import java.io.InputStream;

import static com.google.common.base.Verify.verify;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;

class MaterializingReader
extends FilterReader
class MaterializingInputStream
extends FilterInputStream
{
private final char[] headChars;
private final byte[] head;
private int remaining;
private int currentOffset;

protected MaterializingReader(Reader reader, int maxHeadChars)
protected MaterializingInputStream(InputStream stream, int maxBytes)
{
super(reader);
verify(maxHeadChars > 0 && maxHeadChars <= 128 * 1024, "maxHeadChars must be between 1 and 128 KB");
this.headChars = new char[maxHeadChars];
super(stream);
verify(maxBytes > 0 && maxBytes <= 8 * 1024, "maxBytes must be between 1B and 8 KB");
this.head = new byte[maxBytes];
}

@Override
Expand All @@ -40,8 +41,8 @@ public int read()
{
int value = super.read();
if (value != -1) {
if (currentOffset < headChars.length) {
headChars[currentOffset++] = (char) value;
if (currentOffset < head.length) {
head[currentOffset++] = (byte) value;
}
else {
remaining++;
Expand All @@ -51,26 +52,33 @@ public int read()
}

@Override
public int read(char[] cbuf, int off, int len)
public int read(byte[] buffer, int off, int len)
throws IOException
{
int read = super.read(cbuf, off, len);
int read = super.read(buffer, off, len);
if (read > 0) {
int copyLength = Math.min(read, headChars.length - currentOffset);
int copyLength = Math.min(read, head.length - currentOffset);
if (read > copyLength) {
remaining += read - copyLength;
}
if (copyLength > 0) {
System.arraycopy(cbuf, off, headChars, currentOffset, copyLength);
System.arraycopy(buffer, off, head, currentOffset, copyLength);
currentOffset += copyLength;
}
}
return read;
}

@Override
public int read(byte[] buffer)
throws IOException
{
return read(buffer, 0, buffer.length);
}

public String getHeadString()
{
return String.valueOf(headChars, 0, currentOffset) + (remaining > 0 ? format("... [" + bytesOmitted(remaining) + "]", remaining) : "");
return new String(head, 0, currentOffset, UTF_8) + (remaining > 0 ? format("... [" + bytesOmitted(remaining) + "]", remaining) : "");
}

private String bytesOmitted(long bytes)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,25 @@
import org.junit.jupiter.api.Test;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.StringWriter;

import static java.nio.charset.StandardCharsets.UTF_8;
import static org.assertj.core.api.Assertions.assertThat;

class TestMaterializingReader
class TestMaterializingInputStream
{
@Test
void testHeadBufferOverflow()
throws IOException
{
InputStream stream = new ByteArrayInputStream("abcd".repeat(1337).getBytes(UTF_8));
MaterializingReader reader = new MaterializingReader(new InputStreamReader(stream, UTF_8), 4);
MaterializingInputStream reader = new MaterializingInputStream(stream, 4);

int remainingBytes = 4 * 1337 - 4;

reader.transferTo(new StringWriter()); // Trigger reading
reader.transferTo(new ByteArrayOutputStream()); // Trigger reading
assertThat(reader.getHeadString())
.isEqualTo("abcd... [" + remainingBytes + " more bytes]");
}
Expand All @@ -45,9 +44,9 @@ void testHeadBufferNotFullyUsed()
throws IOException
{
InputStream stream = new ByteArrayInputStream("abcdabc".getBytes(UTF_8));
MaterializingReader reader = new MaterializingReader(new InputStreamReader(stream, UTF_8), 8);
MaterializingInputStream reader = new MaterializingInputStream(stream, 8);

reader.transferTo(new StringWriter()); // Trigger reading
reader.transferTo(new ByteArrayOutputStream()); // Trigger reading
assertThat(reader.getHeadString()).isEqualTo("abcdabc");
}

Expand All @@ -56,9 +55,9 @@ void testHeadBufferFullyUsed()
throws IOException
{
InputStream stream = new ByteArrayInputStream("a".repeat(8).getBytes(UTF_8));
MaterializingReader reader = new MaterializingReader(new InputStreamReader(stream, UTF_8), 8);
MaterializingInputStream reader = new MaterializingInputStream(stream, 8);

reader.transferTo(new StringWriter()); // Trigger reading
reader.transferTo(new ByteArrayOutputStream()); // Trigger reading
assertThat(reader.getHeadString()).isEqualTo("a".repeat(8));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -120,9 +120,4 @@ private boolean hasSpoolingMetadata(Page page, int outputColumnsSize)
{
return page.getChannelCount() == outputColumnsSize + 1 && page.getPositionCount() == 1 && !page.getBlock(outputColumnsSize).isNull(0);
}

public static QueryDataProducer createSpooledQueryDataProducer(QueryDataEncoder.Factory encoder)
{
return new SpoolingQueryDataProducer(encoder);
}
}