diff --git a/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java b/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java index 2987d8d8ed8..89347f7d650 100644 --- a/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java +++ b/kernel/kernel-api/src/test/java/io/delta/kernel/internal/types/JsonHandlerTestImpl.java @@ -20,6 +20,7 @@ import java.util.Iterator; import java.util.List; import java.util.Map; +import java.util.stream.Collectors; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; @@ -42,6 +43,7 @@ import io.delta.kernel.types.StructType; import io.delta.kernel.utils.CloseableIterator; import io.delta.kernel.utils.Utils; +import static io.delta.kernel.internal.util.InternalUtils.checkArgument; /** * Implementation of {@link JsonHandler} for testing Delta Kernel APIs @@ -407,5 +409,151 @@ public ArrayValue getArray(int rowId) { public MapValue getMap(int rowId) { return (MapValue) values.get(rowId); } + + @Override + public ColumnVector getChild(int ordinal) { + checkArgument(dataType instanceof StructType); + StructType structType = (StructType) dataType; + List rows = values.stream() + .map(row -> (Row) row) + .collect(Collectors.toList()); + return new RowBasedVector( + structType.at(ordinal).getDataType(), + rows, + ordinal + ); + } + } + + /** + * Wrapper around list of {@link Row}s to expose the rows as a column vector + */ + private static class RowBasedVector implements ColumnVector { + private final DataType dataType; + private final List rows; + private final int columnOrdinal; + + RowBasedVector(DataType dataType, List rows, int columnOrdinal) { + this.dataType = dataType; + this.rows = rows; + this.columnOrdinal = columnOrdinal; + } + + @Override + public DataType getDataType() { + return dataType; + } + + @Override + public int getSize() { + return rows.size(); + } + + @Override + public void close() { /* nothing to close */ } + + @Override + public boolean isNullAt(int rowId) { + assertValidRowId(rowId); + if (rows.get(rowId) == null) { + return true; + } + return rows.get(rowId).isNullAt(columnOrdinal); + } + + @Override + public boolean getBoolean(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getBoolean(columnOrdinal); + } + + @Override + public byte getByte(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getByte(columnOrdinal); + } + + @Override + public short getShort(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getShort(columnOrdinal); + } + + @Override + public int getInt(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getInt(columnOrdinal); + } + + @Override + public long getLong(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getLong(columnOrdinal); + } + + @Override + public float getFloat(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getFloat(columnOrdinal); + } + + @Override + public double getDouble(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getDouble(columnOrdinal); + } + + @Override + public String getString(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getString(columnOrdinal); + } + + @Override + public BigDecimal getDecimal(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getDecimal(columnOrdinal); + } + + @Override + public byte[] getBinary(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getBinary(columnOrdinal); + } + + @Override + public MapValue getMap(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getMap(columnOrdinal); + } + + @Override + public ArrayValue getArray(int rowId) { + assertValidRowId(rowId); + return rows.get(rowId).getArray(columnOrdinal); + } + + @Override + public ColumnVector getChild(int ordinal) { + List childRows = rows.stream() + .map(row -> { + if (row == null || row.isNullAt(columnOrdinal)) { + return null; + } else { + return row.getStruct(columnOrdinal); + } + }).collect(Collectors.toList()); + StructType structType = (StructType) dataType; + return new RowBasedVector( + structType.at(ordinal).getDataType(), + childRows, + ordinal + ); + } + + private void assertValidRowId(int rowId) { + checkArgument(rowId < rows.size(), + "Invalid rowId: " + rowId + ", max allowed rowId is: " + (rows.size() - 1)); + } } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java index f701aaeb03d..e888fdc6616 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/DefaultRowBasedColumnarBatch.java @@ -15,19 +15,15 @@ */ package io.delta.kernel.defaults.internal.data; -import java.math.BigDecimal; import java.util.ArrayList; import java.util.List; import java.util.Optional; -import java.util.function.Function; -import static java.util.Objects.requireNonNull; import io.delta.kernel.data.*; -import io.delta.kernel.types.DataType; import io.delta.kernel.types.StructField; import io.delta.kernel.types.StructType; -import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; +import io.delta.kernel.defaults.internal.data.vector.DefaultSubFieldVector; /** * {@link ColumnarBatch} wrapper around list of {@link Row} objects. @@ -71,7 +67,7 @@ public ColumnVector getColumnVector(int ordinal) { if (!columnVectors.get(ordinal).isPresent()) { final StructField field = schema.at(ordinal); - final ColumnVector vector = new SubFieldColumnVector( + final ColumnVector vector = new DefaultSubFieldVector( getSize(), field.getDataType(), ordinal, @@ -119,156 +115,4 @@ public ColumnarBatch withDeletedColumnAt(int ordinal) { newSchema, newColumnVectorArr); } - - /** - * {@link ColumnVector} wrapper on top of {@link Row} objects. This wrapper allows referncing - * any nested level column vector from a set of rows. - * TODO: We should change the {@link io.delta.kernel.defaults.client.DefaultJsonHandler} to - * generate data in true columnar format than wrapping a set of rows with a columnar batch - * interface. - */ - private static class SubFieldColumnVector implements ColumnVector { - private final int size; - private final DataType dataType; - private final int columnOrdinal; - private final Function rowIdToRowAccessor; - - /** - * Create an instance of {@link SubFieldColumnVector} - * - * @param size Number of elements in the vector - * @param dataType Datatype of the vector - * @param columnOrdinal Ordinal of the column represented by this vector in the rows - * returned by {@link #rowIdToRowAccessor} - * @param rowIdToRowAccessor {@link Function} that returns a {@link Row} object for given - * rowId - */ - SubFieldColumnVector( - int size, - DataType dataType, - int columnOrdinal, - Function rowIdToRowAccessor) { - checkArgument(size >= 0, "invalid size: %s", size); - this.size = size; - checkArgument(columnOrdinal >= 0, "invalid column ordinal: %s", columnOrdinal); - this.columnOrdinal = columnOrdinal; - this.rowIdToRowAccessor = - requireNonNull(rowIdToRowAccessor, "rowIdToRowAccessor is null"); - this.dataType = requireNonNull(dataType, "dataType is null"); - } - - @Override - public DataType getDataType() { - return dataType; - } - - @Override - public int getSize() { - return size; - } - - @Override - public void close() { /* nothing to close */ } - - @Override - public boolean isNullAt(int rowId) { - assertValidRowId(rowId); - Row row = rowIdToRowAccessor.apply(rowId); - return row == null || row.isNullAt(columnOrdinal); - } - - @Override - public boolean getBoolean(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getBoolean(columnOrdinal); - } - - @Override - public byte getByte(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getByte(columnOrdinal); - } - - @Override - public short getShort(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getShort(columnOrdinal); - } - - @Override - public int getInt(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getInt(columnOrdinal); - } - - @Override - public long getLong(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getLong(columnOrdinal); - } - - @Override - public float getFloat(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getFloat(columnOrdinal); - } - - @Override - public double getDouble(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getDouble(columnOrdinal); - } - - @Override - public byte[] getBinary(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getBinary(columnOrdinal); - } - - @Override - public String getString(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getString(columnOrdinal); - } - - @Override - public BigDecimal getDecimal(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getDecimal(columnOrdinal); - } - - @Override - public MapValue getMap(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getMap(columnOrdinal); - } - - @Override - public Row getStruct(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal); - } - - @Override - public ArrayValue getArray(int rowId) { - assertValidRowId(rowId); - return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal); - } - - @Override - public ColumnVector getChild(int childOrdinal) { - StructType structType = (StructType) dataType; - StructField childField = structType.at(childOrdinal); - return new SubFieldColumnVector( - size, - childField.getDataType(), - childOrdinal, - (rowId) -> (rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal))); - } - - private void assertValidRowId(int rowId) { - checkArgument(rowId < size, - "Invalid rowId: " + rowId + ", max allowed rowId is: " + (size - 1)); - } - } } diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java index 0253b82c93c..39bb4f3245b 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultGenericVector.java @@ -22,6 +22,7 @@ import io.delta.kernel.data.MapValue; import io.delta.kernel.data.Row; import io.delta.kernel.types.*; +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; /** * Generic column vector implementation to expose an array of objects as a column vector. @@ -136,6 +137,18 @@ public MapValue getMap(int rowId) { return (MapValue) values[rowId]; } + @Override + public ColumnVector getChild(int ordinal) { + checkArgument(dataType instanceof StructType); + StructType structType = (StructType) dataType; + return new DefaultSubFieldVector( + getSize(), + structType.at(ordinal).getDataType(), + ordinal, + (rowId) -> (Row) values[rowId]); + + } + private void throwIfUnsafeAccess( Class expDataType, String accessType) { if (!expDataType.isAssignableFrom(dataType.getClass())) { String msg = String.format( diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java new file mode 100644 index 00000000000..5674355e400 --- /dev/null +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultSubFieldVector.java @@ -0,0 +1,167 @@ +package io.delta.kernel.defaults.internal.data.vector; + +import java.math.BigDecimal; +import java.util.function.Function; +import static java.util.Objects.requireNonNull; + +import io.delta.kernel.data.ArrayValue; +import io.delta.kernel.data.ColumnVector; +import io.delta.kernel.data.MapValue; +import io.delta.kernel.data.Row; +import io.delta.kernel.types.DataType; +import io.delta.kernel.types.StructField; +import io.delta.kernel.types.StructType; + +import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument; + +/** + * {@link ColumnVector} wrapper on top of {@link Row} objects. This wrapper allows referencing + * any nested level column vector from a set of rows. + * TODO: We should change the {@link io.delta.kernel.defaults.client.DefaultJsonHandler} to + * generate data in true columnar format than wrapping a set of rows with a columnar batch + * interface. + */ +public class DefaultSubFieldVector implements ColumnVector { + private final int size; + private final DataType dataType; + private final int columnOrdinal; + private final Function rowIdToRowAccessor; + + /** + * Create an instance of {@link DefaultSubFieldVector} + * + * @param size Number of elements in the vector + * @param dataType Datatype of the vector + * @param columnOrdinal Ordinal of the column represented by this vector in the rows + * returned by {@link #rowIdToRowAccessor} + * @param rowIdToRowAccessor {@link Function} that returns a {@link Row} object for given + * rowId + */ + public DefaultSubFieldVector( + int size, + DataType dataType, + int columnOrdinal, + Function rowIdToRowAccessor) { + checkArgument(size >= 0, "invalid size: %s", size); + this.size = size; + checkArgument(columnOrdinal >= 0, "invalid column ordinal: %s", columnOrdinal); + this.columnOrdinal = columnOrdinal; + this.rowIdToRowAccessor = + requireNonNull(rowIdToRowAccessor, "rowIdToRowAccessor is null"); + this.dataType = requireNonNull(dataType, "dataType is null"); + } + + @Override + public DataType getDataType() { + return dataType; + } + + @Override + public int getSize() { + return size; + } + + @Override + public void close() { /* nothing to close */ } + + @Override + public boolean isNullAt(int rowId) { + assertValidRowId(rowId); + Row row = rowIdToRowAccessor.apply(rowId); + return row == null || row.isNullAt(columnOrdinal); + } + + @Override + public boolean getBoolean(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getBoolean(columnOrdinal); + } + + @Override + public byte getByte(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getByte(columnOrdinal); + } + + @Override + public short getShort(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getShort(columnOrdinal); + } + + @Override + public int getInt(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getInt(columnOrdinal); + } + + @Override + public long getLong(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getLong(columnOrdinal); + } + + @Override + public float getFloat(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getFloat(columnOrdinal); + } + + @Override + public double getDouble(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getDouble(columnOrdinal); + } + + @Override + public byte[] getBinary(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getBinary(columnOrdinal); + } + + @Override + public String getString(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getString(columnOrdinal); + } + + @Override + public BigDecimal getDecimal(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getDecimal(columnOrdinal); + } + + @Override + public MapValue getMap(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getMap(columnOrdinal); + } + + @Override + public Row getStruct(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal); + } + + @Override + public ArrayValue getArray(int rowId) { + assertValidRowId(rowId); + return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal); + } + + @Override + public ColumnVector getChild(int childOrdinal) { + StructType structType = (StructType) dataType; + StructField childField = structType.at(childOrdinal); + return new DefaultSubFieldVector( + size, + childField.getDataType(), + childOrdinal, + (rowId) -> (rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal))); + } + + private void assertValidRowId(int rowId) { + checkArgument(rowId < size, + "Invalid rowId: " + rowId + ", max allowed rowId is: " + (size - 1)); + } +} \ No newline at end of file diff --git a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java index 5b061f4c879..144e9438cb3 100644 --- a/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java +++ b/kernel/kernel-defaults/src/main/java/io/delta/kernel/defaults/internal/data/vector/DefaultViewVector.java @@ -144,6 +144,11 @@ public ArrayValue getArray(int rowId) { return underlyingVector.getArray(offset + rowId); } + @Override + public ColumnVector getChild(int ordinal) { + return new DefaultViewVector(underlyingVector.getChild(ordinal), offset, offset + size); + } + private void checkValidRowId(int rowId) { checkArgument(rowId >= 0 && rowId < size, String.format(