Skip to content

Commit

Permalink
use dictionary for BinaryOperatorTransformFunction when possible
Browse files Browse the repository at this point in the history
  • Loading branch information
xiangfu0 committed Dec 11, 2024
1 parent 740c110 commit 9778fbb
Showing 1 changed file with 146 additions and 28 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,22 @@
import java.util.List;
import java.util.Map;
import org.apache.pinot.common.function.TransformFunctionType;
import org.apache.pinot.common.request.context.ExpressionContext;
import org.apache.pinot.common.request.context.predicate.EqPredicate;
import org.apache.pinot.common.request.context.predicate.NotEqPredicate;
import org.apache.pinot.common.request.context.predicate.RangePredicate;
import org.apache.pinot.core.operator.ColumnContext;
import org.apache.pinot.core.operator.blocks.ValueBlock;
import org.apache.pinot.core.operator.filter.predicate.EqualsPredicateEvaluatorFactory;
import org.apache.pinot.core.operator.filter.predicate.NotEqualsPredicateEvaluatorFactory;
import org.apache.pinot.core.operator.filter.predicate.PredicateEvaluator;
import org.apache.pinot.core.operator.filter.predicate.RangePredicateEvaluatorFactory;
import org.apache.pinot.core.operator.transform.TransformResultMetadata;
import org.apache.pinot.spi.data.FieldSpec.DataType;
import org.apache.pinot.spi.utils.ByteArray;
import org.apache.pinot.spi.utils.BytesUtils;
import org.apache.pinot.spi.utils.CommonConstants;
import org.roaringbitmap.RoaringBitmap;


