Skip to content

Commit

Permalink
apacheGH-40942: [Java] Implement C Data Interface for StringView (apa…
Browse files Browse the repository at this point in the history
…che#41967)

### Rationale for this change

Recent inclusion of `Utf8View` and `BinaryView` support to Java also requires adding C Data interface for integrating it with other systems. 

### What changes are included in this PR?

- [X] Adding core functionality for C Data interface for `Utf8View` and `BinaryView`
- [X] Adding `RoundtripTest`
- [X] Adding `StreamingTest`

### Are these changes tested?

Yes, with new tests. 

### Are there any user-facing changes?

No
* GitHub Issue: apache#40942

Authored-by: Vibhatha Abeykoon <[email protected]>
Signed-off-by: David Li <[email protected]>
  • Loading branch information
vibhatha authored Jun 21, 2024
1 parent 27bbc3c commit 3711657
Show file tree
Hide file tree
Showing 21 changed files with 931 additions and 54 deletions.
1 change: 0 additions & 1 deletion dev/archery/archery/integration/datagen.py
Original file line number Diff line number Diff line change
Expand Up @@ -1932,7 +1932,6 @@ def _temp_path():
.skip_tester('Rust'),

generate_binary_view_case()
.skip_tester('Java')
.skip_tester('JS')
.skip_tester('nanoarrow')
.skip_tester('Rust'),
Expand Down
10 changes: 4 additions & 6 deletions java/c/src/main/java/org/apache/arrow/c/ArrayExporter.java
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,6 @@ public void close() {

void export(ArrowArray array, FieldVector vector, DictionaryProvider dictionaryProvider) {
List<FieldVector> children = vector.getChildrenFromFields();
List<ArrowBuf> buffers = vector.getFieldBuffers();
int valueCount = vector.getValueCount();
int nullCount = vector.getNullCount();
DictionaryEncoding dictionaryEncoding = vector.getField().getDictionary();
Expand All @@ -89,11 +88,10 @@ void export(ArrowArray array, FieldVector vector, DictionaryProvider dictionaryP
}
}

if (buffers != null) {
data.buffers = new ArrayList<>(buffers.size());
data.buffers_ptrs = allocator.buffer((long) buffers.size() * Long.BYTES);
vector.exportCDataBuffers(data.buffers, data.buffers_ptrs, NULL);
}
data.buffers = new ArrayList<>(vector.getExportedCDataBufferCount());
data.buffers_ptrs =
allocator.buffer((long) (vector.getExportedCDataBufferCount()) * Long.BYTES);
vector.exportCDataBuffers(data.buffers, data.buffers_ptrs, NULL);

if (dictionaryEncoding != null) {
Dictionary dictionary = dictionaryProvider.lookup(dictionaryEncoding.getId());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import org.apache.arrow.memory.BufferAllocator;
import org.apache.arrow.util.AutoCloseables;
import org.apache.arrow.util.VisibleForTesting;
import org.apache.arrow.vector.BaseVariableWidthViewVector;
import org.apache.arrow.vector.DateDayVector;
import org.apache.arrow.vector.DateMilliVector;
import org.apache.arrow.vector.DurationVector;
Expand All @@ -51,7 +52,6 @@
import org.apache.arrow.vector.complex.UnionVector;
import org.apache.arrow.vector.ipc.message.ArrowFieldNode;
import org.apache.arrow.vector.types.pojo.ArrowType;
import org.apache.arrow.vector.types.pojo.ArrowType.ListView;
import org.apache.arrow.vector.util.DataSizeRoundingUtil;

/** Import buffers from a C Data Interface struct. */
Expand Down Expand Up @@ -227,10 +227,37 @@ public List<ArrowBuf> visit(ArrowType.Utf8 type) {
}
}

private List<ArrowBuf> visitVariableWidthView(ArrowType type) {
final int viewBufferIndex = 1;
final int variadicSizeBufferIndex = this.buffers.length - 1;
final long numOfVariadicBuffers = this.buffers.length - 3;
final long variadicSizeBufferCapacity = numOfVariadicBuffers * Long.BYTES;
List<ArrowBuf> buffers = new ArrayList<>();

ArrowBuf variadicSizeBuffer =
importBuffer(type, variadicSizeBufferIndex, variadicSizeBufferCapacity);

ArrowBuf view =
importFixedBytes(type, viewBufferIndex, BaseVariableWidthViewVector.ELEMENT_SIZE);
buffers.add(maybeImportBitmap(type));
buffers.add(view);

// 0th buffer is validity buffer
// 1st buffer is view buffer
// 2nd buffer onwards are variadic buffer
// N-1 (this.buffers.length - 1) buffer is variadic size buffer
final int variadicBufferReadOffset = 2;
for (int i = 0; i < numOfVariadicBuffers; i++) {
long size = variadicSizeBuffer.getLong((long) i * Long.BYTES);
buffers.add(importBuffer(type, i + variadicBufferReadOffset, size));
}

return buffers;
}

@Override
public List<ArrowBuf> visit(ArrowType.Utf8View type) {
throw new UnsupportedOperationException(
"Importing buffers for view type: " + type + " not supported");
return visitVariableWidthView(type);
}

@Override
Expand Down Expand Up @@ -270,8 +297,7 @@ public List<ArrowBuf> visit(ArrowType.Binary type) {

@Override
public List<ArrowBuf> visit(ArrowType.BinaryView type) {
throw new UnsupportedOperationException(
"Importing buffers for view type: " + type + " not supported");
return visitVariableWidthView(type);
}

@Override
Expand Down Expand Up @@ -373,7 +399,7 @@ public List<ArrowBuf> visit(ArrowType.Duration type) {
}

@Override
public List<ArrowBuf> visit(ListView type) {
public List<ArrowBuf> visit(ArrowType.ListView type) {
throw new UnsupportedOperationException(
"Importing buffers for view type: " + type + " not supported");
}
Expand Down
8 changes: 8 additions & 0 deletions java/c/src/main/java/org/apache/arrow/c/Format.java
Original file line number Diff line number Diff line change
Expand Up @@ -225,6 +225,10 @@ static String asString(ArrowType arrowType) {
}
case Utf8:
return "u";
case Utf8View:
return "vu";
case BinaryView:
return "vz";
case NONE:
throw new IllegalArgumentException("Arrow type ID is NONE");
default:
Expand Down Expand Up @@ -305,6 +309,10 @@ static ArrowType asType(String format, long flags)
case "+m":
boolean keysSorted = (flags & Flags.ARROW_FLAG_MAP_KEYS_SORTED) != 0;
return new ArrowType.Map(keysSorted);
case "vu":
return new ArrowType.Utf8View();
case "vz":
return new ArrowType.BinaryView();
default:
String[] parts = format.split(":", 2);
if (parts.length == 2) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ private void appendNodes(
int expectedBufferCount =
(int) (TypeLayout.getTypeBufferCount(vector.getField().getType()) + variadicBufferCount);
// only update variadicBufferCounts for vectors that have variadic buffers
if (variadicBufferCount > 0) {
if (vector instanceof BaseVariableWidthViewVector) {
variadicBufferCounts.add(variadicBufferCount);
}
if (fieldBuffers.size() != expectedBufferCount) {
Expand Down
56 changes: 56 additions & 0 deletions java/c/src/test/java/org/apache/arrow/c/DictionaryTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -260,6 +260,28 @@ private void createStructVector(StructVector vector) {
vector.setValueCount(2);
}

private void createStructVectorInline(StructVector vector) {
final ViewVarCharVector child1 =
vector.addOrGet(
"f0", FieldType.nullable(MinorType.VIEWVARCHAR.getType()), ViewVarCharVector.class);
final IntVector child2 =
vector.addOrGet("f1", FieldType.nullable(MinorType.INT.getType()), IntVector.class);

// Write the values to child 1
child1.allocateNew();
child1.set(0, "012345678".getBytes());
child1.set(1, "01234".getBytes());
vector.setIndexDefined(0);

// Write the values to child 2
child2.allocateNew();
child2.set(0, 10);
child2.set(1, 11);
vector.setIndexDefined(1);

vector.setValueCount(2);
}

@Test
public void testVectorLoadUnloadOnStructVector() {
try (final StructVector structVector1 = StructVector.empty("struct", allocator)) {
Expand Down Expand Up @@ -293,4 +315,38 @@ public void testVectorLoadUnloadOnStructVector() {
}
}
}

@Test
public void testVectorLoadUnloadOnStructVectorWithInline() {
try (final StructVector structVector1 = StructVector.empty("struct", allocator)) {
createStructVectorInline(structVector1);
Field field1 = structVector1.getField();
Schema schema = new Schema(field1.getChildren());
StructVectorUnloader vectorUnloader = new StructVectorUnloader(structVector1);

try (ArrowRecordBatch recordBatch = vectorUnloader.getRecordBatch();
BufferAllocator finalVectorsAllocator =
allocator.newChildAllocator("struct", 0, Long.MAX_VALUE); ) {
// validating recordBatch contains an output for variadicBufferCounts
assertFalse(recordBatch.getVariadicBufferCounts().isEmpty());
assertEquals(1, recordBatch.getVariadicBufferCounts().size());
assertEquals(0, recordBatch.getVariadicBufferCounts().get(0));

StructVectorLoader vectorLoader = new StructVectorLoader(schema);
try (StructVector structVector2 = vectorLoader.load(finalVectorsAllocator, recordBatch)) {
// Improve this after fixing https://github.com/apache/arrow/issues/41933
// assertTrue(VectorEqualsVisitor.vectorEquals(structVector1, structVector2), "vectors are
// not equivalent");
assertTrue(
VectorEqualsVisitor.vectorEquals(
structVector1.getChild("f0"), structVector2.getChild("f0")),
"vectors are not equivalent");
assertTrue(
VectorEqualsVisitor.vectorEquals(
structVector1.getChild("f1"), structVector2.getChild("f1")),
"vectors are not equivalent");
}
}
}
}
}
75 changes: 75 additions & 0 deletions java/c/src/test/java/org/apache/arrow/c/RoundtripTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
import org.apache.arrow.vector.VarBinaryVector;
import org.apache.arrow.vector.VarCharVector;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.ViewVarBinaryVector;
import org.apache.arrow.vector.ViewVarCharVector;
import org.apache.arrow.vector.ZeroVector;
import org.apache.arrow.vector.compare.VectorEqualsVisitor;
import org.apache.arrow.vector.complex.FixedSizeListVector;
Expand Down Expand Up @@ -524,6 +526,79 @@ public void testVarBinaryVector() {
}
}

private String generateString(String str, int repetition) {
StringBuilder aRepeated = new StringBuilder();
for (int i = 0; i < repetition; i++) {
aRepeated.append(str);
}
return aRepeated.toString();
}

@Test
public void testViewVector() {
// ViewVarCharVector with short strings
try (final ViewVarCharVector vector = new ViewVarCharVector("v1", allocator)) {
setVector(
vector,
"abc".getBytes(StandardCharsets.UTF_8),
"def".getBytes(StandardCharsets.UTF_8),
null);
assertTrue(roundtrip(vector, ViewVarCharVector.class));
}

// ViewVarCharVector with long strings
try (final ViewVarCharVector vector = new ViewVarCharVector("v2", allocator)) {
setVector(
vector,
"01234567890123".getBytes(StandardCharsets.UTF_8),
"01234567890123567".getBytes(StandardCharsets.UTF_8),
null);
assertTrue(roundtrip(vector, ViewVarCharVector.class));
}

// ViewVarBinaryVector with short values
try (final ViewVarBinaryVector vector = new ViewVarBinaryVector("v3", allocator)) {
setVector(
vector,
"abc".getBytes(StandardCharsets.UTF_8),
"def".getBytes(StandardCharsets.UTF_8),
null);
assertTrue(roundtrip(vector, ViewVarBinaryVector.class));
}

// ViewVarBinaryVector with long values
try (final ViewVarBinaryVector vector = new ViewVarBinaryVector("v4", allocator)) {
setVector(
vector,
"01234567890123".getBytes(StandardCharsets.UTF_8),
"01234567890123567".getBytes(StandardCharsets.UTF_8),
null);
assertTrue(roundtrip(vector, ViewVarBinaryVector.class));
}

List<byte[]> byteArrayList = new ArrayList<>();
for (int i = 1; i <= 500; i++) {
StringBuilder sb = new StringBuilder(i);
for (int j = 0; j < i; j++) {
sb.append(j); // or any other character
}
byte[] bytes = sb.toString().getBytes(StandardCharsets.UTF_8);
byteArrayList.add(bytes);
}

// ViewVarCharVector with short long strings with multiple data buffers
try (final ViewVarCharVector vector = new ViewVarCharVector("v5", allocator)) {
setVector(vector, byteArrayList.toArray(new byte[0][]));
assertTrue(roundtrip(vector, ViewVarCharVector.class));
}

// ViewVarBinaryVector with short long strings with multiple data buffers
try (final ViewVarBinaryVector vector = new ViewVarBinaryVector("v6", allocator)) {
setVector(vector, byteArrayList.toArray(new byte[0][]));
assertTrue(roundtrip(vector, ViewVarBinaryVector.class));
}
}

@Test
public void testVarCharVector() {
try (final VarCharVector vector = new VarCharVector("v", allocator)) {
Expand Down
86 changes: 86 additions & 0 deletions java/c/src/test/java/org/apache/arrow/c/StreamTest.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@
import org.apache.arrow.vector.VectorLoader;
import org.apache.arrow.vector.VectorSchemaRoot;
import org.apache.arrow.vector.VectorUnloader;
import org.apache.arrow.vector.ViewVarBinaryVector;
import org.apache.arrow.vector.ViewVarCharVector;
import org.apache.arrow.vector.compare.Range;
import org.apache.arrow.vector.compare.RangeEqualsVisitor;
import org.apache.arrow.vector.dictionary.Dictionary;
Expand Down Expand Up @@ -134,6 +136,90 @@ public void roundtripStrings() throws Exception {
}
}

@Test
public void roundtripStringViews() throws Exception {
final Schema schema =
new Schema(
Arrays.asList(
Field.nullable("ints", new ArrowType.Int(32, true)),
Field.nullable("string_views", new ArrowType.Utf8View())));
final List<ArrowRecordBatch> batches = new ArrayList<>();
try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
final IntVector ints = (IntVector) root.getVector(0);
final ViewVarCharVector strs = (ViewVarCharVector) root.getVector(1);
VectorUnloader unloader = new VectorUnloader(root);

root.allocateNew();
ints.setSafe(0, 1);
ints.setSafe(1, 2);
ints.setSafe(2, 4);
ints.setSafe(3, 8);
strs.setSafe(0, "".getBytes(StandardCharsets.UTF_8));
strs.setSafe(1, "a".getBytes(StandardCharsets.UTF_8));
strs.setSafe(2, "bc1234567890bc".getBytes(StandardCharsets.UTF_8));
strs.setSafe(3, "defg1234567890defg".getBytes(StandardCharsets.UTF_8));
root.setRowCount(4);
batches.add(unloader.getRecordBatch());

root.allocateNew();
ints.setSafe(0, 1);
ints.setNull(1);
ints.setSafe(2, 4);
ints.setNull(3);
strs.setSafe(0, "".getBytes(StandardCharsets.UTF_8));
strs.setNull(1);
strs.setSafe(2, "bc1234567890bc".getBytes(StandardCharsets.UTF_8));
strs.setNull(3);
root.setRowCount(4);
batches.add(unloader.getRecordBatch());
roundtrip(schema, batches);
}
}

@Test
public void roundtripBinaryViews() throws Exception {
final Schema schema =
new Schema(
Arrays.asList(
Field.nullable("ints", new ArrowType.Int(32, true)),
Field.nullable("binary_views", new ArrowType.BinaryView())));
final List<ArrowRecordBatch> batches = new ArrayList<>();
try (final VectorSchemaRoot root = VectorSchemaRoot.create(schema, allocator)) {
final IntVector ints = (IntVector) root.getVector(0);
final ViewVarBinaryVector strs = (ViewVarBinaryVector) root.getVector(1);
VectorUnloader unloader = new VectorUnloader(root);

root.allocateNew();
ints.setSafe(0, 1);
ints.setSafe(1, 2);
ints.setSafe(2, 4);
ints.setSafe(3, 8);
strs.setSafe(0, new byte[0]);
strs.setSafe(1, new byte[] {97});
strs.setSafe(2, new byte[] {98, 99, 49, 50, 51, 52, 53, 54, 55, 56, 57, 48, 98, 99});
strs.setSafe(
3,
new byte[] {
100, 101, 102, 103, 49, 50, 51, 52, 53, 54, 55, 56, 57, 48, 100, 101, 102, 103
});
root.setRowCount(4);
batches.add(unloader.getRecordBatch());

root.allocateNew();
ints.setSafe(0, 1);
ints.setNull(1);
ints.setSafe(2, 4);
ints.setNull(3);
strs.setSafe(0, new byte[0]);
strs.setNull(1);
strs.setSafe(2, new byte[] {98, 99, 49, 50, 51, 52, 53, 54, 55, 56, 57, 48, 98, 99});
strs.setNull(3);
root.setRowCount(4);
batches.add(unloader.getRecordBatch());
roundtrip(schema, batches);
}
}

@Test
public void roundtripDictionary() throws Exception {
final ArrowType.Int indexType = new ArrowType.Int(32, true);
Expand Down
Loading

0 comments on commit 3711657

Please sign in to comment.