Skip to content

Commit

Permalink
implement getChild
Browse files Browse the repository at this point in the history
  • Loading branch information
allisonport-db committed Oct 3, 2023
1 parent 5f9b98e commit 950cca3
Show file tree
Hide file tree
Showing 5 changed files with 335 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -42,6 +43,7 @@
import io.delta.kernel.types.StructType;
import io.delta.kernel.utils.CloseableIterator;
import io.delta.kernel.utils.Utils;
import static io.delta.kernel.internal.util.InternalUtils.checkArgument;

/**
* Implementation of {@link JsonHandler} for testing Delta Kernel APIs
Expand Down Expand Up @@ -407,5 +409,151 @@ public ArrayValue getArray(int rowId) {
public MapValue getMap(int rowId) {
return (MapValue) values.get(rowId);
}

@Override
public ColumnVector getChild(int ordinal) {
checkArgument(dataType instanceof StructType);
StructType structType = (StructType) dataType;
List<Row> rows = values.stream()
.map(row -> (Row) row)
.collect(Collectors.toList());
return new RowBasedVector(
structType.at(ordinal).getDataType(),
rows,
ordinal
);
}
}

/**
* Wrapper around list of {@link Row}s to expose the rows as a column vector
*/
private static class RowBasedVector implements ColumnVector {
private final DataType dataType;
private final List<Row> rows;
private final int columnOrdinal;

RowBasedVector(DataType dataType, List<Row> rows, int columnOrdinal) {
this.dataType = dataType;
this.rows = rows;
this.columnOrdinal = columnOrdinal;
}

@Override
public DataType getDataType() {
return dataType;
}

@Override
public int getSize() {
return rows.size();
}

@Override
public void close() { /* nothing to close */ }

@Override
public boolean isNullAt(int rowId) {
assertValidRowId(rowId);
if (rows.get(rowId) == null) {
return true;
}
return rows.get(rowId).isNullAt(columnOrdinal);
}

@Override
public boolean getBoolean(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getBoolean(columnOrdinal);
}

@Override
public byte getByte(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getByte(columnOrdinal);
}

@Override
public short getShort(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getShort(columnOrdinal);
}

@Override
public int getInt(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getInt(columnOrdinal);
}

@Override
public long getLong(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getLong(columnOrdinal);
}

@Override
public float getFloat(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getFloat(columnOrdinal);
}

@Override
public double getDouble(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getDouble(columnOrdinal);
}

@Override
public String getString(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getString(columnOrdinal);
}

@Override
public BigDecimal getDecimal(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getDecimal(columnOrdinal);
}

@Override
public byte[] getBinary(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getBinary(columnOrdinal);
}

@Override
public MapValue getMap(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getMap(columnOrdinal);
}

@Override
public ArrayValue getArray(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getArray(columnOrdinal);
}

@Override
public ColumnVector getChild(int ordinal) {
List<Row> childRows = rows.stream()
.map(row -> {
if (row == null || row.isNullAt(columnOrdinal)) {
return null;
} else {
return row.getStruct(columnOrdinal);
}
}).collect(Collectors.toList());
StructType structType = (StructType) dataType;
return new RowBasedVector(
structType.at(ordinal).getDataType(),
childRows,
ordinal
);
}

private void assertValidRowId(int rowId) {
checkArgument(rowId < rows.size(),
"Invalid rowId: " + rowId + ", max allowed rowId is: " + (rows.size() - 1));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,15 @@
*/
package io.delta.kernel.defaults.internal.data;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import static java.util.Objects.requireNonNull;

import io.delta.kernel.data.*;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.StructField;
import io.delta.kernel.types.StructType;

import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;
import io.delta.kernel.defaults.internal.data.vector.DefaultSubFieldVector;

/**
* {@link ColumnarBatch} wrapper around list of {@link Row} objects.
Expand Down Expand Up @@ -71,7 +67,7 @@ public ColumnVector getColumnVector(int ordinal) {

if (!columnVectors.get(ordinal).isPresent()) {
final StructField field = schema.at(ordinal);
final ColumnVector vector = new SubFieldColumnVector(
final ColumnVector vector = new DefaultSubFieldVector(
getSize(),
field.getDataType(),
ordinal,
Expand Down Expand Up @@ -119,156 +115,4 @@ public ColumnarBatch withDeletedColumnAt(int ordinal) {
newSchema,
newColumnVectorArr);
}

/**
* {@link ColumnVector} wrapper on top of {@link Row} objects. This wrapper allows referncing
* any nested level column vector from a set of rows.
* TODO: We should change the {@link io.delta.kernel.defaults.client.DefaultJsonHandler} to
* generate data in true columnar format than wrapping a set of rows with a columnar batch
* interface.
*/
private static class SubFieldColumnVector implements ColumnVector {
private final int size;
private final DataType dataType;
private final int columnOrdinal;
private final Function<Integer, Row> rowIdToRowAccessor;

/**
* Create an instance of {@link SubFieldColumnVector}
*
* @param size Number of elements in the vector
* @param dataType Datatype of the vector
* @param columnOrdinal Ordinal of the column represented by this vector in the rows
* returned by {@link #rowIdToRowAccessor}
* @param rowIdToRowAccessor {@link Function} that returns a {@link Row} object for given
* rowId
*/
SubFieldColumnVector(
int size,
DataType dataType,
int columnOrdinal,
Function<Integer, Row> rowIdToRowAccessor) {
checkArgument(size >= 0, "invalid size: %s", size);
this.size = size;
checkArgument(columnOrdinal >= 0, "invalid column ordinal: %s", columnOrdinal);
this.columnOrdinal = columnOrdinal;
this.rowIdToRowAccessor =
requireNonNull(rowIdToRowAccessor, "rowIdToRowAccessor is null");
this.dataType = requireNonNull(dataType, "dataType is null");
}

@Override
public DataType getDataType() {
return dataType;
}

@Override
public int getSize() {
return size;
}

@Override
public void close() { /* nothing to close */ }

@Override
public boolean isNullAt(int rowId) {
assertValidRowId(rowId);
Row row = rowIdToRowAccessor.apply(rowId);
return row == null || row.isNullAt(columnOrdinal);
}

@Override
public boolean getBoolean(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getBoolean(columnOrdinal);
}

@Override
public byte getByte(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getByte(columnOrdinal);
}

@Override
public short getShort(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getShort(columnOrdinal);
}

@Override
public int getInt(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getInt(columnOrdinal);
}

@Override
public long getLong(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getLong(columnOrdinal);
}

@Override
public float getFloat(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getFloat(columnOrdinal);
}

@Override
public double getDouble(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getDouble(columnOrdinal);
}

@Override
public byte[] getBinary(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getBinary(columnOrdinal);
}

@Override
public String getString(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getString(columnOrdinal);
}

@Override
public BigDecimal getDecimal(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getDecimal(columnOrdinal);
}

@Override
public MapValue getMap(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getMap(columnOrdinal);
}

@Override
public Row getStruct(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal);
}

@Override
public ArrayValue getArray(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal);
}

@Override
public ColumnVector getChild(int childOrdinal) {
StructType structType = (StructType) dataType;
StructField childField = structType.at(childOrdinal);
return new SubFieldColumnVector(
size,
childField.getDataType(),
childOrdinal,
(rowId) -> (rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal)));
}

private void assertValidRowId(int rowId) {
checkArgument(rowId < size,
"Invalid rowId: " + rowId + ", max allowed rowId is: " + (size - 1));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import io.delta.kernel.data.MapValue;
import io.delta.kernel.data.Row;
import io.delta.kernel.types.*;
import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;

/**
* Generic column vector implementation to expose an array of objects as a column vector.
Expand Down Expand Up @@ -136,6 +137,18 @@ public MapValue getMap(int rowId) {
return (MapValue) values[rowId];
}

@Override
public ColumnVector getChild(int ordinal) {
checkArgument(dataType instanceof StructType);
StructType structType = (StructType) dataType;
return new DefaultSubFieldVector(
getSize(),
structType.at(ordinal).getDataType(),
ordinal,
(rowId) -> (Row) values[rowId]);

}

private void throwIfUnsafeAccess( Class<? extends DataType> expDataType, String accessType) {
if (!expDataType.isAssignableFrom(dataType.getClass())) {
String msg = String.format(
Expand Down
Loading

0 comments on commit 950cca3

Please sign in to comment.