Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Kernel] Implement getChild for a few remaining column vectors #2133

Merged
merged 7 commits into from
Oct 10, 2023
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;

import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
Expand All @@ -42,6 +43,7 @@
import io.delta.kernel.types.StructType;
import io.delta.kernel.utils.CloseableIterator;
import io.delta.kernel.utils.Utils;
import static io.delta.kernel.internal.util.InternalUtils.checkArgument;

/**
* Implementation of {@link JsonHandler} for testing Delta Kernel APIs
Expand Down Expand Up @@ -407,5 +409,151 @@ public ArrayValue getArray(int rowId) {
public MapValue getMap(int rowId) {
return (MapValue) values.get(rowId);
}

@Override
public ColumnVector getChild(int ordinal) {
allisonport-db marked this conversation as resolved.
Show resolved Hide resolved
checkArgument(dataType instanceof StructType);
StructType structType = (StructType) dataType;
List<Row> rows = values.stream()
.map(row -> (Row) row)
.collect(Collectors.toList());
return new RowBasedVector(
structType.at(ordinal).getDataType(),
rows,
ordinal
);
}
}

/**
* Wrapper around list of {@link Row}s to expose the rows as a column vector
*/
private static class RowBasedVector implements ColumnVector {
private final DataType dataType;
private final List<Row> rows;
private final int columnOrdinal;

RowBasedVector(DataType dataType, List<Row> rows, int columnOrdinal) {
this.dataType = dataType;
this.rows = rows;
this.columnOrdinal = columnOrdinal;
}

@Override
public DataType getDataType() {
return dataType;
}

@Override
public int getSize() {
return rows.size();
}

@Override
public void close() { /* nothing to close */ }

@Override
public boolean isNullAt(int rowId) {
assertValidRowId(rowId);
if (rows.get(rowId) == null) {
return true;
}
return rows.get(rowId).isNullAt(columnOrdinal);
}

@Override
public boolean getBoolean(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getBoolean(columnOrdinal);
}

@Override
public byte getByte(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getByte(columnOrdinal);
}

@Override
public short getShort(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getShort(columnOrdinal);
}

@Override
public int getInt(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getInt(columnOrdinal);
}

@Override
public long getLong(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getLong(columnOrdinal);
}

@Override
public float getFloat(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getFloat(columnOrdinal);
}

@Override
public double getDouble(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getDouble(columnOrdinal);
}

@Override
public String getString(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getString(columnOrdinal);
}

@Override
public BigDecimal getDecimal(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getDecimal(columnOrdinal);
}

@Override
public byte[] getBinary(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getBinary(columnOrdinal);
}

@Override
public MapValue getMap(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getMap(columnOrdinal);
}

@Override
public ArrayValue getArray(int rowId) {
assertValidRowId(rowId);
return rows.get(rowId).getArray(columnOrdinal);
}

@Override
public ColumnVector getChild(int ordinal) {
List<Row> childRows = rows.stream()
.map(row -> {
if (row == null || row.isNullAt(columnOrdinal)) {
return null;
} else {
return row.getStruct(columnOrdinal);
}
}).collect(Collectors.toList());
StructType structType = (StructType) dataType;
return new RowBasedVector(
structType.at(ordinal).getDataType(),
childRows,
ordinal
);
}

private void assertValidRowId(int rowId) {
checkArgument(rowId < rows.size(),
"Invalid rowId: " + rowId + ", max allowed rowId is: " + (rows.size() - 1));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -15,19 +15,15 @@
*/
package io.delta.kernel.defaults.internal.data;

import java.math.BigDecimal;
import java.util.ArrayList;
import java.util.List;
import java.util.Optional;
import java.util.function.Function;
import static java.util.Objects.requireNonNull;

import io.delta.kernel.data.*;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.StructField;
import io.delta.kernel.types.StructType;

import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;
import io.delta.kernel.defaults.internal.data.vector.DefaultSubFieldVector;

/**
* {@link ColumnarBatch} wrapper around list of {@link Row} objects.
Expand Down Expand Up @@ -71,7 +67,7 @@ public ColumnVector getColumnVector(int ordinal) {

if (!columnVectors.get(ordinal).isPresent()) {
final StructField field = schema.at(ordinal);
final ColumnVector vector = new SubFieldColumnVector(
final ColumnVector vector = new DefaultSubFieldVector(
getSize(),
field.getDataType(),
ordinal,
Expand Down Expand Up @@ -119,156 +115,4 @@ public ColumnarBatch withDeletedColumnAt(int ordinal) {
newSchema,
newColumnVectorArr);
}

/**
* {@link ColumnVector} wrapper on top of {@link Row} objects. This wrapper allows referncing
* any nested level column vector from a set of rows.
* TODO: We should change the {@link io.delta.kernel.defaults.client.DefaultJsonHandler} to
* generate data in true columnar format than wrapping a set of rows with a columnar batch
* interface.
*/
private static class SubFieldColumnVector implements ColumnVector {
private final int size;
private final DataType dataType;
private final int columnOrdinal;
private final Function<Integer, Row> rowIdToRowAccessor;

/**
* Create an instance of {@link SubFieldColumnVector}
*
* @param size Number of elements in the vector
* @param dataType Datatype of the vector
* @param columnOrdinal Ordinal of the column represented by this vector in the rows
* returned by {@link #rowIdToRowAccessor}
* @param rowIdToRowAccessor {@link Function} that returns a {@link Row} object for given
* rowId
*/
SubFieldColumnVector(
int size,
DataType dataType,
int columnOrdinal,
Function<Integer, Row> rowIdToRowAccessor) {
checkArgument(size >= 0, "invalid size: %s", size);
this.size = size;
checkArgument(columnOrdinal >= 0, "invalid column ordinal: %s", columnOrdinal);
this.columnOrdinal = columnOrdinal;
this.rowIdToRowAccessor =
requireNonNull(rowIdToRowAccessor, "rowIdToRowAccessor is null");
this.dataType = requireNonNull(dataType, "dataType is null");
}

@Override
public DataType getDataType() {
return dataType;
}

@Override
public int getSize() {
return size;
}

@Override
public void close() { /* nothing to close */ }

@Override
public boolean isNullAt(int rowId) {
assertValidRowId(rowId);
Row row = rowIdToRowAccessor.apply(rowId);
return row == null || row.isNullAt(columnOrdinal);
}

@Override
public boolean getBoolean(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getBoolean(columnOrdinal);
}

@Override
public byte getByte(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getByte(columnOrdinal);
}

@Override
public short getShort(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getShort(columnOrdinal);
}

@Override
public int getInt(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getInt(columnOrdinal);
}

@Override
public long getLong(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getLong(columnOrdinal);
}

@Override
public float getFloat(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getFloat(columnOrdinal);
}

@Override
public double getDouble(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getDouble(columnOrdinal);
}

@Override
public byte[] getBinary(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getBinary(columnOrdinal);
}

@Override
public String getString(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getString(columnOrdinal);
}

@Override
public BigDecimal getDecimal(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getDecimal(columnOrdinal);
}

@Override
public MapValue getMap(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getMap(columnOrdinal);
}

@Override
public Row getStruct(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal);
}

@Override
public ArrayValue getArray(int rowId) {
assertValidRowId(rowId);
return rowIdToRowAccessor.apply(rowId).getArray(columnOrdinal);
}

@Override
public ColumnVector getChild(int childOrdinal) {
StructType structType = (StructType) dataType;
StructField childField = structType.at(childOrdinal);
return new SubFieldColumnVector(
size,
childField.getDataType(),
childOrdinal,
(rowId) -> (rowIdToRowAccessor.apply(rowId).getStruct(columnOrdinal)));
}

private void assertValidRowId(int rowId) {
checkArgument(rowId < size,
"Invalid rowId: " + rowId + ", max allowed rowId is: " + (size - 1));
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,9 @@
import io.delta.kernel.data.MapValue;
import io.delta.kernel.data.Row;
import io.delta.kernel.types.DataType;
import io.delta.kernel.types.StructType;

import static io.delta.kernel.defaults.internal.DefaultKernelUtils.checkArgument;

public class DefaultConstantVector
implements ColumnVector {
Expand Down Expand Up @@ -120,4 +123,16 @@ public Row getStruct(int rowId) {
public ArrayValue getArray(int rowId) {
return (ArrayValue) value;
}

@Override
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems incomplete to ignore DefaultConstantVector but this should never be used. I think possible options

  1. Restrict DefaultConstantVector to only allow StructType for null values (for NonExistentColumnConverter)
  2. Provide this implementation and add a test somewhere (not sure where to put this?)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Option 1 + restrict it for Map and Array types for non-null values only?

public ColumnVector getChild(int ordinal) {
checkArgument(dataType instanceof StructType);
StructType structType = (StructType) dataType;
return new DefaultSubFieldVector(
numRows,
structType.at(ordinal).getDataType(),
ordinal,
(rowId) -> (Row) value
);
}
}
Loading
Loading