Skip to content

Commit

Permalink
Revert "fix: addressing reviews"
Browse files Browse the repository at this point in the history
This reverts commit 4eb1836.
  • Loading branch information
vibhatha committed Dec 14, 2023
1 parent 4eb1836 commit 551263a
Show file tree
Hide file tree
Showing 5 changed files with 76 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,14 +28,14 @@
import java.util.ArrayList;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;

import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.memory.RootAllocator;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.GenerateSampleData;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.compression.NoCompressionCodec;
import org.apache.arrow.vector.dictionary.Dictionary;
Expand All @@ -56,6 +56,7 @@
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
import org.junit.jupiter.api.Disabled;

public class TestArrowReaderWriterWithCompression {

Expand Down Expand Up @@ -93,10 +94,9 @@ private void createAndWriteArrowFile(DictionaryProvider provider,
final int rowCount = 10;
GenerateSampleData.generateTestData(root.getVector(0), rowCount);
root.setRowCount(rowCount);
CompressionCodec codec = CommonsCompressionFactory.INSTANCE.createCodec(codecType,
/*compressionLevel=*/7);

try (final ArrowFileWriter writer = new ArrowFileWriter(root, provider, Channels.newChannel(out),
new HashMap<>(), IpcOption.DEFAULT, codec)) {
new HashMap<>(), IpcOption.DEFAULT, CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
Expand Down Expand Up @@ -137,10 +137,10 @@ private List<Field> createFields(Dictionary dictionary, BufferAllocator allocato
private File writeArrowStream(VectorSchemaRoot root, DictionaryProvider provider,
CompressionUtil.CodecType codecType) throws IOException {
File tempFile = File.createTempFile("dictionary_compression", ".arrow");
CompressionCodec codec = CommonsCompressionFactory.INSTANCE.createCodec(codecType, 7);
try (FileOutputStream fileOut = new FileOutputStream(tempFile);
ArrowStreamWriter writer = new ArrowStreamWriter(root, provider,
Channels.newChannel(fileOut), IpcOption.DEFAULT, codec)) {
Channels.newChannel(fileOut), IpcOption.DEFAULT,
CommonsCompressionFactory.INSTANCE, codecType, Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
Expand All @@ -149,6 +149,7 @@ private File writeArrowStream(VectorSchemaRoot root, DictionaryProvider provider
}

@Test
@Disabled
public void testArrowFileZstdRoundTrip() throws Exception {
createAndWriteArrowFile(null, CompressionUtil.CodecType.ZSTD);
// with compression
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import java.util.Collection;
import java.util.HashMap;
import java.util.List;
import java.util.Optional;
import java.util.function.BiConsumer;
import java.util.stream.Stream;

Expand Down Expand Up @@ -232,13 +233,13 @@ private static Stream<CompressionUtil.CodecType> codecTypes() {

@ParameterizedTest
@MethodSource("codecTypes")
void testReadWriteStream(CompressionUtil.CodecType codecType) throws Exception {
withRoot(codecType, (factory, root) -> {
void testReadWriteStream(CompressionUtil.CodecType codec) throws Exception {
withRoot(codec, (factory, root) -> {
ByteArrayOutputStream compressedStream = new ByteArrayOutputStream();
CompressionCodec codec = factory.createCodec(codecType, /*compressionLevel=*/7);
try (final ArrowStreamWriter writer = new ArrowStreamWriter(
root, new DictionaryProvider.MapDictionaryProvider(),
Channels.newChannel(compressedStream), IpcOption.DEFAULT, codec)) {
Channels.newChannel(compressedStream),
IpcOption.DEFAULT, factory, codec, Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
Expand All @@ -259,14 +260,13 @@ void testReadWriteStream(CompressionUtil.CodecType codecType) throws Exception {

@ParameterizedTest
@MethodSource("codecTypes")
void testReadWriteFile(CompressionUtil.CodecType codecType) throws Exception {
withRoot(codecType, (factory, root) -> {
void testReadWriteFile(CompressionUtil.CodecType codec) throws Exception {
withRoot(codec, (factory, root) -> {
ByteArrayOutputStream compressedStream = new ByteArrayOutputStream();
CompressionCodec codec = factory.createCodec(codecType, /*compressionLevel=*/7);
try (final ArrowFileWriter writer = new ArrowFileWriter(
root, new DictionaryProvider.MapDictionaryProvider(),
Channels.newChannel(compressedStream),
new HashMap<>(), IpcOption.DEFAULT, codec)) {
new HashMap<>(), IpcOption.DEFAULT, factory, codec, Optional.of(7))) {
writer.start();
writer.writeBatch();
writer.end();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,13 @@
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.ArrowBlock;
Expand Down Expand Up @@ -74,8 +76,15 @@ public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, Writa
}

public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out,
Map<String, String> metaData, IpcOption option, CompressionCodec codec) {
super(root, provider, out, option, codec);
Map<String, String> metaData, IpcOption option, CompressionCodec.Factory compressionFactory,
CompressionUtil.CodecType codecType) {
this(root, provider, out, metaData, option, compressionFactory, codecType, Optional.empty());
}

public ArrowFileWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out,
Map<String, String> metaData, IpcOption option, CompressionCodec.Factory compressionFactory,
CompressionUtil.CodecType codecType, Optional<Integer> compressionLevel) {
super(root, provider, out, option, compressionFactory, codecType, compressionLevel);
this.metaData = metaData;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,15 @@
import java.nio.channels.WritableByteChannel;
import java.util.HashMap;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.vector.FieldVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.compare.VectorEqualsVisitor;
import org.apache.arrow.vector.compression.CompressionCodec;
import org.apache.arrow.vector.compression.CompressionUtil;
import org.apache.arrow.vector.dictionary.Dictionary;
import org.apache.arrow.vector.dictionary.DictionaryProvider;
import org.apache.arrow.vector.ipc.message.IpcOption;
Expand Down Expand Up @@ -80,13 +82,33 @@ public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, Wri
* @param root Existing VectorSchemaRoot with vectors to be written.
* @param provider DictionaryProvider for any vectors that are dictionary encoded.
* (Optional, can be null)
* @param option IPC write options
* @param compressionFactory Compression codec factory
* @param codecType Codec type
* @param out WritableByteChannel for writing.
*/
public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out,
IpcOption option, CompressionCodec.Factory compressionFactory,
CompressionUtil.CodecType codecType) {
this(root, provider, out, option, compressionFactory, codecType, Optional.empty());
}

/**
* Construct an ArrowStreamWriter with compression enabled.
*
* @param root Existing VectorSchemaRoot with vectors to be written.
* @param provider DictionaryProvider for any vectors that are dictionary encoded.
* (Optional, can be null)
* @param option IPC write options
* @param codec Compression codec
* @param compressionFactory Compression codec factory
* @param codecType Codec type
* @param compressionLevel Compression level
* @param out WritableByteChannel for writing.
*/
public ArrowStreamWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out,
IpcOption option, CompressionCodec codec) {
super(root, provider, out, option, codec);
IpcOption option, CompressionCodec.Factory compressionFactory,
CompressionUtil.CodecType codecType, Optional<Integer> compressionLevel) {
super(root, provider, out, option, compressionFactory, codecType, compressionLevel);
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;

import org.apache.arrow.vector.FieldVector;
Expand Down Expand Up @@ -59,9 +60,12 @@ public abstract class ArrowWriter implements AutoCloseable {
private final VectorUnloader unloader;
private final DictionaryProvider dictionaryProvider;
private final Set<Long> dictionaryIdsUsed = new HashSet<>();

private final CompressionCodec.Factory compressionFactory;
private final CompressionUtil.CodecType codecType;
private final Optional<Integer> compressionLevel;
private boolean started = false;
private boolean ended = false;
private final CompressionCodec codec;

protected IpcOption option;

Expand All @@ -70,8 +74,8 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab
}

protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, IpcOption option) {
this(root, provider, out, option, NoCompressionCodec.Factory.INSTANCE.createCodec(
CompressionUtil.CodecType.NO_COMPRESSION));
this(root, provider, out, option, NoCompressionCodec.Factory.INSTANCE, CompressionUtil.CodecType.NO_COMPRESSION,
Optional.empty());
}

/**
Expand All @@ -81,16 +85,22 @@ protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, Writab
* @param provider where to find the dictionaries
* @param out the output where to write
* @param option IPC write options
* @param codec the compression codec
* @param compressionFactory Compression codec factory
* @param codecType Compression codec
* @param compressionLevel Compression level
*/
protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out,
IpcOption option, CompressionCodec codec) {
protected ArrowWriter(VectorSchemaRoot root, DictionaryProvider provider, WritableByteChannel out, IpcOption option,
CompressionCodec.Factory compressionFactory, CompressionUtil.CodecType codecType,
Optional<Integer> compressionLevel) {
this.out = new WriteChannel(out);
this.option = option;
this.dictionaryProvider = provider;
this.codec = codec;

this.compressionFactory = compressionFactory;
this.codecType = codecType;
this.compressionLevel = compressionLevel;
this.unloader = new VectorUnloader(root, /*includeNullCount*/ true,
codec, /*alignBuffers*/ true);
getCodec(), /*alignBuffers*/ true);

List<Field> fields = new ArrayList<>(root.getSchema().getFields().size());

Expand Down Expand Up @@ -126,7 +136,7 @@ protected void writeDictionaryBatch(Dictionary dictionary) throws IOException {
Collections.singletonList(vector.getField()),
Collections.singletonList(vector),
count);
VectorUnloader unloader = new VectorUnloader(dictRoot, /*includeNullCount*/ true, codec,
VectorUnloader unloader = new VectorUnloader(dictRoot, /*includeNullCount*/ true, getCodec(),
/*alignBuffers*/ true);
ArrowRecordBatch batch = unloader.getRecordBatch();
ArrowDictionaryBatch dictionaryBatch = new ArrowDictionaryBatch(id, batch, false);
Expand Down Expand Up @@ -168,6 +178,12 @@ public long bytesWritten() {
return out.getCurrentPosition();
}

private CompressionCodec getCodec() {
return this.compressionLevel.isPresent() ?
this.compressionFactory.createCodec(this.codecType, this.compressionLevel.get()) :
this.compressionFactory.createCodec(this.codecType);
}

private void ensureStarted() throws IOException {
if (!started) {
started = true;
Expand Down

0 comments on commit 551263a

Please sign in to comment.