Skip to content

Commit

Permalink
[Kernel] Implement getChild for a few remaining column vectors (#2133)
Browse files Browse the repository at this point in the history
## Description

Provides implementations for `getChild` for column vectors that are missing them.

## How was this patch tested?

Adds simple tests for `DefaultViewVector` and `DefaultGenericVector` (used by complex types in the JSON handler).
#2131 also is based off these changes and uses `getChild` instead of `getStruct` everywhere in the code.
  • Loading branch information
allisonport-db authored Oct 10, 2023
1 parent 2117d50 commit cd02359
Show file tree
Hide file tree
Showing 9 changed files with 309 additions and 291 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public int getSize() {

@Override
public ColumnVector getElements() {
return new DefaultGenericVector(arrayType.getElementType(), elements);
return DefaultGenericVector.fromArray(arrayType.getElementType(), elements);
}
};
}
Expand All @@ -206,7 +206,7 @@ public ColumnVector getElements() {
throw new RuntimeException("MapType with a key type of `String` is supported, " +
"received a key type: " + mapType.getKeyType());
}
List<String> keys = new ArrayList<>(jsonValue.size());
List<Object> keys = new ArrayList<>(jsonValue.size());
List<Object> values = new ArrayList<>(jsonValue.size());
final Iterator<Map.Entry<String, JsonNode>> iter = jsonValue.fields();

