From 8e5e8d915eff1f769d5b4cd2ad666fd7ff90166e Mon Sep 17 00:00:00 2001 From: "Xiaotian (Jackie) Jiang" <17555551+Jackie-Jiang@users.noreply.github.com> Date: Sat, 8 Jun 2024 21:23:37 -0700 Subject: [PATCH] Fix array literal handling (#13345) --- .../request/context/ExpressionContext.java | 2 + .../request/context/LiteralContext.java | 135 ++++++++++---- .../common/utils/request/RequestUtils.java | 30 ++-- pinot-common/src/main/proto/expressions.proto | 25 +++ .../ArrayLiteralTransformFunction.java | 37 ++-- .../function/TransformFunctionFactory.java | 7 +- .../HistogramAggregationFunction.java | 46 +++-- .../executor/ServerQueryExecutorV1Impl.java | 9 +- .../parser/CalciteRexExpressionParser.java | 17 +- .../query/planner/logical/RexExpression.java | 5 +- .../serde/ProtoExpressionToRexExpression.java | 58 ++++++ .../serde/RexExpressionToProtoExpression.java | 25 +++ .../planner/serde/RexExpressionSerDeTest.java | 165 ++++++++++++++++++ 13 files changed, 457 insertions(+), 104 deletions(-) create mode 100644 pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/RexExpressionSerDeTest.java diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java index 927ab4eb69f8..d52c0091e0cc 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/ExpressionContext.java @@ -18,6 +18,7 @@ */ package org.apache.pinot.common.request.context; +import com.google.common.annotations.VisibleForTesting; import java.util.Objects; import java.util.Set; import javax.annotation.Nullable; @@ -51,6 +52,7 @@ public static ExpressionContext forLiteral(Literal literal) { return forLiteral(new LiteralContext(literal)); } + @VisibleForTesting public static ExpressionContext forLiteral(DataType type, @Nullable Object value) { return forLiteral(new LiteralContext(type, value)); } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java index eb0667296fc8..0a2b8ad6e152 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/LiteralContext.java @@ -18,8 +18,12 @@ */ package org.apache.pinot.common.request.context; +import com.google.common.annotations.VisibleForTesting; +import com.google.common.base.Preconditions; import java.math.BigDecimal; import java.sql.Timestamp; +import java.util.Arrays; +import java.util.List; import java.util.Objects; import javax.annotation.Nullable; import org.apache.pinot.common.request.Literal; @@ -52,101 +56,147 @@ public class LiteralContext { private String _stringValue; private byte[] _bytesValue; - public LiteralContext(DataType type, Object value) { - _type = type; - _value = value; - _pinotDataType = getPinotDataType(type); - } - public LiteralContext(Literal literal) { switch (literal.getSetField()) { + case NULL_VALUE: + _type = DataType.UNKNOWN; + _value = null; + _pinotDataType = null; + break; case BOOL_VALUE: _type = DataType.BOOLEAN; _value = literal.getBoolValue(); + _pinotDataType = PinotDataType.BOOLEAN; break; case INT_VALUE: _type = DataType.INT; _value = literal.getIntValue(); + _pinotDataType = PinotDataType.INTEGER; break; case LONG_VALUE: _type = DataType.LONG; _value = literal.getLongValue(); + _pinotDataType = PinotDataType.LONG; break; case FLOAT_VALUE: _type = DataType.FLOAT; _value = Float.intBitsToFloat(literal.getFloatValue()); + _pinotDataType = PinotDataType.FLOAT; break; case DOUBLE_VALUE: _type = DataType.DOUBLE; _value = literal.getDoubleValue(); + _pinotDataType = PinotDataType.DOUBLE; break; case BIG_DECIMAL_VALUE: _type = DataType.BIG_DECIMAL; _value = BigDecimalUtils.deserialize(literal.getBigDecimalValue()); + _pinotDataType = PinotDataType.BIG_DECIMAL; break; case STRING_VALUE: _type = DataType.STRING; _value = literal.getStringValue(); + _pinotDataType = PinotDataType.STRING; break; case BINARY_VALUE: _type = DataType.BYTES; _value = literal.getBinaryValue(); + _pinotDataType = PinotDataType.BYTES; break; - // TODO: Revisit the type handling and whether we should convert value to primitive array for ARRAY types - case INT_ARRAY_VALUE: + case INT_ARRAY_VALUE: { _type = DataType.INT; - _value = literal.getIntArrayValue(); + List valueList = literal.getIntArrayValue(); + int numValues = valueList.size(); + int[] values = new int[numValues]; + for (int i = 0; i < numValues; i++) { + values[i] = valueList.get(i); + } + _value = values; + _pinotDataType = PinotDataType.PRIMITIVE_INT_ARRAY; break; - case LONG_ARRAY_VALUE: + } + case LONG_ARRAY_VALUE: { _type = DataType.LONG; - _value = literal.getLongArrayValue(); + List valueList = literal.getLongArrayValue(); + int numValues = valueList.size(); + long[] values = new long[numValues]; + for (int i = 0; i < numValues; i++) { + values[i] = valueList.get(i); + } + _value = values; + _pinotDataType = PinotDataType.PRIMITIVE_LONG_ARRAY; break; - // TODO: Revisit the FLOAT_ARRAY handling. Currently the values are stored as int bits. - case FLOAT_ARRAY_VALUE: + } + case FLOAT_ARRAY_VALUE: { _type = DataType.FLOAT; - _value = literal.getFloatArrayValue(); + List valueList = literal.getFloatArrayValue(); + int numValues = valueList.size(); + float[] values = new float[numValues]; + for (int i = 0; i < numValues; i++) { + values[i] = Float.intBitsToFloat(valueList.get(i)); + } + _value = values; + _pinotDataType = PinotDataType.PRIMITIVE_FLOAT_ARRAY; break; - case DOUBLE_ARRAY_VALUE: + } + case DOUBLE_ARRAY_VALUE: { _type = DataType.DOUBLE; - _value = literal.getDoubleArrayValue(); + List valueList = literal.getDoubleArrayValue(); + int numValues = valueList.size(); + double[] values = new double[numValues]; + for (int i = 0; i < numValues; i++) { + values[i] = valueList.get(i); + } + _value = values; + _pinotDataType = PinotDataType.PRIMITIVE_DOUBLE_ARRAY; break; + } case STRING_ARRAY_VALUE: _type = DataType.STRING; - _value = literal.getStringArrayValue(); - break; - case NULL_VALUE: - _type = DataType.UNKNOWN; - _value = null; + _value = literal.getStringArrayValue().toArray(new String[0]); + _pinotDataType = PinotDataType.STRING_ARRAY; break; default: throw new IllegalStateException("Unsupported field type: " + literal.getSetField()); } - _pinotDataType = getPinotDataType(_type); + } + + @VisibleForTesting + public LiteralContext(DataType type, @Nullable Object value) { + _type = type; + _value = value; + _pinotDataType = getPinotDataType(type, value); } @Nullable - private static PinotDataType getPinotDataType(DataType type) { + private static PinotDataType getPinotDataType(DataType type, @Nullable Object value) { + if (value == null) { + return null; + } + if (type == DataType.BYTES) { + Preconditions.checkState(value.getClass().getComponentType() == byte.class, "Bytes array is not supported"); + return PinotDataType.BYTES; + } + boolean singleValue = !value.getClass().isArray(); switch (type) { case BOOLEAN: + Preconditions.checkState(singleValue, "Boolean array is not supported"); return PinotDataType.BOOLEAN; case INT: - return PinotDataType.INTEGER; + return singleValue ? PinotDataType.INTEGER : PinotDataType.PRIMITIVE_INT_ARRAY; case LONG: - return PinotDataType.LONG; + return singleValue ? PinotDataType.LONG : PinotDataType.PRIMITIVE_LONG_ARRAY; case FLOAT: - return PinotDataType.FLOAT; + return singleValue ? PinotDataType.FLOAT : PinotDataType.PRIMITIVE_FLOAT_ARRAY; case DOUBLE: - return PinotDataType.DOUBLE; + return singleValue ? PinotDataType.DOUBLE : PinotDataType.PRIMITIVE_DOUBLE_ARRAY; case BIG_DECIMAL: + Preconditions.checkState(singleValue, "BigDecimal array is not supported"); return PinotDataType.BIG_DECIMAL; case STRING: - return PinotDataType.STRING; - case BYTES: - return PinotDataType.BYTES; - case UNKNOWN: - return null; + return singleValue ? PinotDataType.STRING : PinotDataType.STRING_ARRAY; default: - throw new IllegalStateException("Unsupported data type: " + type); + throw new IllegalStateException("Unsupported DataType: " + type); } } @@ -159,6 +209,10 @@ public Object getValue() { return _value; } + public boolean isSingleValue() { + return _pinotDataType == null || _pinotDataType.isSingleValue(); + } + public boolean getBooleanValue() { Boolean booleanValue = _booleanValue; if (booleanValue == null) { @@ -281,8 +335,21 @@ public String toString() { // https://github.com/apache/pinot/pull/11762) if (isNull()) { return "'null'"; - } else { + } + if (isSingleValue()) { return "'" + getStringValue() + "'"; } + switch (_pinotDataType) { + case PRIMITIVE_INT_ARRAY: + return "'" + Arrays.toString((int[]) _value) + "'"; + case PRIMITIVE_LONG_ARRAY: + return "'" + Arrays.toString((long[]) _value) + "'"; + case PRIMITIVE_FLOAT_ARRAY: + return "'" + Arrays.toString((float[]) _value) + "'"; + case PRIMITIVE_DOUBLE_ARRAY: + return "'" + Arrays.toString((double[]) _value) + "'"; + default: + throw new IllegalStateException("Unsupported PinotDataType: " + _pinotDataType); + } } } diff --git a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java index 44a0931957cf..7cc83877311e 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/utils/request/RequestUtils.java @@ -170,48 +170,48 @@ public static Literal getLiteral(@Nullable Object object) { return getNullLiteral(); } if (object instanceof Boolean) { - return RequestUtils.getLiteral((boolean) object); + return getLiteral((boolean) object); } if (object instanceof Integer) { - return RequestUtils.getLiteral((int) object); + return getLiteral((int) object); } if (object instanceof Long) { - return RequestUtils.getLiteral((long) object); + return getLiteral((long) object); } if (object instanceof Float) { - return RequestUtils.getLiteral((float) object); + return getLiteral((float) object); } if (object instanceof Double) { - return RequestUtils.getLiteral((double) object); + return getLiteral((double) object); } if (object instanceof BigDecimal) { - return RequestUtils.getLiteral((BigDecimal) object); + return getLiteral((BigDecimal) object); } if (object instanceof Timestamp) { - return RequestUtils.getLiteral(((Timestamp) object).getTime()); + return getLiteral(((Timestamp) object).getTime()); } if (object instanceof String) { - return RequestUtils.getLiteral((String) object); + return getLiteral((String) object); } if (object instanceof byte[]) { - return RequestUtils.getLiteral((byte[]) object); + return getLiteral((byte[]) object); } if (object instanceof int[]) { - return RequestUtils.getLiteral((int[]) object); + return getLiteral((int[]) object); } if (object instanceof long[]) { - return RequestUtils.getLiteral((long[]) object); + return getLiteral((long[]) object); } if (object instanceof float[]) { - return RequestUtils.getLiteral((float[]) object); + return getLiteral((float[]) object); } if (object instanceof double[]) { - return RequestUtils.getLiteral((double[]) object); + return getLiteral((double[]) object); } if (object instanceof String[]) { - return RequestUtils.getLiteral((String[]) object); + return getLiteral((String[]) object); } - return RequestUtils.getLiteral(object.toString()); + return getLiteral(object.toString()); } public static Literal getLiteral(SqlLiteral node) { diff --git a/pinot-common/src/main/proto/expressions.proto b/pinot-common/src/main/proto/expressions.proto index ebc164a2ad6e..17cf4ac11508 100644 --- a/pinot-common/src/main/proto/expressions.proto +++ b/pinot-common/src/main/proto/expressions.proto @@ -58,9 +58,34 @@ message Literal { double double = 6; string string = 7; bytes bytes = 8; + IntArray intArray = 9; + LongArray longArray = 10; + FloatArray floatArray = 11; + DoubleArray doubleArray = 12; + StringArray stringArray = 13; } } +message IntArray { + repeated int32 values = 1; +} + +message LongArray { + repeated int64 values = 1; +} + +message FloatArray { + repeated float values = 1; +} + +message DoubleArray { + repeated double values = 1; +} + +message StringArray { + repeated string values = 1; +} + message FunctionCall { ColumnDataType dataType = 1; string functionName = 2; diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java index b2065e20d3a4..084d34bf2e0f 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/ArrayLiteralTransformFunction.java @@ -56,10 +56,9 @@ public class ArrayLiteralTransformFunction implements TransformFunction { private String[][] _stringArrayResult; public ArrayLiteralTransformFunction(LiteralContext literalContext) { - List literalArray = (List) literalContext.getValue(); - Preconditions.checkNotNull(literalArray); - if (literalArray.isEmpty()) { - _dataType = DataType.UNKNOWN; + _dataType = literalContext.getType(); + Object value = literalContext.getValue(); + if (value == null) { _intArrayLiteral = new int[0]; _longArrayLiteral = new long[0]; _floatArrayLiteral = new float[0]; @@ -67,53 +66,37 @@ public ArrayLiteralTransformFunction(LiteralContext literalContext) { _stringArrayLiteral = new String[0]; return; } - _dataType = literalContext.getType(); switch (_dataType) { case INT: - _intArrayLiteral = new int[literalArray.size()]; - for (int i = 0; i < _intArrayLiteral.length; i++) { - _intArrayLiteral[i] = (int) literalArray.get(i); - } + _intArrayLiteral = (int[]) value; _longArrayLiteral = null; _floatArrayLiteral = null; _doubleArrayLiteral = null; _stringArrayLiteral = null; break; case LONG: - _longArrayLiteral = new long[literalArray.size()]; - for (int i = 0; i < _longArrayLiteral.length; i++) { - _longArrayLiteral[i] = (long) literalArray.get(i); - } + _longArrayLiteral = (long[]) value; _intArrayLiteral = null; _floatArrayLiteral = null; _doubleArrayLiteral = null; _stringArrayLiteral = null; break; case FLOAT: - _floatArrayLiteral = new float[literalArray.size()]; - for (int i = 0; i < _floatArrayLiteral.length; i++) { - _floatArrayLiteral[i] = (float) literalArray.get(i); - } + _floatArrayLiteral = (float[]) value; _intArrayLiteral = null; _longArrayLiteral = null; _doubleArrayLiteral = null; _stringArrayLiteral = null; break; case DOUBLE: - _doubleArrayLiteral = new double[literalArray.size()]; - for (int i = 0; i < _doubleArrayLiteral.length; i++) { - _doubleArrayLiteral[i] = (double) literalArray.get(i); - } + _doubleArrayLiteral = (double[]) value; _intArrayLiteral = null; _longArrayLiteral = null; _floatArrayLiteral = null; _stringArrayLiteral = null; break; case STRING: - _stringArrayLiteral = new String[literalArray.size()]; - for (int i = 0; i < _stringArrayLiteral.length; i++) { - _stringArrayLiteral[i] = (String) literalArray.get(i); - } + _stringArrayLiteral = (String[]) value; _intArrayLiteral = null; _longArrayLiteral = null; _floatArrayLiteral = null; @@ -121,8 +104,8 @@ public ArrayLiteralTransformFunction(LiteralContext literalContext) { break; default: throw new IllegalStateException( - "Illegal data type for ArrayLiteralTransformFunction: " + _dataType + ", literal contexts: " - + Arrays.toString(literalArray.toArray())); + "Illegal data type for ArrayLiteralTransformFunction: " + _dataType + ", literal context: " + + literalContext); } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java index d5e4d9d481d6..de7668ca26a9 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/TransformFunctionFactory.java @@ -338,12 +338,13 @@ public static TransformFunction get(ExpressionContext expression, Map arguments) { ExpressionContext arrayExpression = arguments.get(1); Preconditions.checkArgument( // ARRAY function - ((arrayExpression.getType() == ExpressionContext.Type.FUNCTION) - && (arrayExpression.getFunction().getFunctionName().equals(ARRAY_CONSTRUCTOR))) - || ((arrayExpression.getType() == ExpressionContext.Type.LITERAL) - && (arrayExpression.getLiteral().getValue() instanceof List)), + (arrayExpression.getType() == ExpressionContext.Type.FUNCTION && arrayExpression.getFunction() + .getFunctionName().equals(ARRAY_CONSTRUCTOR)) || ( + arrayExpression.getType() == ExpressionContext.Type.LITERAL && !arrayExpression.getLiteral() + .isSingleValue()), "Please use the format of `Histogram(columnName, ARRAY[1,10,100])` to specify the bin edges"); if (arrayExpression.getType() == ExpressionContext.Type.FUNCTION) { _bucketEdges = parseVector(arrayExpression.getFunction().getArguments()); } else { - _bucketEdges = parseVectorLiteral((List) arrayExpression.getLiteral().getValue()); + _bucketEdges = parseVectorLiteral(arrayExpression.getLiteral().getValue()); } _lower = _bucketEdges[0]; _upper = _bucketEdges[_bucketEdges.length - 1]; @@ -111,22 +112,35 @@ private double[] parseVector(List arrayStr) { ret[i] = arrayStr.get(i).getLiteral().getDoubleValue(); } if (i > 0) { - Preconditions.checkState(ret[i] > ret[i - 1], "The bin edges must be strictly increasing"); + Preconditions.checkArgument(ret[i] > ret[i - 1], "The bin edges must be strictly increasing"); } } return ret; } - private double[] parseVectorLiteral(List arrayStr) { - int len = arrayStr.size(); - Preconditions.checkArgument(len > 1, "The number of bin edges must be greater than 1"); - double[] ret = new double[len]; - for (int i = 0; i < len; i++) { - // TODO: Represent infinity as literal instead of identifier - ret[i] = Double.parseDouble(arrayStr.get(i).toString()); - if (i > 0) { - Preconditions.checkState(ret[i] > ret[i - 1], "The bin edges must be strictly increasing"); - } + private double[] parseVectorLiteral(Object array) { + Preconditions.checkArgument(array != null, "The bin edges must not be null"); + double[] ret; + if (array instanceof int[]) { + int[] intArray = (int[]) array; + ret = new double[intArray.length]; + ArrayCopyUtils.copy(intArray, ret, intArray.length); + } else if (array instanceof long[]) { + long[] longArray = (long[]) array; + ret = new double[longArray.length]; + ArrayCopyUtils.copy(longArray, ret, longArray.length); + } else if (array instanceof float[]) { + float[] floatArray = (float[]) array; + ret = new double[floatArray.length]; + ArrayCopyUtils.copy(floatArray, ret, floatArray.length); + } else if (array instanceof double[]) { + ret = (double[]) array; + } else { + throw new IllegalArgumentException("Unsupported array type: " + array.getClass()); + } + Preconditions.checkArgument(ret.length > 1, "The number of bin edges must be greater than 1"); + for (int i = 1; i < ret.length; i++) { + Preconditions.checkArgument(ret[i] > ret[i - 1], "The bin edges must be strictly increasing"); } return ret; } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java b/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java index 19f9421a84c4..8edc7b49704f 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/query/executor/ServerQueryExecutorV1Impl.java @@ -41,6 +41,7 @@ import org.apache.pinot.common.request.context.FilterContext; import org.apache.pinot.common.request.context.FunctionContext; import org.apache.pinot.common.utils.config.QueryOptionsUtils; +import org.apache.pinot.common.utils.request.RequestUtils; import org.apache.pinot.core.common.ExplainPlanRowData; import org.apache.pinot.core.common.ExplainPlanRows; import org.apache.pinot.core.common.Operator; @@ -72,7 +73,6 @@ import org.apache.pinot.segment.spi.MutableSegment; import org.apache.pinot.segment.spi.SegmentContext; import org.apache.pinot.segment.spi.SegmentMetadata; -import org.apache.pinot.spi.data.FieldSpec; import org.apache.pinot.spi.env.PinotConfiguration; import org.apache.pinot.spi.exception.BadQueryRequestException; import org.apache.pinot.spi.exception.QueryCancelledException; @@ -236,8 +236,8 @@ private InstanceResponseBlock executeInternal(ServerQueryRequest queryRequest, E if (indexTimeMs > 0) { minIndexTimeMs = Math.min(minIndexTimeMs, indexTimeMs); } - long ingestionTimeMs = ((RealtimeTableDataManager) - tableDataManager).getPartitionIngestionTimeMs(indexSegment.getSegmentName()); + long ingestionTimeMs = + ((RealtimeTableDataManager) tableDataManager).getPartitionIngestionTimeMs(indexSegment.getSegmentName()); if (ingestionTimeMs > 0) { minIngestionTimeMs = Math.min(minIngestionTimeMs, ingestionTimeMs); } @@ -602,8 +602,7 @@ private void handleSubquery(ExpressionContext expression, TableDataManager table result != null ? result.getClass().getSimpleName() : null); // Rewrite the expression function.setFunctionName(TransformFunctionType.IN_ID_SET.name()); - arguments.set(1, - ExpressionContext.forLiteral(FieldSpec.DataType.STRING, ((IdSet) result).toBase64String())); + arguments.set(1, ExpressionContext.forLiteral(RequestUtils.getLiteral(((IdSet) result).toBase64String()))); } else { for (ExpressionContext argument : arguments) { handleSubquery(argument, tableDataManager, indexSegments, timerContext, executorService, endTimeMs); diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java index 67992d4dfc4e..a20b2479d4f0 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/parser/CalciteRexExpressionParser.java @@ -26,9 +26,12 @@ import org.apache.pinot.common.request.Expression; import org.apache.pinot.common.request.Literal; import org.apache.pinot.common.request.PinotQuery; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.common.utils.request.RequestUtils; import org.apache.pinot.query.planner.logical.RexExpression; import org.apache.pinot.query.planner.plannode.SortNode; +import org.apache.pinot.spi.utils.BooleanUtils; +import org.apache.pinot.spi.utils.ByteArray; import org.apache.pinot.sql.parsers.ParserUtils; @@ -137,9 +140,19 @@ private static Expression inputRefToIdentifier(RexExpression.InputRef inputRef, public static Literal toLiteral(RexExpression.Literal literal) { Object value = literal.getValue(); + if (value == null) { + return RequestUtils.getNullLiteral(); + } // NOTE: Value is stored in internal format in RexExpression.Literal. - return value != null ? RequestUtils.getLiteral(literal.getDataType().toExternal(value)) - : RequestUtils.getNullLiteral(); + // Do not convert TIMESTAMP/BOOLEAN_ARRAY/TIMESTAMP_ARRAY to external format because they are not explicitly + // supported in single-stage engine Literal. + ColumnDataType dataType = literal.getDataType(); + if (dataType == ColumnDataType.BOOLEAN) { + value = BooleanUtils.isTrueInternalValue(value); + } else if (dataType == ColumnDataType.BYTES) { + value = ((ByteArray) value).getBytes(); + } + return RequestUtils.getLiteral(value); } private static Expression compileFunctionExpression(RexExpression.FunctionCall rexCall, PinotQuery pinotQuery) { diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java index b81177877fc0..d06ee0473a2c 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/logical/RexExpression.java @@ -18,6 +18,7 @@ */ package org.apache.pinot.query.planner.logical; +import java.util.Arrays; import java.util.List; import java.util.Objects; import org.apache.calcite.rex.RexNode; @@ -92,12 +93,12 @@ public boolean equals(Object o) { return false; } Literal literal = (Literal) o; - return _dataType == literal._dataType && Objects.equals(_value, literal._value); + return _dataType == literal._dataType && Objects.deepEquals(_value, literal._value); } @Override public int hashCode() { - return Objects.hash(_dataType, _value); + return Arrays.deepHashCode(new Object[]{_dataType, _value}); } } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java index 206f9dcd2ac7..e197276d751a 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/ProtoExpressionToRexExpression.java @@ -23,6 +23,7 @@ import org.apache.pinot.common.proto.Expressions; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.spi.utils.BigDecimalUtils; import org.apache.pinot.spi.utils.ByteArray; @@ -74,10 +75,67 @@ public static RexExpression.Literal convertLiteral(Expressions.Literal literal) return new RexExpression.Literal(dataType, literal.getFloat()); case DOUBLE: return new RexExpression.Literal(dataType, literal.getDouble()); + case BIG_DECIMAL: + return new RexExpression.Literal(dataType, BigDecimalUtils.deserialize(literal.getBytes().toByteArray())); case STRING: return new RexExpression.Literal(dataType, literal.getString()); case BYTES: return new RexExpression.Literal(dataType, new ByteArray(literal.getBytes().toByteArray())); + case INT_ARRAY: { + Expressions.IntArray intArray = literal.getIntArray(); + int numValues = intArray.getValuesCount(); + int[] values = new int[numValues]; + { + for (int i = 0; i < numValues; i++) { + values[i] = intArray.getValues(i); + } + } + return new RexExpression.Literal(dataType, values); + } + case LONG_ARRAY: { + Expressions.LongArray longArray = literal.getLongArray(); + int numValues = longArray.getValuesCount(); + long[] values = new long[numValues]; + { + for (int i = 0; i < numValues; i++) { + values[i] = longArray.getValues(i); + } + } + return new RexExpression.Literal(dataType, values); + } + case FLOAT_ARRAY: { + Expressions.FloatArray floatArray = literal.getFloatArray(); + int numValues = floatArray.getValuesCount(); + float[] values = new float[numValues]; + { + for (int i = 0; i < numValues; i++) { + values[i] = floatArray.getValues(i); + } + } + return new RexExpression.Literal(dataType, values); + } + case DOUBLE_ARRAY: { + Expressions.DoubleArray doubleArray = literal.getDoubleArray(); + int numValues = doubleArray.getValuesCount(); + double[] values = new double[numValues]; + { + for (int i = 0; i < numValues; i++) { + values[i] = doubleArray.getValues(i); + } + } + return new RexExpression.Literal(dataType, values); + } + case STRING_ARRAY: { + Expressions.StringArray stringArray = literal.getStringArray(); + int numValues = stringArray.getValuesCount(); + String[] values = new String[numValues]; + { + for (int i = 0; i < numValues; i++) { + values[i] = stringArray.getValues(i); + } + } + return new RexExpression.Literal(dataType, values); + } default: throw new IllegalStateException("Unsupported ColumnDataType: " + dataType); } diff --git a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java index 0ff66c0c389e..0350d8ba8ca6 100644 --- a/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java +++ b/pinot-query-planner/src/main/java/org/apache/pinot/query/planner/serde/RexExpressionToProtoExpression.java @@ -19,8 +19,13 @@ package org.apache.pinot.query.planner.serde; import com.google.protobuf.ByteString; +import it.unimi.dsi.fastutil.doubles.DoubleArrayList; +import it.unimi.dsi.fastutil.floats.FloatArrayList; +import it.unimi.dsi.fastutil.ints.IntArrayList; +import it.unimi.dsi.fastutil.longs.LongArrayList; import java.math.BigDecimal; import java.util.ArrayList; +import java.util.Arrays; import java.util.List; import org.apache.pinot.common.proto.Expressions; import org.apache.pinot.common.utils.DataSchema.ColumnDataType; @@ -94,6 +99,26 @@ public static Expressions.Literal convertLiteral(RexExpression.Literal literal) case BYTES: literalBuilder.setBytes(ByteString.copyFrom(((ByteArray) value).getBytes())); break; + case INT_ARRAY: + literalBuilder.setIntArray( + Expressions.IntArray.newBuilder().addAllValues(IntArrayList.wrap((int[]) value)).build()); + break; + case LONG_ARRAY: + literalBuilder.setLongArray( + Expressions.LongArray.newBuilder().addAllValues(LongArrayList.wrap((long[]) value)).build()); + break; + case FLOAT_ARRAY: + literalBuilder.setFloatArray( + Expressions.FloatArray.newBuilder().addAllValues(FloatArrayList.wrap((float[]) value)).build()); + break; + case DOUBLE_ARRAY: + literalBuilder.setDoubleArray( + Expressions.DoubleArray.newBuilder().addAllValues(DoubleArrayList.wrap((double[]) value)).build()); + break; + case STRING_ARRAY: + literalBuilder.setStringArray( + Expressions.StringArray.newBuilder().addAllValues(Arrays.asList((String[]) value)).build()); + break; default: throw new IllegalStateException("Unsupported ColumnDataType: " + dataType); } diff --git a/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/RexExpressionSerDeTest.java b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/RexExpressionSerDeTest.java new file mode 100644 index 000000000000..b933f5c99080 --- /dev/null +++ b/pinot-query-planner/src/test/java/org/apache/pinot/query/planner/serde/RexExpressionSerDeTest.java @@ -0,0 +1,165 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.pinot.query.planner.serde; + +import java.math.BigDecimal; +import java.util.List; +import java.util.Random; +import org.apache.commons.lang.RandomStringUtils; +import org.apache.pinot.common.utils.DataSchema.ColumnDataType; +import org.apache.pinot.query.planner.logical.RexExpression; +import org.apache.pinot.spi.utils.BooleanUtils; +import org.apache.pinot.spi.utils.ByteArray; +import org.testng.annotations.Test; + +import static org.testng.Assert.assertEquals; + + +public class RexExpressionSerDeTest { + private static final List SUPPORTED_DATE_TYPES = + List.of(ColumnDataType.INT, ColumnDataType.LONG, ColumnDataType.FLOAT, ColumnDataType.DOUBLE, + ColumnDataType.BIG_DECIMAL, ColumnDataType.BOOLEAN, ColumnDataType.TIMESTAMP, ColumnDataType.STRING, + ColumnDataType.BYTES, ColumnDataType.INT_ARRAY, ColumnDataType.LONG_ARRAY, ColumnDataType.FLOAT_ARRAY, + ColumnDataType.DOUBLE_ARRAY, ColumnDataType.BOOLEAN_ARRAY, ColumnDataType.TIMESTAMP_ARRAY, + ColumnDataType.STRING_ARRAY, ColumnDataType.UNKNOWN); + private static final Random RANDOM = new Random(); + + @Test + public void testNullLiteral() { + for (ColumnDataType dataType : SUPPORTED_DATE_TYPES) { + verifyLiteralSerDe(new RexExpression.Literal(dataType, null)); + } + } + + @Test + public void testIntLiteral() { + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.INT, RANDOM.nextInt())); + } + + @Test + public void testLongLiteral() { + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.LONG, RANDOM.nextLong())); + } + + @Test + public void testFloatLiteral() { + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.FLOAT, RANDOM.nextFloat())); + } + + @Test + public void testDoubleLiteral() { + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.DOUBLE, RANDOM.nextDouble())); + } + + @Test + public void testBigDecimalLiteral() { + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.BIG_DECIMAL, + RANDOM.nextBoolean() ? BigDecimal.valueOf(RANDOM.nextLong()) : BigDecimal.valueOf(RANDOM.nextDouble()))); + } + + @Test + public void testBooleanLiteral() { + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.BOOLEAN, BooleanUtils.toInt(RANDOM.nextBoolean()))); + } + + @Test + public void testTimestampLiteral() { + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.TIMESTAMP, RANDOM.nextLong())); + } + + @Test + public void testStringLiteral() { + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.STRING, RandomStringUtils.random(RANDOM.nextInt(10)))); + } + + @Test + public void testBytesLiteral() { + byte[] bytes = new byte[RANDOM.nextInt(10)]; + RANDOM.nextBytes(bytes); + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.BYTES, new ByteArray(bytes))); + } + + @Test + public void testIntArrayLiteral() { + int[] values = new int[RANDOM.nextInt(10)]; + for (int i = 0; i < values.length; i++) { + values[i] = RANDOM.nextInt(); + } + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.INT_ARRAY, values)); + } + + @Test + public void testLongArrayLiteral() { + long[] values = new long[RANDOM.nextInt(10)]; + for (int i = 0; i < values.length; i++) { + values[i] = RANDOM.nextLong(); + } + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.LONG_ARRAY, values)); + } + + @Test + public void testFloatArrayLiteral() { + float[] values = new float[RANDOM.nextInt(10)]; + for (int i = 0; i < values.length; i++) { + values[i] = RANDOM.nextFloat(); + } + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.FLOAT_ARRAY, values)); + } + + @Test + public void testDoubleArrayLiteral() { + double[] values = new double[RANDOM.nextInt(10)]; + for (int i = 0; i < values.length; i++) { + values[i] = RANDOM.nextDouble(); + } + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.DOUBLE_ARRAY, values)); + } + + @Test + public void testBooleanArrayLiteral() { + int[] values = new int[RANDOM.nextInt(10)]; + for (int i = 0; i < values.length; i++) { + values[i] = BooleanUtils.toInt(RANDOM.nextBoolean()); + } + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.BOOLEAN_ARRAY, values)); + } + + @Test + public void testTimestampArrayLiteral() { + long[] values = new long[RANDOM.nextInt(10)]; + for (int i = 0; i < values.length; i++) { + values[i] = RANDOM.nextLong(); + } + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.TIMESTAMP_ARRAY, values)); + } + + @Test + public void testStringArrayLiteral() { + String[] values = new String[RANDOM.nextInt(10)]; + for (int i = 0; i < values.length; i++) { + values[i] = RandomStringUtils.random(RANDOM.nextInt(10)); + } + verifyLiteralSerDe(new RexExpression.Literal(ColumnDataType.STRING_ARRAY, values)); + } + + private void verifyLiteralSerDe(RexExpression.Literal literal) { + assertEquals(literal, + ProtoExpressionToRexExpression.convertLiteral(RexExpressionToProtoExpression.convertLiteral(literal))); + } +}