Skip to content

Commit

Permalink
use dictionary for BinaryOperatorTransformFunction when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu0 committed Dec 13, 2024
1 parent 607bed2 commit ff991bf
Show file tree
Hide file tree
Showing 2 changed files with 207 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -140,4 +140,20 @@ public String toString() {
return "(" + _lhs + (_lowerInclusive ? " >= '" : " > '") + _lowerBound + "' AND " + _lhs + (_upperInclusive
? " <= '" : " < '") + _upperBound + "')";
}

public static String getGreatRange(String value) {
return LOWER_EXCLUSIVE + value + DELIMITER + UNBOUNDED + UPPER_EXCLUSIVE;
}

public static String getLessRange(String value) {
return LOWER_EXCLUSIVE + UNBOUNDED + DELIMITER + value + UPPER_EXCLUSIVE;
}

public static String getGreatEqualRange(String value) {
return LOWER_INCLUSIVE + value + DELIMITER + UNBOUNDED + UPPER_EXCLUSIVE;
}

public static String getLessEqualRange(String value) {
return LOWER_EXCLUSIVE + UNBOUNDED + DELIMITER + value + UPPER_INCLUSIVE;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,22 @@
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.predicate.EqPredicate;
import org.apache.pinot.common.request.context.predicate.NotEqPredicate;
import org.apache.pinot.common.request.context.predicate.RangePredicate;
import org.apache.pinot.core.operator.ColumnContext;
import org.apache.pinot.core.operator.blocks.ValueBlock;
import org.apache.pinot.core.operator.filter.predicate.EqualsPredicateEvaluatorFactory;
import org.apache.pinot.core.operator.filter.predicate.NotEqualsPredicateEvaluatorFactory;
import org.apache.pinot.core.operator.filter.predicate.PredicateEvaluator;
import org.apache.pinot.core.operator.filter.predicate.RangePredicateEvaluatorFactory;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.spi.data.FieldSpec.DataType;
import org.apache.pinot.spi.utils.ByteArray;
import org.apache.pinot.spi.utils.BytesUtils;
import org.apache.pinot.spi.utils.CommonConstants;
import org.roaringbitmap.RoaringBitmap;


/**
Expand All @@ -48,6 +59,8 @@ public abstract class BinaryOperatorTransformFunction extends BaseTransformFunct
protected TransformFunction _rightTransformFunction;
protected DataType _leftStoredType;
protected DataType _rightStoredType;
protected PredicateEvaluator _predicateEvaluator;
protected boolean _isNull;

protected BinaryOperatorTransformFunction(TransformFunctionType transformFunctionType) {
// translate to integer in [0, 5] for guaranteed tableswitch
Expand Down Expand Up @@ -91,6 +104,91 @@ public void init(List<TransformFunction> arguments, Map<String, ColumnContext> c
_rightTransformFunction = arguments.get(1);
_leftStoredType = _leftTransformFunction.getResultMetadata().getDataType().getStoredType();
_rightStoredType = _rightTransformFunction.getResultMetadata().getDataType().getStoredType();

if (_leftTransformFunction instanceof IdentifierTransformFunction
&& _rightTransformFunction instanceof LiteralTransformFunction) {
IdentifierTransformFunction leftTransformFunction = (IdentifierTransformFunction) _leftTransformFunction;
if (leftTransformFunction.getDictionary() != null) {
LiteralTransformFunction rightTransformFunction = (LiteralTransformFunction) _rightTransformFunction;
if (rightTransformFunction.isNull()) {
_isNull = true;
}
String rightLiteralStr;
switch (_leftStoredType) {
case INT:
rightLiteralStr = Integer.toString(_isNull ? CommonConstants.NullValuePlaceHolder.INT
: new BigDecimal(rightTransformFunction.getStringLiteral()).intValue());
break;
case LONG:
rightLiteralStr = Long.toString(_isNull ? CommonConstants.NullValuePlaceHolder.LONG
: new BigDecimal(rightTransformFunction.getStringLiteral()).longValue());
break;
case FLOAT:
rightLiteralStr = Float.toString(_isNull ? CommonConstants.NullValuePlaceHolder.FLOAT
: new BigDecimal(rightTransformFunction.getStringLiteral()).floatValue());
break;
case DOUBLE:
rightLiteralStr =
Double.toString(_isNull ? CommonConstants.NullValuePlaceHolder.DOUBLE
: new BigDecimal(rightTransformFunction.getStringLiteral()).doubleValue());
break;
case BIG_DECIMAL:
rightLiteralStr = _isNull ? CommonConstants.NullValuePlaceHolder.BIG_DECIMAL.toString()
: rightTransformFunction.getStringLiteral();
break;
case STRING:
rightLiteralStr =
_isNull ? CommonConstants.NullValuePlaceHolder.STRING : rightTransformFunction.getStringLiteral();
break;
case BYTES:
rightLiteralStr = _isNull ? BytesUtils.toHexString(CommonConstants.NullValuePlaceHolder.BYTES)
: rightTransformFunction.getStringLiteral();
break;
default:
throw new IllegalStateException(
"Unsupported data type for dictionary based predicate: " + _leftStoredType);
}
switch (_op) {
case EQUALS:
_predicateEvaluator = EqualsPredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new EqPredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
rightLiteralStr), leftTransformFunction.getDictionary(), _leftStoredType);
break;
case NOT_EQUAL:
_predicateEvaluator = NotEqualsPredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new NotEqPredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
rightLiteralStr), leftTransformFunction.getDictionary(), _leftStoredType);
break;
case GREATER_THAN_OR_EQUAL:
_predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
RangePredicate.getGreatEqualRange(rightLiteralStr)), leftTransformFunction.getDictionary(),
_leftStoredType);
break;
case GREATER_THAN:
_predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
RangePredicate.getGreatRange(rightLiteralStr)), leftTransformFunction.getDictionary(),
_leftStoredType);
break;
case LESS_THAN:
_predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
RangePredicate.getLessRange(rightLiteralStr)), leftTransformFunction.getDictionary(),
_leftStoredType);
break;
case LESS_THAN_OR_EQUAL:
_predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
RangePredicate.getLessEqualRange(rightLiteralStr)), leftTransformFunction.getDictionary(),
_leftStoredType);
break;
default:
throw new IllegalStateException("Unsupported operation for dictionary based predicate: " + _op);
}
}
}

// Data type check: left and right types should be compatible.
if (_leftStoredType == DataType.BYTES || _rightStoredType == DataType.BYTES) {
Preconditions.checkState(_leftStoredType == _rightStoredType, String.format(
Expand All @@ -114,39 +212,68 @@ public int[] transformToIntValuesSV(ValueBlock valueBlock) {
private void fillResultArray(ValueBlock valueBlock) {
int length = valueBlock.getNumDocs();
initIntValuesSV(length);
switch (_leftStoredType) {
case INT:
fillResultInt(valueBlock, length);
break;
case LONG:
fillResultLong(valueBlock, length);
break;
case FLOAT:
fillResultFloat(valueBlock, length);
break;
case DOUBLE:
fillResultDouble(valueBlock, length);
break;
case BIG_DECIMAL:
fillResultBigDecimal(valueBlock, length);
break;
case STRING:
fillResultString(valueBlock, length);
break;
case BYTES:
fillResultBytes(valueBlock, length);
break;
case UNKNOWN:
fillResultUnknown(length);
break;
// NOTE: Multi-value columns are not comparable, so we should not reach here
default:
throw illegalState();
if (_isNull) {
// If nullBitMap exists, then use it to fill the result
RoaringBitmap nullBitmap = _leftTransformFunction.getNullBitmap(valueBlock);
if (nullBitmap != null) {
if (_op == EQUALS) {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = nullBitmap.contains(i) ? 1 : 0;
}
} else {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = nullBitmap.contains(i) ? 0 : 1;
}
}
}
return;
}
if (_leftTransformFunction.getDictionary() != null && _predicateEvaluator != null) {
int[] dictIds = _leftTransformFunction.transformToDictIdsSV(valueBlock);
for (int i = 0; i < dictIds.length; i++) {
_intValuesSV[i] = _predicateEvaluator.applySV(dictIds[i]) ? 1 : 0;
}
} else {
switch (_leftStoredType) {
case INT:
fillResultInt(valueBlock, length);
break;
case LONG:
fillResultLong(valueBlock, length);
break;
case FLOAT:
fillResultFloat(valueBlock, length);
break;
case DOUBLE:
fillResultDouble(valueBlock, length);
break;
case BIG_DECIMAL:
fillResultBigDecimal(valueBlock, length);
break;
case STRING:
fillResultString(valueBlock, length);
break;
case BYTES:
fillResultBytes(valueBlock, length);
break;
case UNKNOWN:
fillResultUnknown(length);
break;
// NOTE: Multi-value columns are not comparable, so we should not reach here
default:
throw illegalState();
}
}
}

private void fillResultInt(ValueBlock valueBlock, int length) {
int[] leftIntValues = _leftTransformFunction.transformToIntValuesSV(valueBlock);
if (_predicateEvaluator != null) {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = _predicateEvaluator.applySV(leftIntValues[i]) ? 1 : 0;
}
return;
}
switch (_rightStoredType) {
case INT:
fillIntResultArray(valueBlock, leftIntValues, length);
Expand Down Expand Up @@ -176,6 +303,12 @@ private void fillResultInt(ValueBlock valueBlock, int length) {

private void fillResultLong(ValueBlock valueBlock, int length) {
long[] leftLongValues = _leftTransformFunction.transformToLongValuesSV(valueBlock);
if (_predicateEvaluator != null) {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = _predicateEvaluator.applySV(leftLongValues[i]) ? 1 : 0;
}
return;
}
switch (_rightStoredType) {
case INT:
fillIntResultArray(valueBlock, leftLongValues, length);
Expand Down Expand Up @@ -205,6 +338,12 @@ private void fillResultLong(ValueBlock valueBlock, int length) {

private void fillResultFloat(ValueBlock valueBlock, int length) {
float[] leftFloatValues = _leftTransformFunction.transformToFloatValuesSV(valueBlock);
if (_predicateEvaluator != null) {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = _predicateEvaluator.applySV(leftFloatValues[i]) ? 1 : 0;
}
return;
}
switch (_rightStoredType) {
case INT:
fillIntResultArray(valueBlock, leftFloatValues, length);
Expand Down Expand Up @@ -234,6 +373,12 @@ private void fillResultFloat(ValueBlock valueBlock, int length) {

private void fillResultDouble(ValueBlock valueBlock, int length) {
double[] leftDoubleValues = _leftTransformFunction.transformToDoubleValuesSV(valueBlock);
if (_predicateEvaluator != null) {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = _predicateEvaluator.applySV(leftDoubleValues[i]) ? 1 : 0;
}
return;
}
switch (_rightStoredType) {
case INT:
fillIntResultArray(valueBlock, leftDoubleValues, length);
Expand Down Expand Up @@ -263,6 +408,12 @@ private void fillResultDouble(ValueBlock valueBlock, int length) {

private void fillResultBigDecimal(ValueBlock valueBlock, int length) {
BigDecimal[] leftBigDecimalValues = _leftTransformFunction.transformToBigDecimalValuesSV(valueBlock);
if (_predicateEvaluator != null) {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = _predicateEvaluator.applySV(leftBigDecimalValues[i]) ? 1 : 0;
}
return;
}
switch (_rightStoredType) {
case INT:
fillIntResultArray(valueBlock, leftBigDecimalValues, length);
Expand Down Expand Up @@ -299,6 +450,12 @@ private IllegalStateException illegalState() {

private void fillResultString(ValueBlock valueBlock, int length) {
String[] leftStringValues = _leftTransformFunction.transformToStringValuesSV(valueBlock);
if (_predicateEvaluator != null) {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = _predicateEvaluator.applySV(leftStringValues[i]) ? 1 : 0;
}
return;
}
String[] rightStringValues = _rightTransformFunction.transformToStringValuesSV(valueBlock);
for (int i = 0; i < length; i++) {
_intValuesSV[i] = getIntResult(leftStringValues[i].compareTo(rightStringValues[i]));
Expand All @@ -307,6 +464,12 @@ private void fillResultString(ValueBlock valueBlock, int length) {

private void fillResultBytes(ValueBlock valueBlock, int length) {
byte[][] leftBytesValues = _leftTransformFunction.transformToBytesValuesSV(valueBlock);
if (_predicateEvaluator != null) {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = _predicateEvaluator.applySV(leftBytesValues[i]) ? 1 : 0;
}
return;
}
byte[][] rightBytesValues = _rightTransformFunction.transformToBytesValuesSV(valueBlock);
for (int i = 0; i < length; i++) {
_intValuesSV[i] = getIntResult((ByteArray.compare(leftBytesValues[i], rightBytesValues[i])));
Expand Down

0 comments on commit ff991bf

Please sign in to comment.