Skip to content

Commit

Permalink
[Multi-stage] Fix literal handling (apache#13344)
Browse files Browse the repository at this point in the history
  • Loading branch information
Jackie-Jiang authored Jun 8, 2024
1 parent 61aa6ce commit 1229add
Show file tree
Hide file tree
Showing 20 changed files with 376 additions and 344 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@

import java.util.Objects;
import java.util.Set;
import javax.annotation.Nullable;
import org.apache.pinot.common.request.Literal;
import org.apache.pinot.spi.data.FieldSpec;
import org.apache.pinot.spi.data.FieldSpec.DataType;


/**
Expand All @@ -42,12 +43,16 @@ public enum Type {
// Only set when the _type is LITERAL
private final LiteralContext _literal;

public static ExpressionContext forLiteralContext(Literal literal) {
return new ExpressionContext(Type.LITERAL, null, null, new LiteralContext(literal));
public static ExpressionContext forLiteral(LiteralContext literal) {
return new ExpressionContext(Type.LITERAL, null, null, literal);
}

public static ExpressionContext forLiteralContext(FieldSpec.DataType type, Object val) {
return new ExpressionContext(Type.LITERAL, null, null, new LiteralContext(type, val));
public static ExpressionContext forLiteral(Literal literal) {
return forLiteral(new LiteralContext(literal));
}

public static ExpressionContext forLiteral(DataType type, @Nullable Object value) {
return forLiteral(new LiteralContext(type, value));
}

public static ExpressionContext forIdentifier(String identifier) {
Expand All @@ -70,7 +75,7 @@ public Type getType() {
}

// Please check the _type of this context is Literal before calling get, otherwise it may return null.
public LiteralContext getLiteral(){
public LiteralContext getLiteral() {
return _literal;
}

Expand Down Expand Up @@ -104,7 +109,8 @@ public boolean equals(Object o) {
return false;
}
ExpressionContext that = (ExpressionContext) o;
return _type == that._type && Objects.equals(_identifier, that._identifier) && Objects.equals(_function, that._function) && Objects.equals(_literal, that._literal);
return _type == that._type && Objects.equals(_identifier, that._identifier) && Objects.equals(_function,
that._function) && Objects.equals(_literal, that._literal);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ public static ExpressionContext getExpression(String expression) {
public static ExpressionContext getExpression(Expression thriftExpression) {
switch (thriftExpression.getType()) {
case LITERAL:
return ExpressionContext.forLiteralContext(thriftExpression.getLiteral());
return ExpressionContext.forLiteral(thriftExpression.getLiteral());
case IDENTIFIER:
return ExpressionContext.forIdentifier(thriftExpression.getIdentifier().getName());
case FUNCTION:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,18 @@
import com.google.common.base.Splitter;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import it.unimi.dsi.fastutil.doubles.DoubleArrayList;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.longs.LongArrayList;
import java.math.BigDecimal;
import java.sql.Timestamp;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import javax.annotation.Nullable;
import org.apache.calcite.sql.SqlLiteral;
import org.apache.calcite.sql.SqlNumericLiteral;
Expand Down Expand Up @@ -102,8 +105,116 @@ public static Expression getIdentifierExpression(String identifier) {
return expression;
}

public static Expression getLiteralExpression(SqlLiteral node) {
Expression expression = new Expression(ExpressionType.LITERAL);
public static Literal getNullLiteral() {
return Literal.nullValue(true);
}

public static Literal getLiteral(boolean value) {
return Literal.boolValue(value);
}

public static Literal getLiteral(int value) {
return Literal.intValue(value);
}

public static Literal getLiteral(long value) {
return Literal.longValue(value);
}

public static Literal getLiteral(float value) {
return Literal.floatValue(Float.floatToRawIntBits(value));
}

public static Literal getLiteral(double value) {
return Literal.doubleValue(value);
}

public static Literal getLiteral(BigDecimal value) {
return Literal.bigDecimalValue(BigDecimalUtils.serialize(value));
}

public static Literal getLiteral(String value) {
return Literal.stringValue(value);
}

public static Literal getLiteral(byte[] value) {
return Literal.binaryValue(value);
}

public static Literal getLiteral(int[] value) {
return Literal.intArrayValue(IntArrayList.wrap(value));
}

public static Literal getLiteral(long[] value) {
return Literal.longArrayValue(LongArrayList.wrap(value));
}

public static Literal getLiteral(float[] value) {
IntArrayList intBitsList = new IntArrayList(value.length);
for (float floatValue : value) {
intBitsList.add(Float.floatToRawIntBits(floatValue));
}
return Literal.floatArrayValue(intBitsList);
}

public static Literal getLiteral(double[] value) {
return Literal.doubleArrayValue(DoubleArrayList.wrap(value));
}

public static Literal getLiteral(String[] value) {
return Literal.stringArrayValue(Arrays.asList(value));
}

public static Literal getLiteral(@Nullable Object object) {
if (object == null) {
return getNullLiteral();
}
if (object instanceof Boolean) {
return RequestUtils.getLiteral((boolean) object);
}
if (object instanceof Integer) {
return RequestUtils.getLiteral((int) object);
}
if (object instanceof Long) {
return RequestUtils.getLiteral((long) object);
}
if (object instanceof Float) {
return RequestUtils.getLiteral((float) object);
}
if (object instanceof Double) {
return RequestUtils.getLiteral((double) object);
}
if (object instanceof BigDecimal) {
return RequestUtils.getLiteral((BigDecimal) object);
}
if (object instanceof Timestamp) {
return RequestUtils.getLiteral(((Timestamp) object).getTime());
}
if (object instanceof String) {
return RequestUtils.getLiteral((String) object);
}
if (object instanceof byte[]) {
return RequestUtils.getLiteral((byte[]) object);
}
if (object instanceof int[]) {
return RequestUtils.getLiteral((int[]) object);
}
if (object instanceof long[]) {
return RequestUtils.getLiteral((long[]) object);
}
if (object instanceof float[]) {
return RequestUtils.getLiteral((float[]) object);
}
if (object instanceof double[]) {
return RequestUtils.getLiteral((double[]) object);
}
if (object instanceof String[]) {
return RequestUtils.getLiteral((String[]) object);
}
return RequestUtils.getLiteral(object.toString());
}

public static Literal getLiteral(SqlLiteral node) {
Literal literal = new Literal();
if (node instanceof SqlNumericLiteral) {
BigDecimal bigDecimalValue = node.bigDecimalValue();
Expand Down Expand Up @@ -133,146 +244,77 @@ public static Expression getLiteralExpression(SqlLiteral node) {
break;
}
}
expression.setLiteral(literal);
return expression;
return literal;
}

public static Expression createNewLiteralExpression() {
public static Expression getLiteralExpression(Literal literal) {
Expression expression = new Expression(ExpressionType.LITERAL);
Literal literal = new Literal();
expression.setLiteral(literal);
return expression;
}

public static Expression getNullLiteralExpression() {
return getLiteralExpression(getNullLiteral());
}

public static Expression getLiteralExpression(boolean value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setBoolValue(value);
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(int value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setIntValue(value);
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(long value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setLongValue(value);
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(float value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setFloatValue(Float.floatToRawIntBits(value));
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(double value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setDoubleValue(value);
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(BigDecimal value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setBigDecimalValue(BigDecimalUtils.serialize(value));
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(String value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setStringValue(value);
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(byte[] value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setBinaryValue(value);
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(int[] value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setIntArrayValue(Arrays.stream(value).boxed().collect(Collectors.toList()));
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(long[] value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setLongArrayValue(Arrays.stream(value).boxed().collect(Collectors.toList()));
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(float[] value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setFloatArrayValue(
IntStream.range(0, value.length).mapToObj(i -> Float.floatToRawIntBits(value[i])).collect(Collectors.toList()));
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(double[] value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setDoubleArrayValue(Arrays.stream(value).boxed().collect(Collectors.toList()));
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getLiteralExpression(String[] value) {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setStringArrayValue(Arrays.asList(value));
return expression;
return getLiteralExpression(getLiteral(value));
}

public static Expression getNullLiteralExpression() {
Expression expression = createNewLiteralExpression();
expression.getLiteral().setNullValue(true);
return expression;
public static Expression getLiteralExpression(SqlLiteral node) {
return getLiteralExpression(getLiteral(node));
}

public static Expression getLiteralExpression(@Nullable Object object) {
if (object == null) {
return getNullLiteralExpression();
}
if (object instanceof Boolean) {
return RequestUtils.getLiteralExpression((boolean) object);
}
if (object instanceof Integer) {
return RequestUtils.getLiteralExpression((int) object);
}
if (object instanceof Long) {
return RequestUtils.getLiteralExpression((long) object);
}
if (object instanceof Float) {
return RequestUtils.getLiteralExpression((float) object);
}
if (object instanceof Double) {
return RequestUtils.getLiteralExpression((double) object);
}
if (object instanceof BigDecimal) {
return RequestUtils.getLiteralExpression((BigDecimal) object);
}
if (object instanceof String) {
return RequestUtils.getLiteralExpression((String) object);
}
if (object instanceof byte[]) {
return RequestUtils.getLiteralExpression((byte[]) object);
}
if (object instanceof int[]) {
return RequestUtils.getLiteralExpression((int[]) object);
}
if (object instanceof long[]) {
return RequestUtils.getLiteralExpression((long[]) object);
}
if (object instanceof float[]) {
return RequestUtils.getLiteralExpression((float[]) object);
}
if (object instanceof double[]) {
return RequestUtils.getLiteralExpression((double[]) object);
}
if (object instanceof String[]) {
return RequestUtils.getLiteralExpression((String[]) object);
}
return RequestUtils.getLiteralExpression(object.toString());
return getLiteralExpression(getLiteral(object));
}

/**
Expand Down
18 changes: 8 additions & 10 deletions pinot-common/src/main/proto/expressions.proto
Original file line number Diff line number Diff line change
Expand Up @@ -50,16 +50,14 @@ message InputRef {

message Literal {
ColumnDataType dataType = 1;
bool isValueNull = 2;
oneof literalField {
bool boolField = 101;
int32 intField = 102;
int64 longField = 103;
float floatField = 104;
double doubleField = 105;
string stringField = 106;
bytes bytesField = 107;
bytes serializedField = 108;
oneof value {
bool null = 2;
int32 int = 3;
int64 long = 4;
float float = 5;
double double = 6;
string string = 7;
bytes bytes = 8;
}
}

Expand Down
Loading

0 comments on commit 1229add

Please sign in to comment.