diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java index ffedc150339..5454d008b92 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultJsonRow.java @@ -194,7 +194,7 @@ public int getSize() { @Override public ColumnVector getElements() { - return new DefaultGenericVector(arrayType.getElementType(), elements); + return DefaultGenericVector.fromArray(arrayType.getElementType(), elements); } }; } @@ -206,7 +206,7 @@ public ColumnVector getElements() { throw new RuntimeException("MapType with a key type of `String` is supported, " + "received a key type: " + mapType.getKeyType()); } - List keys = new ArrayList<>(jsonValue.size()); + List keys = new ArrayList<>(jsonValue.size()); List values = new ArrayList<>(jsonValue.size()); final Iterator> iter = jsonValue.fields(); @@ -229,12 +229,12 @@ public int getSize() { @Override public ColumnVector getKeys() { - return new DefaultGenericVector(mapType.getKeyType(), keys.toArray()); + return DefaultGenericVector.fromList(mapType.getKeyType(), keys); } @Override public ColumnVector getValues() { - return new DefaultGenericVector(mapType.getValueType(), values.toArray()); + return DefaultGenericVector.fromList(mapType.getValueType(), values); } }; } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java index f701aaeb03d..31cc7f5e0c0 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java @@ -15,22 +15,21 @@ */ package io.delta.kernel.defaults.internal.data; -import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.function.Function; -import static java.util.Objects.requireNonNull; import io.delta.kernel.data.*; -import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; -import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; +import io.delta.kernel.defaults.internal.data.vector.DefaultSubFieldVector; /** * {@link ColumnarBatch} wrapper around list of {@link Row} objects. + * TODO: We should change the {@link io.delta.kernel.defaults.client.DefaultJsonHandler} to + * generate data in true columnar format than wrapping a set of rows with a columnar batch + * interface. */ public class DefaultRowBasedColumnarBatch implements ColumnarBatch { private final StructType schema; @@ -71,7 +70,7 @@ public ColumnVector getColumnVector(int ordinal) { if (!columnVectors.get(ordinal).isPresent()) { final StructField field = schema.at(ordinal); - final ColumnVector vector = new SubFieldColumnVector( + final ColumnVector vector = new DefaultSubFieldVector( getSize(), field.getDataType(), ordinal, @@ -119,156 +118,4 @@ public ColumnarBatch withDeletedColumnAt(int ordinal) { newSchema, newColumnVectorArr); } - - /** - * {@link ColumnVector} wrapper on top of {@link Row} objects. This wrapper allows referncing - * any nested level column vector from a set of rows. - * TODO: We should change the {@link io.delta.kernel.defaults.client.DefaultJsonHandler} to - * generate data in true columnar format than wrapping a set of rows with a columnar batch - * interface. - */ - private static class SubFieldColumnVector implements ColumnVector { - private final int size; - private final DataType dataType; - private final int columnOrdinal; - private final Function rowIdToRowAccessor; - - /** - * Create an instance of {@link SubFieldColumnVector} - * - * @param size Number of elements in the vector - * @param dataType Datatype of the vector - * @param columnOrdinal Ordinal of the column represented by this vector in the rows - * returned by {@link #rowIdToRowAccessor} - * @param rowIdToRowAccessor {@link Function} that returns a {@link Row} object for given - * rowId - */ - SubFieldColumnVector( - int size, - DataType dataType, - int columnOrdinal, - Function rowIdToRowAccessor) { - checkArgument(size >= 0, "invalid size: %s", size); - this.size = size; - checkArgument(columnOrdinal >= 0, "invalid column ordinal: %s", columnOrdinal); - this.columnOrdinal = columnOrdinal; - this.rowIdToRowAccessor = - requireNonNull(rowIdToRowAccessor, "rowIdToRowAccessor is null"); - this.dataType = requireNonNull(dataType, "dataType is null"); - } - - @Override - public DataType getDataType() { - return dataType; - } - - @Override - public int getSize() { - return size; - } - - @Override - public void close() { /* nothing to close */ } - - @Override - public boolean isNullAt(int rowId) { - assertValidRowId(rowId); - Row row = rowIdToRowAccessor.apply(rowId); - return row == null || row.isNullAt(columnOrdinal); - } - - @Override - public boolean getBoolean(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getBoolean(columnOrdinal); - } - - @Override - public byte getByte(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getByte(columnOrdinal); - } - - @Override - public short getShort(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getShort(columnOrdinal); - } - - @Override - public int getInt(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getInt(columnOrdinal); - } - - @Override - public long getLong(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getLong(columnOrdinal); - } - - @Override - public float getFloat(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getFloat(columnOrdinal); - } - - @Override - public double getDouble(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getDouble(columnOrdinal); - } - - @Override - public byte[] getBinary(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getBinary(columnOrdinal); - } - - @Override - public String getString(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getString(columnOrdinal); - } - - @Override - public BigDecimal getDecimal(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getDecimal(columnOrdinal); - } - - @Override - public MapValue getMap(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getMap(columnOrdinal); - } - - @Override - public Row getStruct(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal); - } - - @Override - public ArrayValue getArray(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal); - } - - @Override - public ColumnVector getChild(int childOrdinal) { - StructType structType = (StructType) dataType; - StructField childField = structType.at(childOrdinal); - return new SubFieldColumnVector( - size, - childField.getDataType(), - childOrdinal, - (rowId) -> (rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal))); - } - - private void assertValidRowId(int rowId) { - checkArgument(rowId < size, - "Invalid rowId: " + rowId + ", max allowed rowId is: " + (size - 1)); - } - } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultConstantVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultConstantVector.java index b96005bbf80..20d7fb819bc 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultConstantVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultConstantVector.java @@ -15,109 +15,13 @@ */ package io.delta.kernel.defaults.internal.data.vector; -import java.math.BigDecimal; - -import io.delta.kernel.data.ArrayValue; -import io.delta.kernel.data.ColumnVector; -import io.delta.kernel.data.MapValue; -import io.delta.kernel.data.Row; import io.delta.kernel.types.DataType; public class DefaultConstantVector - implements ColumnVector { - private final DataType dataType; - private final int numRows; - private final Object value; + extends DefaultGenericVector { public DefaultConstantVector(DataType dataType, int numRows, Object value) { // TODO: Validate datatype and value object type - this.dataType = dataType; - this.numRows = numRows; - this.value = value; - } - - @Override - public DataType getDataType() { - return dataType; - } - - @Override - public int getSize() { - return numRows; - } - - @Override - public void close() { - // nothing to close - } - - @Override - public boolean isNullAt(int rowId) { - return value == null; - } - - @Override - public boolean getBoolean(int rowId) { - return (boolean) value; - } - - @Override - public byte getByte(int rowId) { - return (byte) value; - } - - @Override - public short getShort(int rowId) { - return (short) value; - } - - @Override - public int getInt(int rowId) { - return (int) value; - } - - @Override - public long getLong(int rowId) { - return (long) value; - } - - @Override - public float getFloat(int rowId) { - return (float) value; - } - - @Override - public double getDouble(int rowId) { - return (double) value; - } - - @Override - public byte[] getBinary(int rowId) { - return (byte[]) value; - } - - @Override - public String getString(int rowId) { - return (String) value; - } - - @Override - public BigDecimal getDecimal(int rowId) { - return (BigDecimal) value; - } - - @Override - public MapValue getMap(int rowId) { - return (MapValue) value; - } - - @Override - public Row getStruct(int rowId) { - return (Row) value; - } - - @Override - public ArrayValue getArray(int rowId) { - return (ArrayValue) value; + super(numRows, dataType, (rowId) -> value); } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java index 0253b82c93c..817757e34e5 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java @@ -16,24 +16,40 @@ package io.delta.kernel.defaults.internal.data.vector; import java.math.BigDecimal; +import java.util.List; +import java.util.function.Function; import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnVector; import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; import io.delta.kernel.types.*; +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; /** * Generic column vector implementation to expose an array of objects as a column vector. */ public class DefaultGenericVector implements ColumnVector { + public static DefaultGenericVector fromArray(DataType dataType, Object[] elements) { + return new DefaultGenericVector(elements.length, dataType, rowId -> elements[rowId]); + } + + public static DefaultGenericVector fromList(DataType dataType, List elements) { + return new DefaultGenericVector(elements.size(), dataType, rowId -> elements.get(rowId)); + } + + private final int size; private final DataType dataType; - private final Object[] values; + private final Function rowIdToValueAccessor; - public DefaultGenericVector(DataType dataType, Object[] values) { + protected DefaultGenericVector( + int size, + DataType dataType, + Function rowIdToValueAccessor) { + this.size = size; this.dataType = dataType; - this.values = values; + this.rowIdToValueAccessor = rowIdToValueAccessor; } @Override @@ -43,7 +59,7 @@ public DataType getDataType() { @Override public int getSize() { - return values.length; + return size; } @Override @@ -53,87 +69,112 @@ public void close() { @Override public boolean isNullAt(int rowId) { - return values[rowId] == null; + assertValidRowId(rowId); + return rowIdToValueAccessor.apply(rowId) == null; } @Override public boolean getBoolean(int rowId) { + assertValidRowId(rowId); throwIfUnsafeAccess(BooleanType.class, "boolean"); - return (boolean) values[rowId]; + return (boolean) rowIdToValueAccessor.apply(rowId); } @Override public byte getByte(int rowId) { + assertValidRowId(rowId); throwIfUnsafeAccess(ByteType.class, "byte"); - return (byte) values[rowId]; + return (byte) rowIdToValueAccessor.apply(rowId); } @Override public short getShort(int rowId) { + assertValidRowId(rowId); throwIfUnsafeAccess(ShortType.class, "short"); - return (short) values[rowId]; + return (short) rowIdToValueAccessor.apply(rowId); } @Override public int getInt(int rowId) { - throwIfUnsafeAccess(IntegerType.class, "integer"); - return (int) values[rowId]; + assertValidRowId(rowId); + throwIfUnsafeAccess(IntegerType.class, DateType.class, dataType.toString()); + return (int) rowIdToValueAccessor.apply(rowId); } @Override public long getLong(int rowId) { - throwIfUnsafeAccess(LongType.class, "long"); - return (long) values[rowId]; + assertValidRowId(rowId); + throwIfUnsafeAccess(LongType.class, TimestampType.class, dataType.toString()); + return (long) rowIdToValueAccessor.apply(rowId); } @Override public float getFloat(int rowId) { + assertValidRowId(rowId); throwIfUnsafeAccess(FloatType.class, "float"); - return (float) values[rowId]; + return (float) rowIdToValueAccessor.apply(rowId); } @Override public double getDouble(int rowId) { + assertValidRowId(rowId); throwIfUnsafeAccess(DoubleType.class, "double"); - return (double) values[rowId]; + return (double) rowIdToValueAccessor.apply(rowId); } @Override public String getString(int rowId) { + assertValidRowId(rowId); throwIfUnsafeAccess(StringType.class, "string"); - return (String) values[rowId]; + return (String) rowIdToValueAccessor.apply(rowId); } @Override public BigDecimal getDecimal(int rowId) { + assertValidRowId(rowId); throwIfUnsafeAccess(DecimalType.class, "decimal"); - return (BigDecimal) values[rowId]; + return (BigDecimal) rowIdToValueAccessor.apply(rowId); } @Override public byte[] getBinary(int rowId) { + assertValidRowId(rowId); throwIfUnsafeAccess(BinaryType.class, "binary"); - return (byte[]) values[rowId]; + return (byte[]) rowIdToValueAccessor.apply(rowId); } @Override public Row getStruct(int rowId) { + assertValidRowId(rowId); throwIfUnsafeAccess(StructType.class, "struct"); - return (Row) values[rowId]; + return (Row) rowIdToValueAccessor.apply(rowId); } @Override public ArrayValue getArray(int rowId) { + assertValidRowId(rowId); // TODO: not sufficient check, also need to check the element type throwIfUnsafeAccess(ArrayType.class, "array"); - return (ArrayValue) values[rowId]; + return (ArrayValue) rowIdToValueAccessor.apply(rowId); } @Override public MapValue getMap(int rowId) { + assertValidRowId(rowId); // TODO: not sufficient check, also need to check the element types throwIfUnsafeAccess(MapType.class, "map"); - return (MapValue) values[rowId]; + return (MapValue) rowIdToValueAccessor.apply(rowId); + } + + @Override + public ColumnVector getChild(int ordinal) { + throwIfUnsafeAccess(StructType.class, "struct"); + StructType structType = (StructType) dataType; + return new DefaultSubFieldVector( + getSize(), + structType.at(ordinal).getDataType(), + ordinal, + (rowId) -> (Row) rowIdToValueAccessor.apply(rowId)); } private void throwIfUnsafeAccess( Class expDataType, String accessType) { @@ -144,4 +185,24 @@ private void throwIfUnsafeAccess( Class expDataType, String dataType); throw new UnsupportedOperationException(msg); } - }} + } + + private void throwIfUnsafeAccess( + Class expDataType1, + Class expDataType2, + String accessType) { + if (!(expDataType1.isAssignableFrom(dataType.getClass()) || + expDataType2.isAssignableFrom(dataType.getClass()))) { + String msg = String.format( + "Trying to access a `%s` value from vector of type `%s`", + accessType, + dataType); + throw new UnsupportedOperationException(msg); + } + } + + private void assertValidRowId(int rowId) { + checkArgument(rowId < size, + "Invalid rowId: " + rowId + ", max allowed rowId is: " + (size - 1)); + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java new file mode 100644 index 00000000000..7d0edbb7c6e --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java @@ -0,0 +1,179 @@ +/* + * Copyright (2023) The Delta Lake Project Authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.delta.kernel.defaults.internal.data.vector; + +import java.math.BigDecimal; +import java.util.function.Function; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.Row; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.StructField; +import io.delta.kernel.types.StructType; + +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; + +/** + * {@link ColumnVector} wrapper on top of {@link Row} objects. This wrapper allows referencing + * any nested level column vector from a set of rows. + */ +public class DefaultSubFieldVector implements ColumnVector { + private final int size; + private final DataType dataType; + private final int columnOrdinal; + private final Function rowIdToRowAccessor; + + /** + * Create an instance of {@link DefaultSubFieldVector} + * + * @param size Number of elements in the vector + * @param dataType Datatype of the vector + * @param columnOrdinal Ordinal of the column represented by this vector in the rows + * returned by {@link #rowIdToRowAccessor} + * @param rowIdToRowAccessor {@link Function} that returns a {@link Row} object for given + * rowId + */ + public DefaultSubFieldVector( + int size, + DataType dataType, + int columnOrdinal, + Function rowIdToRowAccessor) { + checkArgument(size >= 0, "invalid size: %s", size); + this.size = size; + checkArgument(columnOrdinal >= 0, "invalid column ordinal: %s", columnOrdinal); + this.columnOrdinal = columnOrdinal; + this.rowIdToRowAccessor = + requireNonNull(rowIdToRowAccessor, "rowIdToRowAccessor is null"); + this.dataType = requireNonNull(dataType, "dataType is null"); + } + + @Override + public DataType getDataType() { + return dataType; + } + + @Override + public int getSize() { + return size; + } + + @Override + public void close() { /* nothing to close */ } + + @Override + public boolean isNullAt(int rowId) { + assertValidRowId(rowId); + Row row = rowIdToRowAccessor.apply(rowId); + return row == null || row.isNullAt(columnOrdinal); + } + + @Override + public boolean getBoolean(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getBoolean(columnOrdinal); + } + + @Override + public byte getByte(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getByte(columnOrdinal); + } + + @Override + public short getShort(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getShort(columnOrdinal); + } + + @Override + public int getInt(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getInt(columnOrdinal); + } + + @Override + public long getLong(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getLong(columnOrdinal); + } + + @Override + public float getFloat(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getFloat(columnOrdinal); + } + + @Override + public double getDouble(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getDouble(columnOrdinal); + } + + @Override + public byte[] getBinary(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getBinary(columnOrdinal); + } + + @Override + public String getString(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getString(columnOrdinal); + } + + @Override + public BigDecimal getDecimal(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getDecimal(columnOrdinal); + } + + @Override + public MapValue getMap(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getMap(columnOrdinal); + } + + @Override + public Row getStruct(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal); + } + + @Override + public ArrayValue getArray(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal); + } + + @Override + public ColumnVector getChild(int childOrdinal) { + StructType structType = (StructType) dataType; + StructField childField = structType.at(childOrdinal); + return new DefaultSubFieldVector( + size, + childField.getDataType(), + childOrdinal, + (rowId) -> (rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal))); + } + + private void assertValidRowId(int rowId) { + checkArgument(rowId < size, + "Invalid rowId: " + rowId + ", max allowed rowId is: " + (size - 1)); + } +} diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java index 5b061f4c879..144e9438cb3 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java @@ -144,6 +144,11 @@ public ArrayValue getArray(int rowId) { return underlyingVector.getArray(offset + rowId); } + @Override + public ColumnVector getChild(int ordinal) { + return new DefaultViewVector(underlyingVector.getChild(ordinal), offset, offset + size); + } + private void checkValidRowId(int rowId) { checkArgument(rowId >= 0 && rowId < size, String.format( diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java index 81300a60f69..d736ee98554 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/client/TestDefaultJsonHandler.java @@ -21,10 +21,12 @@ import org.apache.hadoop.conf.Configuration; import org.junit.Test; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertTrue; import io.delta.kernel.client.FileReadContext; import io.delta.kernel.client.FileSystemClient; import io.delta.kernel.client.JsonHandler; +import io.delta.kernel.data.ArrayValue; import io.delta.kernel.data.ColumnarBatch; import io.delta.kernel.data.FileDataReadResult; import io.delta.kernel.data.Row; @@ -156,7 +158,9 @@ public void parseNestedComplexTypes() throws IOException { " \"array\": [0, 1, null]," + " \"nested_array\": [[\"a\", \"b\"], [\"c\"], []]," + " \"map\": {\"a\": true, \"b\": false},\n" + - " \"nested_map\": {\"a\": {\"one\": [], \"two\": [1, 2, 3]}, \"b\": {}}\n" + + " \"nested_map\": {\"a\": {\"one\": [], \"two\": [1, 2, 3]}, \"b\": {}},\n" + + " \"array_of_struct\": " + + "[{\"field1\": \"foo\", \"field2\": 3}, {\"field1\": null}]\n" + "}"; StructType schema = new StructType() .add("array", new ArrayType(IntegerType.INTEGER, true)) @@ -171,7 +175,15 @@ public void parseNestedComplexTypes() throws IOException { true ), true - )); + ) + ).add("array_of_struct", + new ArrayType( + new StructType() + .add("field1", StringType.STRING, true) + .add("field2", IntegerType.INTEGER, true), + true + ) + ); ColumnarBatch batch = JSON_HANDLER.parseJson(singletonStringColumnVector(json), schema); try (CloseableIterator rows = batch.getRows()) { @@ -202,6 +214,18 @@ public void parseNestedComplexTypes() throws IOException { } }; assertEquals(exp3, VectorUtils.toJavaMap(result.getMap(3))); + ArrayValue arrayOfStruct = result.getArray(4); + assertEquals(arrayOfStruct.getSize(), 2); + // check getStruct + assertEquals("foo", arrayOfStruct.getElements().getStruct(0).getString(0)); + assertEquals(3, arrayOfStruct.getElements().getStruct(0).getInt(1)); + assertTrue(arrayOfStruct.getElements().getStruct(1).isNullAt(0)); + assertTrue(arrayOfStruct.getElements().getStruct(1).isNullAt(1)); + // check getChild + assertEquals("foo", arrayOfStruct.getElements().getChild(0).getString(0)); + assertEquals(3, arrayOfStruct.getElements().getChild(1).getInt(0)); + assertTrue(arrayOfStruct.getElements().getChild(0).isNullAt(1)); + assertTrue(arrayOfStruct.getElements().getChild(1).isNullAt(1)); } } diff --git a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/internal/parquet/TestParquetBatchReader.java b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/internal/parquet/TestParquetBatchReader.java index c46a02665b5..4c661913bba 100644 --- a/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/internal/parquet/TestParquetBatchReader.java +++ b/kernel/kernel-defaults/src/test/java/io/delta/kernel/defaults/internal/parquet/TestParquetBatchReader.java @@ -364,8 +364,11 @@ private static void verifyRowFromAllTypesFile( assertEquals(2, arrayValue.getSize()); assertEquals(2, elementVector.getSize()); assertTrue(elementVector.getDataType() instanceof StructType); + // check getStruct Row item0 = elementVector.getStruct(0); assertEquals(rowId, item0.getLong(0)); + // also check DefaultViewVector implements getChild + assertEquals(rowId, elementVector.getChild(0).getLong(0)); assertTrue(elementVector.isNullAt(1)); break; } diff --git a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala index c8e2b1044d2..ff05f91f6c2 100644 --- a/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala +++ b/kernel/kernel-defaults/src/test/scala/io/delta/kernel/defaults/utils/TestUtils.scala @@ -332,18 +332,13 @@ trait TestUtils extends Assertions { new MapValue() { override def getSize: Int = map.size - override def getKeys = new DefaultGenericVector( - keyType, keys.toArray) + override def getKeys = DefaultGenericVector.fromArray(keyType, keys.toArray) - override def getValues = new DefaultGenericVector( - valueType, values.toArray) + override def getValues = DefaultGenericVector.fromArray(valueType, values.toArray) } } } - new DefaultGenericVector( - dataType, - mapValues.map(getMapValue).toArray - ) + DefaultGenericVector.fromArray(dataType, mapValues.map(getMapValue).toArray) } }