/**
Expand All @@ -48,6 +59,9 @@ public abstract class BinaryOperatorTransformFunction extends BaseTransformFunct
protected TransformFunction _rightTransformFunction;
protected DataType _leftStoredType;
protected DataType _rightStoredType;
protected boolean _useDictionary = false;
protected PredicateEvaluator _predicateEvaluator;
protected boolean _isNull;

protected BinaryOperatorTransformFunction(TransformFunctionType transformFunctionType) {
// translate to integer in [0, 5] for guaranteed tableswitch
Expand Down Expand Up @@ -91,6 +105,87 @@ public void init(List<TransformFunction> arguments, Map<String, ColumnContext> c
_rightTransformFunction = arguments.get(1);
_leftStoredType = _leftTransformFunction.getResultMetadata().getDataType().getStoredType();
_rightStoredType = _rightTransformFunction.getResultMetadata().getDataType().getStoredType();

if (_leftTransformFunction instanceof IdentifierTransformFunction
&& _rightTransformFunction instanceof LiteralTransformFunction) {
IdentifierTransformFunction leftTransformFunction = (IdentifierTransformFunction) _leftTransformFunction;
if (leftTransformFunction.getDictionary() != null) {
LiteralTransformFunction rightTransformFunction = (LiteralTransformFunction) _rightTransformFunction;
if (rightTransformFunction.isNull()) {
_isNull = true;
}
String rightLiteralStr;
switch (_leftStoredType) {
case INT:
rightLiteralStr = Integer.toString(_isNull ? CommonConstants.NullValuePlaceHolder.INT
: new BigDecimal(rightTransformFunction.getStringLiteral()).intValue());
break;
case LONG:
rightLiteralStr = Long.toString(_isNull ? CommonConstants.NullValuePlaceHolder.LONG
: new BigDecimal(rightTransformFunction.getStringLiteral()).longValue());
break;
case FLOAT:
rightLiteralStr = Float.toString(_isNull ? CommonConstants.NullValuePlaceHolder.FLOAT
: new BigDecimal(rightTransformFunction.getStringLiteral()).floatValue());
break;
case DOUBLE:
rightLiteralStr =
Double.toString(_isNull ? CommonConstants.NullValuePlaceHolder.DOUBLE
: new BigDecimal(rightTransformFunction.getStringLiteral()).doubleValue());
break;
case BIG_DECIMAL:
rightLiteralStr = _isNull ? CommonConstants.NullValuePlaceHolder.BIG_DECIMAL.toString()
: rightTransformFunction.getStringLiteral();
break;
case STRING:
rightLiteralStr =
_isNull ? CommonConstants.NullValuePlaceHolder.STRING : rightTransformFunction.getStringLiteral();
break;
case BYTES:
rightLiteralStr = _isNull ? BytesUtils.toHexString(CommonConstants.NullValuePlaceHolder.BYTES)
: rightTransformFunction.getStringLiteral();
break;
default:
throw new IllegalStateException(
"Unsupported data type for dictionary based predicate: " + _leftStoredType);
}
switch (_op) {
case EQUALS:
_predicateEvaluator = EqualsPredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new EqPredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
rightLiteralStr), leftTransformFunction.getDictionary(), _leftStoredType);
break;
case NOT_EQUAL:
_predicateEvaluator = NotEqualsPredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new NotEqPredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
rightLiteralStr), leftTransformFunction.getDictionary(), _leftStoredType);
break;
case GREATER_THAN_OR_EQUAL:
_predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
"[" + rightLiteralStr + "\0*)"), leftTransformFunction.getDictionary(), _leftStoredType);
break;
case GREATER_THAN:
_predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
"(" + rightLiteralStr + "\0*)"), leftTransformFunction.getDictionary(), _leftStoredType);
break;
case LESS_THAN:
_predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
"(*\0" + rightLiteralStr + "]"), leftTransformFunction.getDictionary(), _leftStoredType);
break;
case LESS_THAN_OR_EQUAL:
_predicateEvaluator = RangePredicateEvaluatorFactory.newDictionaryBasedEvaluator(
new RangePredicate(ExpressionContext.forIdentifier(leftTransformFunction.getColumnName()),
"(*\0" + rightLiteralStr + ")"), leftTransformFunction.getDictionary(), _leftStoredType);
break;
default:
throw new IllegalStateException("Unsupported operation for dictionary based predicate: " + _op);
}
}
}

// Data type check: left and right types should be compatible.
if (_leftStoredType == DataType.BYTES || _rightStoredType == DataType.BYTES) {
Preconditions.checkState(_leftStoredType == _rightStoredType, String.format(
Expand All @@ -114,34 +209,57 @@ public int[] transformToIntValuesSV(ValueBlock valueBlock) {
private void fillResultArray(ValueBlock valueBlock) {
int length = valueBlock.getNumDocs();
initIntValuesSV(length);
switch (_leftStoredType) {
case INT:
fillResultInt(valueBlock, length);
break;
case LONG:
fillResultLong(valueBlock, length);
break;
case FLOAT:
fillResultFloat(valueBlock, length);
break;
case DOUBLE:
fillResultDouble(valueBlock, length);
break;
case BIG_DECIMAL:
fillResultBigDecimal(valueBlock, length);
break;
case STRING:
fillResultString(valueBlock, length);
break;
case BYTES:
fillResultBytes(valueBlock, length);
break;
case UNKNOWN:
fillResultUnknown(length);
break;
// NOTE: Multi-value columns are not comparable, so we should not reach here
default:
throw illegalState();
if (_isNull) {
// If nullBitMap exists, then use it to fill the result
RoaringBitmap nullBitmap = _leftTransformFunction.getNullBitmap(valueBlock);
if (nullBitmap != null) {
if (_op == EQUALS) {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = nullBitmap.contains(i) ? 1 : 0;
}
} else {
for (int i = 0; i < length; i++) {
_intValuesSV[i] = nullBitmap.contains(i) ? 0 : 1;
}
}
}
return;
}
if (_useDictionary) {
int[] dictIds = _leftTransformFunction.transformToDictIdsSV(valueBlock);
for (int i = 0; i < dictIds.length; i++) {
_intValuesSV[i] = _predicateEvaluator.applySV(dictIds[i]) ? 1 : 0;
}
} else {
switch (_leftStoredType) {
case INT:
fillResultInt(valueBlock, length);
break;
case LONG:
fillResultLong(valueBlock, length);
break;
case FLOAT:
fillResultFloat(valueBlock, length);
break;
case DOUBLE:
fillResultDouble(valueBlock, length);
break;
case BIG_DECIMAL:
fillResultBigDecimal(valueBlock, length);
break;
case STRING:
fillResultString(valueBlock, length);
break;
case BYTES:
fillResultBytes(valueBlock, length);
break;
case UNKNOWN:
fillResultUnknown(length);
break;
// NOTE: Multi-value columns are not comparable, so we should not reach here
default:
throw illegalState();
}
}
}

Expand Down

0 comments on commit 9778fbb

Please sign in to comment.