Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Implement getChild for a few remaining column vectors #2133

Merged
merged 7 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -194,7 +194,7 @@ public int getSize() {

@Override
public ColumnVector getElements() {
return new DefaultGenericVector(arrayType.getElementType(), elements);
return DefaultGenericVector.fromArray(arrayType.getElementType(), elements);
}
};
}
Expand All @@ -206,7 +206,7 @@ public ColumnVector getElements() {
throw new RuntimeException("MapType with a key type of `String` is supported, " +
"received a key type: " + mapType.getKeyType());
}
List<String> keys = new ArrayList<>(jsonValue.size());
List<Object> keys = new ArrayList<>(jsonValue.size());
List<Object> values = new ArrayList<>(jsonValue.size());
final Iterator<Map.Entry<String, JsonNode>> iter = jsonValue.fields();

Expand All @@ -229,12 +229,12 @@ public int getSize() {

@Override
public ColumnVector getKeys() {
return new DefaultGenericVector(mapType.getKeyType(), keys.toArray());
return DefaultGenericVector.fromList(mapType.getKeyType(), keys);
}

@Override
public ColumnVector getValues() {
return new DefaultGenericVector(mapType.getValueType(), values.toArray());
return DefaultGenericVector.fromList(mapType.getValueType(), values);
}
};
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,21 @@
*/
package io.delta.kernel.defaults.internal.data;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import static java.util.Objects.requireNonNull;

import io.delta.kernel.data.*;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.StructField;
import io.delta.kernel.types.StructType;

import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;
import io.delta.kernel.defaults.internal.data.vector.DefaultSubFieldVector;

/**
* {@link ColumnarBatch} wrapper around list of {@link Row} objects.
* TODO: We should change the {@link io.delta.kernel.defaults.client.DefaultJsonHandler} to
* generate data in true columnar format than wrapping a set of rows with a columnar batch
* interface.
*/
public class DefaultRowBasedColumnarBatch implements ColumnarBatch {
private final StructType schema;
Expand Down Expand Up @@ -71,7 +70,7 @@ public ColumnVector getColumnVector(int ordinal) {

if (!columnVectors.get(ordinal).isPresent()) {
final StructField field = schema.at(ordinal);
final ColumnVector vector = new SubFieldColumnVector(
final ColumnVector vector = new DefaultSubFieldVector(
getSize(),
field.getDataType(),
ordinal,
Expand Down Expand Up @@ -119,156 +118,4 @@ public ColumnarBatch withDeletedColumnAt(int ordinal) {
newSchema,
newColumnVectorArr);
}

/**
* {@link ColumnVector} wrapper on top of {@link Row} objects. This wrapper allows referncing
* any nested level column vector from a set of rows.
* TODO: We should change the {@link io.delta.kernel.defaults.client.DefaultJsonHandler} to
* generate data in true columnar format than wrapping a set of rows with a columnar batch
* interface.
*/
private static class SubFieldColumnVector implements ColumnVector {
private final int size;
private final DataType dataType;
private final int columnOrdinal;
private final Function<Integer, Row> rowIdToRowAccessor;

/**
* Create an instance of {@link SubFieldColumnVector}
*
* @param size Number of elements in the vector
* @param dataType Datatype of the vector
* @param columnOrdinal Ordinal of the column represented by this vector in the rows
* returned by {@link #rowIdToRowAccessor}
* @param rowIdToRowAccessor {@link Function} that returns a {@link Row} object for given
* rowId
*/
SubFieldColumnVector(
int size,
DataType dataType,
int columnOrdinal,
Function<Integer, Row> rowIdToRowAccessor) {
checkArgument(size >= 0, "invalid size: %s", size);
this.size = size;
checkArgument(columnOrdinal >= 0, "invalid column ordinal: %s", columnOrdinal);
this.columnOrdinal = columnOrdinal;
this.rowIdToRowAccessor =
requireNonNull(rowIdToRowAccessor, "rowIdToRowAccessor is null");
this.dataType = requireNonNull(dataType, "dataType is null");
}

@Override
public DataType getDataType() {
return dataType;
}

@Override
public int getSize() {
return size;
}

@Override
public void close() { /* nothing to close */ }

@Override
public boolean isNullAt(int rowId) {
assertValidRowId(rowId);
Row row = rowIdToRowAccessor.apply(rowId);
return row == null || row.isNullAt(columnOrdinal);
}

@Override
public boolean getBoolean(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getBoolean(columnOrdinal);
}

@Override
public byte getByte(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getByte(columnOrdinal);
}

@Override
public short getShort(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getShort(columnOrdinal);
}

@Override
public int getInt(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getInt(columnOrdinal);
}

@Override
public long getLong(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getLong(columnOrdinal);
}

@Override
public float getFloat(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getFloat(columnOrdinal);
}

@Override
public double getDouble(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getDouble(columnOrdinal);
}

@Override
public byte[] getBinary(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getBinary(columnOrdinal);
}

@Override
public String getString(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getString(columnOrdinal);
}

@Override
public BigDecimal getDecimal(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getDecimal(columnOrdinal);
}

@Override
public MapValue getMap(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getMap(columnOrdinal);
}

@Override
public Row getStruct(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal);
}

@Override
public ArrayValue getArray(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal);
}

@Override
public ColumnVector getChild(int childOrdinal) {
StructType structType = (StructType) dataType;
StructField childField = structType.at(childOrdinal);
return new SubFieldColumnVector(
size,
childField.getDataType(),
childOrdinal,
(rowId) -> (rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal)));
}

private void assertValidRowId(int rowId) {
checkArgument(rowId < size,
"Invalid rowId: " + rowId + ", max allowed rowId is: " + (size - 1));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,109 +15,13 @@
*/
package io.delta.kernel.defaults.internal.data.vector;

import java.math.BigDecimal;

import io.delta.kernel.data.ArrayValue;
import io.delta.kernel.data.ColumnVector;
import io.delta.kernel.data.MapValue;
import io.delta.kernel.data.Row;
import io.delta.kernel.types.DataType;

public class DefaultConstantVector
implements ColumnVector {
private final DataType dataType;
private final int numRows;
private final Object value;
extends DefaultGenericVector {

public DefaultConstantVector(DataType dataType, int numRows, Object value) {
// TODO: Validate datatype and value object type
this.dataType = dataType;
this.numRows = numRows;
this.value = value;
}

@Override
public DataType getDataType() {
return dataType;
}

@Override
public int getSize() {
return numRows;
}

@Override
public void close() {
// nothing to close
}

@Override
public boolean isNullAt(int rowId) {
return value == null;
}

@Override
public boolean getBoolean(int rowId) {
return (boolean) value;
}

@Override
public byte getByte(int rowId) {
return (byte) value;
}

@Override
public short getShort(int rowId) {
return (short) value;
}

@Override
public int getInt(int rowId) {
return (int) value;
}

@Override
public long getLong(int rowId) {
return (long) value;
}

@Override
public float getFloat(int rowId) {
return (float) value;
}

@Override
public double getDouble(int rowId) {
return (double) value;
}

@Override
public byte[] getBinary(int rowId) {
return (byte[]) value;
}

@Override
public String getString(int rowId) {
return (String) value;
}

@Override
public BigDecimal getDecimal(int rowId) {
return (BigDecimal) value;
}

@Override
public MapValue getMap(int rowId) {
return (MapValue) value;
}

@Override
public Row getStruct(int rowId) {
return (Row) value;
}

@Override
public ArrayValue getArray(int rowId) {
return (ArrayValue) value;
super(numRows, dataType, (rowId) -> value);
}
}
Loading