From ff991bf8742545f72443bd8fa96d914bd775eba8 Mon Sep 17 00:00:00 2001 From: Xiang Fu Date: Fri, 6 Dec 2024 11:33:31 -0800 Subject: [PATCH] use dictionary for BinaryOperatorTransformFunction when possible --- .../context/predicate/RangePredicate.java | 16 ++ .../BinaryOperatorTransformFunction.java | 219 +++++++++++++++--- 2 files changed, 207 insertions(+), 28 deletions(-) diff --git a/pinot-common/src/main/java/org/apache/pinot/common/request/context/predicate/RangePredicate.java b/pinot-common/src/main/java/org/apache/pinot/common/request/context/predicate/RangePredicate.java index 61b5ebbe9191..0e08d00e3726 100644 --- a/pinot-common/src/main/java/org/apache/pinot/common/request/context/predicate/RangePredicate.java +++ b/pinot-common/src/main/java/org/apache/pinot/common/request/context/predicate/RangePredicate.java @@ -140,4 +140,20 @@ public String toString() { return "(" + _lhs + (_lowerInclusive ? " >= '" : " > '") + _lowerBound + "' AND " + _lhs + (_upperInclusive ? " <= '" : " < '") + _upperBound + "')"; } + + public static String getGreatRange(String value) { + return LOWER_EXCLUSIVE + value + DELIMITER + UNBOUNDED + UPPER_EXCLUSIVE; + } + + public static String getLessRange(String value) { + return LOWER_EXCLUSIVE + UNBOUNDED + DELIMITER + value + UPPER_EXCLUSIVE; + } + + public static String getGreatEqualRange(String value) { + return LOWER_INCLUSIVE + value + DELIMITER + UNBOUNDED + UPPER_EXCLUSIVE; + } + + public static String getLessEqualRange(String value) { + return LOWER_EXCLUSIVE + UNBOUNDED + DELIMITER + value + UPPER_INCLUSIVE; + } } diff --git a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java index 24850868e9df..24161147a170 100644 --- a/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java +++ b/pinot-core/src/main/java/org/apache/pinot/core/operator/transform/function/BinaryOperatorTransformFunction.java @@ -23,11 +23,22 @@ import java.util.List; import java.util.Map; import org.apache.pinot.common.function.TransformFunctionType; +import org.apache.pinot.common.request.context.ExpressionContext; +import org.apache.pinot.common.request.context.predicate.EqPredicate; +import org.apache.pinot.common.request.context.predicate.NotEqPredicate; +import org.apache.pinot.common.request.context.predicate.RangePredicate; import org.apache.pinot.core.operator.ColumnContext; import org.apache.pinot.core.operator.blocks.ValueBlock; +import org.apache.pinot.core.operator.filter.predicate.EqualsPredicateEvaluatorFactory; +import org.apache.pinot.core.operator.filter.predicate.NotEqualsPredicateEvaluatorFactory; +import org.apache.pinot.core.operator.filter.predicate.PredicateEvaluator; +import org.apache.pinot.core.operator.filter.predicate.RangePredicateEvaluatorFactory; import org.apache.pinot.core.operator.transform.TransformResultMetadata; import org.apache.pinot.spi.data.FieldSpec.DataType; import org.apache.pinot.spi.utils.ByteArray; +import org.apache.pinot.spi.utils.BytesUtils; +import org.apache.pinot.spi.utils.CommonConstants; +import org.roaringbitmap.RoaringBitmap; /** @@ -48,6 +59,8 @@ public abstract class BinaryOperatorTransformFunction extends BaseTransformFunct protected TransformFunction _rightTransformFunction; protected DataType _leftStoredType; protected DataType _rightStoredType; + protected PredicateEvaluator _predicateEvaluator; + protected boolean _isNull; protected BinaryOperatorTransformFunction(TransformFunctionType transformFunctionType) { // translate to integer in [0, 5] for guaranteed tableswitch @@ -91,6 +104,91 @@ public void init(List arguments, Map c _rightTransformFunction = arguments.get(1); _leftStoredType = _leftTransformFunction.getResultMetadata().getDataType().getStoredType(); _rightStoredType = _rightTransformFunction.getResultMetadata().getDataType().getStoredType(); + + if (_leftTransformFunction instanceof IdentifierTransformFunction + && _rightTransformFunction instanceof LiteralTransformFunction) { + IdentifierTransformFunction leftTransformFunction = (IdentifierTransformFunction) _leftTransformFunction; + if (leftTransformFunction.getDictionary() != null) { + LiteralTransformFunction rightTransformFunction = (LiteralTransformFunction) _rightTransformFunction; + if (rightTransformFunction.isNull()) { + _isNull = true; + } + String rightLiteralStr; + switch (_leftStoredType) { + case INT: + rightLiteralStr = Integer.toString(_isNull ? CommonConstants.NullValuePlaceHolder.INT + : new BigDecimal(rightTransformFunction.getStringLiteral()).intValue()); + break; + case LONG: + rightLiteralStr = Long.toString(_isNull ? CommonConstants.NullValuePlaceHolder.LONG + : new BigDecimal(rightTransformFunction.getStringLiteral()).longValue()); + break; + case FLOAT: + rightLiteralStr = Float.toString(_isNull ? CommonConstants.NullValuePlaceHolder.FLOAT + : new BigDecimal(rightTransformFunction.getStringLiteral()).floatValue()); + break; + case DOUBLE: + rightLiteralStr = + Double.toString(_isNull ? CommonConstants.NullValuePlaceHolder.DOUBLE + : new BigDecimal(rightTransformFunction.getStringLiteral()).doubleValue()); + break; + case BIG_DECIMAL: + rightLiteralStr = _isNull ? CommonConstants.NullValuePlaceHolder.BIG_DECIMAL.toString() + : rightTransformFunction.getStringLiteral(); + break; + case STRING: + rightLiteralStr = + _isNull ? CommonConstants.NullValuePlaceHolder.STRING : rightTransformFunction.getStringLiteral(); + break; + case BYTES: + rightLiteralStr = _isNull ? BytesUtils.toHexString(CommonConstants.NullValuePlaceHolder.BYTES) + : rightTransformFunction.getStringLiteral(); + break; + default: + throw new IllegalStateException( + "Unsupported data type for dictionary based predicate: " + _leftStoredType); + } + switch (_op) { + case EQUALS: + _predicateEvaluator = EqualsPredicateEvaluatorFactory.newDictionaryBasedEvaluator( + new EqPredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()), + rightLiteralStr), leftTransformFunction.getDictionary(), _leftStoredType); + break; + case NOT_EQUAL: + _predicateEvaluator = NotEqualsPredicateEvaluatorFactory.newDictionaryBasedEvaluator( + new NotEqPredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()), + rightLiteralStr), leftTransformFunction.getDictionary(), _leftStoredType); + break; + case GREATER_THAN_OR_EQUAL: + _predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator( + new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()), + RangePredicate.getGreatEqualRange(rightLiteralStr)), leftTransformFunction.getDictionary(), + _leftStoredType); + break; + case GREATER_THAN: + _predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator( + new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()), + RangePredicate.getGreatRange(rightLiteralStr)), leftTransformFunction.getDictionary(), + _leftStoredType); + break; + case LESS_THAN: + _predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator( + new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()), + RangePredicate.getLessRange(rightLiteralStr)), leftTransformFunction.getDictionary(), + _leftStoredType); + break; + case LESS_THAN_OR_EQUAL: + _predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator( + new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()), + RangePredicate.getLessEqualRange(rightLiteralStr)), leftTransformFunction.getDictionary(), + _leftStoredType); + break; + default: + throw new IllegalStateException("Unsupported operation for dictionary based predicate: " + _op); + } + } + } + // Data type check: left and right types should be compatible. if (_leftStoredType == DataType.BYTES || _rightStoredType == DataType.BYTES) { Preconditions.checkState(_leftStoredType == _rightStoredType, String.format( @@ -114,39 +212,68 @@ public int[] transformToIntValuesSV(ValueBlock valueBlock) { private void fillResultArray(ValueBlock valueBlock) { int length = valueBlock.getNumDocs(); initIntValuesSV(length); - switch (_leftStoredType) { - case INT: - fillResultInt(valueBlock, length); - break; - case LONG: - fillResultLong(valueBlock, length); - break; - case FLOAT: - fillResultFloat(valueBlock, length); - break; - case DOUBLE: - fillResultDouble(valueBlock, length); - break; - case BIG_DECIMAL: - fillResultBigDecimal(valueBlock, length); - break; - case STRING: - fillResultString(valueBlock, length); - break; - case BYTES: - fillResultBytes(valueBlock, length); - break; - case UNKNOWN: - fillResultUnknown(length); - break; - // NOTE: Multi-value columns are not comparable, so we should not reach here - default: - throw illegalState(); + if (_isNull) { + // If nullBitMap exists, then use it to fill the result + RoaringBitmap nullBitmap = _leftTransformFunction.getNullBitmap(valueBlock); + if (nullBitmap != null) { + if (_op == EQUALS) { + for (int i = 0; i < length; i++) { + _intValuesSV[i] = nullBitmap.contains(i) ? 1 : 0; + } + } else { + for (int i = 0; i < length; i++) { + _intValuesSV[i] = nullBitmap.contains(i) ? 0 : 1; + } + } + } + return; + } + if (_leftTransformFunction.getDictionary() != null && _predicateEvaluator != null) { + int[] dictIds = _leftTransformFunction.transformToDictIdsSV(valueBlock); + for (int i = 0; i < dictIds.length; i++) { + _intValuesSV[i] = _predicateEvaluator.applySV(dictIds[i]) ? 1 : 0; + } + } else { + switch (_leftStoredType) { + case INT: + fillResultInt(valueBlock, length); + break; + case LONG: + fillResultLong(valueBlock, length); + break; + case FLOAT: + fillResultFloat(valueBlock, length); + break; + case DOUBLE: + fillResultDouble(valueBlock, length); + break; + case BIG_DECIMAL: + fillResultBigDecimal(valueBlock, length); + break; + case STRING: + fillResultString(valueBlock, length); + break; + case BYTES: + fillResultBytes(valueBlock, length); + break; + case UNKNOWN: + fillResultUnknown(length); + break; + // NOTE: Multi-value columns are not comparable, so we should not reach here + default: + throw illegalState(); + } } } private void fillResultInt(ValueBlock valueBlock, int length) { int[] leftIntValues = _leftTransformFunction.transformToIntValuesSV(valueBlock); + if (_predicateEvaluator != null) { + for (int i = 0; i < length; i++) { + _intValuesSV[i] = _predicateEvaluator.applySV(leftIntValues[i]) ? 1 : 0; + } + return; + } switch (_rightStoredType) { case INT: fillIntResultArray(valueBlock, leftIntValues, length); @@ -176,6 +303,12 @@ private void fillResultInt(ValueBlock valueBlock, int length) { private void fillResultLong(ValueBlock valueBlock, int length) { long[] leftLongValues = _leftTransformFunction.transformToLongValuesSV(valueBlock); + if (_predicateEvaluator != null) { + for (int i = 0; i < length; i++) { + _intValuesSV[i] = _predicateEvaluator.applySV(leftLongValues[i]) ? 1 : 0; + } + return; + } switch (_rightStoredType) { case INT: fillIntResultArray(valueBlock, leftLongValues, length); @@ -205,6 +338,12 @@ private void fillResultLong(ValueBlock valueBlock, int length) { private void fillResultFloat(ValueBlock valueBlock, int length) { float[] leftFloatValues = _leftTransformFunction.transformToFloatValuesSV(valueBlock); + if (_predicateEvaluator != null) { + for (int i = 0; i < length; i++) { + _intValuesSV[i] = _predicateEvaluator.applySV(leftFloatValues[i]) ? 1 : 0; + } + return; + } switch (_rightStoredType) { case INT: fillIntResultArray(valueBlock, leftFloatValues, length); @@ -234,6 +373,12 @@ private void fillResultFloat(ValueBlock valueBlock, int length) { private void fillResultDouble(ValueBlock valueBlock, int length) { double[] leftDoubleValues = _leftTransformFunction.transformToDoubleValuesSV(valueBlock); + if (_predicateEvaluator != null) { + for (int i = 0; i < length; i++) { + _intValuesSV[i] = _predicateEvaluator.applySV(leftDoubleValues[i]) ? 1 : 0; + } + return; + } switch (_rightStoredType) { case INT: fillIntResultArray(valueBlock, leftDoubleValues, length); @@ -263,6 +408,12 @@ private void fillResultDouble(ValueBlock valueBlock, int length) { private void fillResultBigDecimal(ValueBlock valueBlock, int length) { BigDecimal[] leftBigDecimalValues = _leftTransformFunction.transformToBigDecimalValuesSV(valueBlock); + if (_predicateEvaluator != null) { + for (int i = 0; i < length; i++) { + _intValuesSV[i] = _predicateEvaluator.applySV(leftBigDecimalValues[i]) ? 1 : 0; + } + return; + } switch (_rightStoredType) { case INT: fillIntResultArray(valueBlock, leftBigDecimalValues, length); @@ -299,6 +450,12 @@ private IllegalStateException illegalState() { private void fillResultString(ValueBlock valueBlock, int length) { String[] leftStringValues = _leftTransformFunction.transformToStringValuesSV(valueBlock); + if (_predicateEvaluator != null) { + for (int i = 0; i < length; i++) { + _intValuesSV[i] = _predicateEvaluator.applySV(leftStringValues[i]) ? 1 : 0; + } + return; + } String[] rightStringValues = _rightTransformFunction.transformToStringValuesSV(valueBlock); for (int i = 0; i < length; i++) { _intValuesSV[i] = getIntResult(leftStringValues[i].compareTo(rightStringValues[i])); @@ -307,6 +464,12 @@ private void fillResultString(ValueBlock valueBlock, int length) { private void fillResultBytes(ValueBlock valueBlock, int length) { byte[][] leftBytesValues = _leftTransformFunction.transformToBytesValuesSV(valueBlock); + if (_predicateEvaluator != null) { + for (int i = 0; i < length; i++) { + _intValuesSV[i] = _predicateEvaluator.applySV(leftBytesValues[i]) ? 1 : 0; + } + return; + } byte[][] rightBytesValues = _rightTransformFunction.transformToBytesValuesSV(valueBlock); for (int i = 0; i < length; i++) { _intValuesSV[i] = getIntResult((ByteArray.compare(leftBytesValues[i], rightBytesValues[i])));