Expand All @@ -229,12 +229,12 @@ public int getSize() {

@Override
public ColumnVector getKeys() {
return new DefaultGenericVector(mapType.getKeyType(), keys.toArray());
return DefaultGenericVector.fromList(mapType.getKeyType(), keys);
}

@Override
public ColumnVector getValues() {
return new DefaultGenericVector(mapType.getValueType(), values.toArray());
return DefaultGenericVector.fromList(mapType.getValueType(), values);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,21 @@
*/
package io.delta.kernel.defaults.internal.data;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import static java.util.Objects.requireNonNull;

import io.delta.kernel.data.*;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.StructField;
import io.delta.kernel.types.StructType;

import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;
import io.delta.kernel.defaults.internal.data.vector.DefaultSubFieldVector;

/**
* {@link ColumnarBatch} wrapper around list of {@link Row} objects.
* TODO: We should change the {@link io.delta.kernel.defaults.client.DefaultJsonHandler} to
* generate data in true columnar format than wrapping a set of rows with a columnar batch
* interface.
*/
public class DefaultRowBasedColumnarBatch implements ColumnarBatch {
private final StructType schema;
Expand Down Expand Up @@ -71,7 +70,7 @@ public ColumnVector getColumnVector(int ordinal) {

if (!columnVectors.get(ordinal).isPresent()) {
final StructField field = schema.at(ordinal);
final ColumnVector vector = new SubFieldColumnVector(
final ColumnVector vector = new DefaultSubFieldVector(
getSize(),
field.getDataType(),
ordinal,
Expand Down Expand Up @@ -119,156 +118,4 @@ public ColumnarBatch withDeletedColumnAt(int ordinal) {
newSchema,
newColumnVectorArr);
}

/**
* {@link ColumnVector} wrapper on top of {@link Row} objects. This wrapper allows referncing
* any nested level column vector from a set of rows.
* TODO: We should change the {@link io.delta.kernel.defaults.client.DefaultJsonHandler} to
* generate data in true columnar format than wrapping a set of rows with a columnar batch
* interface.
*/
private static class SubFieldColumnVector implements ColumnVector {
private final int size;
private final DataType dataType;
private final int columnOrdinal;
private final Function<Integer, Row> rowIdToRowAccessor;

/**
* Create an instance of {@link SubFieldColumnVector}
*
* @param size Number of elements in the vector
* @param dataType Datatype of the vector
* @param columnOrdinal Ordinal of the column represented by this vector in the rows
* returned by {@link #rowIdToRowAccessor}
* @param rowIdToRowAccessor {@link Function} that returns a {@link Row} object for given
* rowId
*/
SubFieldColumnVector(
int size,
DataType dataType,
int columnOrdinal,
Function<Integer, Row> rowIdToRowAccessor) {
checkArgument(size >= 0, "invalid size: %s", size);
this.size = size;
checkArgument(columnOrdinal >= 0, "invalid column ordinal: %s", columnOrdinal);
this.columnOrdinal = columnOrdinal;
this.rowIdToRowAccessor =
requireNonNull(rowIdToRowAccessor, "rowIdToRowAccessor is null");
this.dataType = requireNonNull(dataType, "dataType is null");
}

@Override
public DataType getDataType() {
return dataType;
}

@Override
public int getSize() {
return size;
}

@Override
public void close() { /* nothing to close */ }

@Override
public boolean isNullAt(int rowId) {
assertValidRowId(rowId);
Row row = rowIdToRowAccessor.apply(rowId);
return row == null || row.isNullAt(columnOrdinal);
}

@Override
public boolean getBoolean(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getBoolean(columnOrdinal);
}

@Override
public byte getByte(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getByte(columnOrdinal);
}

@Override
public short getShort(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getShort(columnOrdinal);
}

@Override
public int getInt(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getInt(columnOrdinal);
}

@Override
public long getLong(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getLong(columnOrdinal);
}

@Override
public float getFloat(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getFloat(columnOrdinal);
}

@Override
public double getDouble(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getDouble(columnOrdinal);
}

@Override
public byte[] getBinary(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getBinary(columnOrdinal);
}

@Override
public String getString(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getString(columnOrdinal);
}

@Override
public BigDecimal getDecimal(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getDecimal(columnOrdinal);
}

@Override
public MapValue getMap(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getMap(columnOrdinal);
}

@Override
public Row getStruct(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal);
}

@Override
public ArrayValue getArray(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal);
}

@Override
public ColumnVector getChild(int childOrdinal) {
StructType structType = (StructType) dataType;
StructField childField = structType.at(childOrdinal);
return new SubFieldColumnVector(
size,
childField.getDataType(),
childOrdinal,
(rowId) -> (rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal)));
}

private void assertValidRowId(int rowId) {
checkArgument(rowId < size,
"Invalid rowId: " + rowId + ", max allowed rowId is: " + (size - 1));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,109 +15,13 @@
*/
package io.delta.kernel.defaults.internal.data.vector;

import java.math.BigDecimal;

import io.delta.kernel.data.ArrayValue;
import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.MapValue;
import io.delta.kernel.data.Row;
import io.delta.kernel.types.DataType;

public class DefaultConstantVector
implements ColumnVector {
private final DataType dataType;
private final int numRows;
private final Object value;
extends DefaultGenericVector {

public DefaultConstantVector(DataType dataType, int numRows, Object value) {
// TODO: Validate datatype and value object type
this.dataType = dataType;
this.numRows = numRows;
this.value = value;
}

@Override
public DataType getDataType() {
return dataType;
}

@Override
public int getSize() {
return numRows;
}

@Override
public void close() {
// nothing to close
}

@Override
public boolean isNullAt(int rowId) {
return value == null;
}

@Override
public boolean getBoolean(int rowId) {
return (boolean) value;
}

@Override
public byte getByte(int rowId) {
return (byte) value;
}

@Override
public short getShort(int rowId) {
return (short) value;
}

@Override
public int getInt(int rowId) {
return (int) value;
}

@Override
public long getLong(int rowId) {
return (long) value;
}

@Override
public float getFloat(int rowId) {
return (float) value;
}

@Override
public double getDouble(int rowId) {
return (double) value;
}

@Override
public byte[] getBinary(int rowId) {
return (byte[]) value;
}

@Override
public String getString(int rowId) {
return (String) value;
}

@Override
public BigDecimal getDecimal(int rowId) {
return (BigDecimal) value;
}

@Override
public MapValue getMap(int rowId) {
return (MapValue) value;
}

@Override
public Row getStruct(int rowId) {
return (Row) value;
}

@Override
public ArrayValue getArray(int rowId) {
return (ArrayValue) value;
super(numRows, dataType, (rowId) -> value);
}
}
Loading

0 comments on commit cd02359

Please sign in to comment.