Skip to content

Commit

Permalink
HIVE-27690: Handle casting NULL literal to complex type (Krisztian Ka…
Browse files Browse the repository at this point in the history
…sa, reviewed by Laszlo Vegh)
  • Loading branch information
kasakrisz authored Dec 18, 2023
1 parent a823bab commit fd92b39
Show file tree
Hide file tree
Showing 19 changed files with 875 additions and 195 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,7 @@ castExpression
LPAREN
expression
KW_AS
toType=primitiveType
toType=type
(fmt=KW_FORMAT StringLiteral)?
RPAREN
// simple cast
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -486,6 +486,9 @@ public final class FunctionRegistry {
system.registerGenericUDF("!", GenericUDFOPNot.class);
system.registerGenericUDF("between", GenericUDFBetween.class);
system.registerGenericUDF("in_bloom_filter", GenericUDFInBloomFilter.class);
system.registerGenericUDF("toMap", GenericUDFToMap.class);
system.registerGenericUDF("toArray", GenericUDFToArray.class);
system.registerGenericUDF("toStruct", GenericUDFToStruct.class);

// Utility UDFs
system.registerUDF("version", UDFVersion.class, false);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -180,9 +180,7 @@ public static ASTNode emptyPlan(RelDataType dataType) {
ASTBuilder select = ASTBuilder.construct(HiveParser.TOK_SELECT, "TOK_SELECT");
for (int i = 0; i < dataType.getFieldCount(); ++i) {
RelDataTypeField fieldType = dataType.getFieldList().get(i);
select.add(ASTBuilder.selectExpr(
createNullField(fieldType.getType()),
fieldType.getName()));
select.add(ASTBuilder.selectExpr(createNullField(fieldType.getType()), fieldType.getName()));
}

ASTNode insert = ASTBuilder.
Expand All @@ -203,53 +201,52 @@ private static ASTNode createNullField(RelDataType fieldType) {
return ASTBuilder.construct(HiveParser.TOK_NULL, "TOK_NULL").node();
}

ASTNode astNode = convertType(fieldType);
return ASTBuilder.construct(HiveParser.TOK_FUNCTION, "TOK_FUNCTION")
.add(astNode)
.add(HiveParser.TOK_NULL, "TOK_NULL")
.node();
}

static ASTNode convertType(RelDataType fieldType) {
if (fieldType.getSqlTypeName() == SqlTypeName.NULL) {
return ASTBuilder.construct(HiveParser.TOK_NULL, "TOK_NULL").node();
}

if (fieldType.getSqlTypeName() == SqlTypeName.ROW) {
ASTBuilder namedStructCallNode = ASTBuilder.construct(HiveParser.TOK_FUNCTION, "TOK_FUNCTION");
namedStructCallNode.add(HiveParser.Identifier, "named_struct");
ASTBuilder columnListNode = ASTBuilder.construct(HiveParser.TOK_TABCOLLIST, "TOK_TABCOLLIST");
for (RelDataTypeField structFieldType : fieldType.getFieldList()) {
namedStructCallNode.add(HiveParser.Identifier, structFieldType.getName());
namedStructCallNode.add(createNullField(structFieldType.getType()));
ASTNode colNode = ASTBuilder.construct(HiveParser.TOK_TABCOL, "TOK_TABCOL")
.add(HiveParser.Identifier, structFieldType.getName())
.add(convertType(structFieldType.getType()))
.node();
columnListNode.add(colNode);
}
return namedStructCallNode.node();
return ASTBuilder.construct(HiveParser.TOK_STRUCT, "TOK_STRUCT").add(columnListNode).node();
}

if (fieldType.getSqlTypeName() == SqlTypeName.MAP) {
ASTBuilder mapCallNode = ASTBuilder.construct(HiveParser.TOK_FUNCTION, "TOK_FUNCTION");
mapCallNode.add(HiveParser.Identifier, "map");
mapCallNode.add(createNullField(fieldType.getKeyType()));
mapCallNode.add(createNullField(fieldType.getValueType()));
ASTBuilder mapCallNode = ASTBuilder.construct(HiveParser.TOK_MAP, "TOK_MAP");
mapCallNode.add(convertType(fieldType.getKeyType()));
mapCallNode.add(convertType(fieldType.getValueType()));
return mapCallNode.node();
}

if (fieldType.getSqlTypeName() == SqlTypeName.ARRAY) {
ASTBuilder arrayCallNode = ASTBuilder.construct(HiveParser.TOK_FUNCTION, "TOK_FUNCTION");
arrayCallNode.add(HiveParser.Identifier, "array");
arrayCallNode.add(createNullField(fieldType.getComponentType()));
ASTBuilder arrayCallNode = ASTBuilder.construct(HiveParser.TOK_LIST, "TOK_LIST");
arrayCallNode.add(convertType(fieldType.getComponentType()));
return arrayCallNode.node();
}

return createCastNull(fieldType);
}

private static ASTNode createCastNull(RelDataType fieldType) {
HiveToken ht = TypeConverter.hiveToken(fieldType);
ASTNode typeNode;
if (ht == null) {
typeNode = ASTBuilder.construct(
HiveParser.Identifier, fieldType.getSqlTypeName().getName().toLowerCase()).node();
} else {
ASTBuilder typeNodeBuilder = ASTBuilder.construct(ht.type, ht.text);
if (ht.args != null) {
for (String castArg : ht.args) {
typeNodeBuilder.add(HiveParser.Identifier, castArg);
}
ASTBuilder astBldr = ASTBuilder.construct(ht.type, ht.text);
if (ht.args != null) {
for (String castArg : ht.args) {
astBldr.add(HiveParser.Identifier, castArg);
}
typeNode = typeNodeBuilder.node();
}
return ASTBuilder.construct(HiveParser.TOK_FUNCTION, "TOK_FUNCTION")
.add(typeNode)
.add(HiveParser.TOK_NULL, "TOK_NULL")
.node();

return astBldr.node();
}

private ASTNode convert() throws CalciteSemanticException {
Expand Down Expand Up @@ -1042,22 +1039,7 @@ public ASTNode visitCall(RexCall call) {
Collections.singletonList(SqlFunctionConverter.buildAST(SqlStdOperatorTable.IS_NOT_DISTINCT_FROM, astNodeLst, call.getType())), call.getType());
case CAST:
assert(call.getOperands().size() == 1);
if (call.getType().isStruct() ||
SqlTypeName.MAP.equals(call.getType().getSqlTypeName()) ||
SqlTypeName.ARRAY.equals(call.getType().getSqlTypeName())) {
// cast for complex types can be ignored safely because explicit casting on such
// types are not possible, implicit casting e.g. CAST(ROW__ID as <...>) can be ignored
return call.getOperands().get(0).accept(this);
}

HiveToken ht = TypeConverter.hiveToken(call.getType());
ASTBuilder astBldr = ASTBuilder.construct(ht.type, ht.text);
if (ht.args != null) {
for (String castArg : ht.args) {
astBldr.add(HiveParser.Identifier, castArg);
}
}
astNodeLst.add(astBldr.node());
astNodeLst.add(convertType(call.getType()));
astNodeLst.add(call.getOperands().get(0).accept(this));
break;
case EXTRACT:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,14 @@
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFCase;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFIn;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFTimestamp;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToArray;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToBinary;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToChar;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToDate;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToDecimal;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToMap;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToString;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToStruct;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToTimestampLocalTZ;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToUnixTimeStamp;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDFToVarchar;
Expand Down Expand Up @@ -334,7 +337,10 @@ public static RexNode handleExplicitCast(GenericUDF udf, RelDataType returnType,
|| (udf instanceof GenericUDFToString)
|| (udf instanceof GenericUDFToDecimal) || (udf instanceof GenericUDFToDate)
|| (udf instanceof GenericUDFTimestamp) || (udf instanceof GenericUDFToTimestampLocalTZ)
|| (udf instanceof GenericUDFToBinary) || castExprUsingUDFBridge(udf)) {
|| (udf instanceof GenericUDFToBinary) || castExprUsingUDFBridge(udf)
|| (udf instanceof GenericUDFToMap)
|| (udf instanceof GenericUDFToArray)
|| (udf instanceof GenericUDFToStruct)) {
castExpr = rexBuilder.makeAbstractCast(returnType, childRexNodeLst.get(0));
}
}
Expand Down
51 changes: 50 additions & 1 deletion ql/src/java/org/apache/hadoop/hive/ql/parse/ParseUtils.java
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Queue;
Expand Down Expand Up @@ -60,13 +61,20 @@
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorConverters;
import org.apache.hadoop.hive.serde2.typeinfo.CharTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.DecimalTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.ListTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.MapTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.PrimitiveTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.StructTypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoFactory;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.hive.serde2.typeinfo.VarcharTypeInfo;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import static org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer.getTypeStringFromAST;
import static org.apache.hadoop.hive.ql.parse.BaseSemanticAnalyzer.unescapeIdentifier;


/**
* Library of utility functions used in the parse code.
Expand Down Expand Up @@ -204,7 +212,7 @@ static int checkJoinFilterRefersOneAlias(String[] tabAliases, ASTNode filterCond

switch(filterCondn.getType()) {
case HiveParser.TOK_TABLE_OR_COL:
String tableOrCol = SemanticAnalyzer.unescapeIdentifier(filterCondn.getChild(0).getText()
String tableOrCol = unescapeIdentifier(filterCondn.getChild(0).getText()
.toLowerCase());
return getIndex(tabAliases, tableOrCol);
case HiveParser.Identifier:
Expand Down Expand Up @@ -725,4 +733,45 @@ public static final class ReparseResult {
}
}

public static TypeInfo getComplexTypeTypeInfo(ASTNode typeNode) throws SemanticException {
switch (typeNode.getType()) {
case HiveParser.TOK_LIST:
ListTypeInfo listTypeInfo = new ListTypeInfo();
listTypeInfo.setListElementTypeInfo(getComplexTypeTypeInfo((ASTNode) typeNode.getChild(0)));
return listTypeInfo;
case HiveParser.TOK_MAP:
MapTypeInfo mapTypeInfo = new MapTypeInfo();
String keyTypeString = getTypeStringFromAST((ASTNode) typeNode.getChild(0));
mapTypeInfo.setMapKeyTypeInfo(TypeInfoFactory.getPrimitiveTypeInfo(keyTypeString));
mapTypeInfo.setMapValueTypeInfo(getComplexTypeTypeInfo((ASTNode) typeNode.getChild(1)));
return mapTypeInfo;
case HiveParser.TOK_STRUCT:
StructTypeInfo structTypeInfo = new StructTypeInfo();
Map<String, TypeInfo> fields = collectStructFieldNames(typeNode);
structTypeInfo.setAllStructFieldNames(new ArrayList<>(fields.keySet()));
structTypeInfo.setAllStructFieldTypeInfos(new ArrayList<>(fields.values()));
return structTypeInfo;
default:
String typeString = getTypeStringFromAST(typeNode);
return TypeInfoFactory.getPrimitiveTypeInfo(typeString);
}
}

private static Map<String, TypeInfo> collectStructFieldNames(ASTNode structTypeNode) throws SemanticException {
ASTNode fieldListNode = (ASTNode) structTypeNode.getChild(0);
assert fieldListNode.getType() == HiveParser.TOK_TABCOLLIST;

Map<String, TypeInfo> result = new LinkedHashMap<>(fieldListNode.getChildCount());
for (int i = 0; i < fieldListNode.getChildCount(); i++) {
ASTNode child = (ASTNode) fieldListNode.getChild(i);

String attributeIdentifier = unescapeIdentifier(child.getChild(0).getText());
if (result.containsKey(attributeIdentifier)) {
throw new SemanticException(ErrorMsg.AMBIGUOUS_STRUCT_ATTRIBUTE, attributeIdentifier);
} else {
result.put(attributeIdentifier, getComplexTypeTypeInfo((ASTNode) child.getChild(1)));
}
}
return result;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -113,30 +113,40 @@ protected boolean isExprInstance(Object o) {
protected ExprNodeDesc toExpr(ColumnInfo colInfo, RowResolver rowResolver, int offset)
throws SemanticException {
ObjectInspector inspector = colInfo.getObjectInspector();
if (inspector instanceof ConstantObjectInspector && inspector instanceof PrimitiveObjectInspector) {
return toPrimitiveConstDesc(colInfo, inspector);
}
if (inspector instanceof ConstantObjectInspector && inspector instanceof ListObjectInspector) {
ObjectInspector listElementOI = ((ListObjectInspector)inspector).getListElementObjectInspector();
if (listElementOI instanceof PrimitiveObjectInspector) {
return toListConstDesc(colInfo, inspector, listElementOI);
if (inspector instanceof ConstantObjectInspector) {
if (inspector instanceof PrimitiveObjectInspector) {
return toPrimitiveConstDesc(colInfo, inspector);
}
}
if (inspector instanceof ConstantObjectInspector && inspector instanceof MapObjectInspector) {
ObjectInspector keyOI = ((MapObjectInspector)inspector).getMapKeyObjectInspector();
ObjectInspector valueOI = ((MapObjectInspector)inspector).getMapValueObjectInspector();
if (keyOI instanceof PrimitiveObjectInspector && valueOI instanceof PrimitiveObjectInspector) {
return toMapConstDesc(colInfo, inspector, keyOI, valueOI);

Object inputConstantValue = ((ConstantObjectInspector) inspector).getWritableConstantValue();
if (inputConstantValue == null) {
return createExprNodeConstantDesc(colInfo, null);
}
}
if (inspector instanceof ConstantObjectInspector && inspector instanceof StructObjectInspector) {
boolean allPrimitive = true;
List<? extends StructField> fields = ((StructObjectInspector)inspector).getAllStructFieldRefs();
for (StructField field : fields) {
allPrimitive &= field.getFieldObjectInspector() instanceof PrimitiveObjectInspector;

if (inspector instanceof ListObjectInspector) {
ObjectInspector listElementOI = ((ListObjectInspector) inspector).getListElementObjectInspector();
if (listElementOI instanceof PrimitiveObjectInspector) {
PrimitiveObjectInspector poi = (PrimitiveObjectInspector) listElementOI;
return createExprNodeConstantDesc(colInfo, toListConstant((List<?>) inputConstantValue, poi));
}
}
if (allPrimitive) {
return toStructConstDesc(colInfo, inspector, fields);
if (inspector instanceof MapObjectInspector) {
ObjectInspector keyOI = ((MapObjectInspector)inspector).getMapKeyObjectInspector();
ObjectInspector valueOI = ((MapObjectInspector)inspector).getMapValueObjectInspector();
if (keyOI instanceof PrimitiveObjectInspector && valueOI instanceof PrimitiveObjectInspector) {
return createExprNodeConstantDesc(colInfo, toMapConstant((Map<?, ?>) inputConstantValue, keyOI, valueOI));
}
}
if (inspector instanceof StructObjectInspector) {
boolean allPrimitive = true;
List<? extends StructField> fields = ((StructObjectInspector)inspector).getAllStructFieldRefs();
for (StructField field : fields) {
allPrimitive &= field.getFieldObjectInspector() instanceof PrimitiveObjectInspector;
}
if (allPrimitive) {
return createExprNodeConstantDesc(colInfo, toStructConstDesc(
(List<?>) ((ConstantObjectInspector) inspector).getWritableConstantValue(), fields));
}
}
}
// non-constant or non-primitive constants
Expand All @@ -145,6 +155,13 @@ protected ExprNodeDesc toExpr(ColumnInfo colInfo, RowResolver rowResolver, int o
return column;
}

private static ExprNodeConstantDesc createExprNodeConstantDesc(ColumnInfo colInfo, Object constantValue) {
ExprNodeConstantDesc constantExpr = new ExprNodeConstantDesc(colInfo.getType(), constantValue);
constantExpr.setFoldedFromCol(colInfo.getInternalName());
constantExpr.setFoldedFromTab(colInfo.getTabAlias());
return constantExpr;
}

private static ExprNodeConstantDesc toPrimitiveConstDesc(ColumnInfo colInfo, ObjectInspector inspector) {
PrimitiveObjectInspector poi = (PrimitiveObjectInspector) inspector;
Object constant = ((ConstantObjectInspector) inspector).getWritableConstantValue();
Expand All @@ -155,50 +172,33 @@ private static ExprNodeConstantDesc toPrimitiveConstDesc(ColumnInfo colInfo, Obj
return constantExpr;
}

private static ExprNodeConstantDesc toListConstDesc(ColumnInfo colInfo, ObjectInspector inspector,
ObjectInspector listElementOI) {
PrimitiveObjectInspector poi = (PrimitiveObjectInspector)listElementOI;
List<?> values = (List<?>)((ConstantObjectInspector) inspector).getWritableConstantValue();
List<Object> constant = new ArrayList<Object>();
for (Object o : values) {
private static List<Object> toListConstant(List<?> constantValue, PrimitiveObjectInspector poi) {
List<Object> constant = new ArrayList<>(constantValue.size());
for (Object o : constantValue) {
constant.add(poi.getPrimitiveJavaObject(o));
}

ExprNodeConstantDesc constantExpr = new ExprNodeConstantDesc(colInfo.getType(), constant);
constantExpr.setFoldedFromCol(colInfo.getInternalName());
constantExpr.setFoldedFromTab(colInfo.getTabAlias());
return constantExpr;
return constant;
}

private static ExprNodeConstantDesc toMapConstDesc(ColumnInfo colInfo, ObjectInspector inspector,
ObjectInspector keyOI, ObjectInspector valueOI) {
PrimitiveObjectInspector keyPoi = (PrimitiveObjectInspector)keyOI;
PrimitiveObjectInspector valuePoi = (PrimitiveObjectInspector)valueOI;
Map<?, ?> values = (Map<?, ?>)((ConstantObjectInspector) inspector).getWritableConstantValue();
Map<Object, Object> constant = new LinkedHashMap<Object, Object>();
for (Map.Entry<?, ?> e : values.entrySet()) {
private static Map<Object, Object> toMapConstant(
Map<?, ?> constantValue, ObjectInspector keyOI, ObjectInspector valueOI) {
PrimitiveObjectInspector keyPoi = (PrimitiveObjectInspector) keyOI;
PrimitiveObjectInspector valuePoi = (PrimitiveObjectInspector) valueOI;
Map<Object, Object> constant = new LinkedHashMap<>(constantValue.size());
for (Map.Entry<?, ?> e : constantValue.entrySet()) {
constant.put(keyPoi.getPrimitiveJavaObject(e.getKey()), valuePoi.getPrimitiveJavaObject(e.getValue()));
}

ExprNodeConstantDesc constantExpr = new ExprNodeConstantDesc(colInfo.getType(), constant);
constantExpr.setFoldedFromCol(colInfo.getInternalName());
constantExpr.setFoldedFromTab(colInfo.getTabAlias());
return constantExpr;
return constant;
}

private static ExprNodeConstantDesc toStructConstDesc(ColumnInfo colInfo, ObjectInspector inspector,
List<? extends StructField> fields) {
List<?> values = (List<?>)((ConstantObjectInspector) inspector).getWritableConstantValue();
List<Object> constant = new ArrayList<Object>();
for (int i = 0; i < values.size(); i++) {
Object value = values.get(i);
private static List<Object> toStructConstDesc(List<?> constantValue, List<? extends StructField> fields) {
List<Object> constant = new ArrayList<>(constantValue.size());
for (int i = 0; i < constantValue.size(); i++) {
Object value = constantValue.get(i);
PrimitiveObjectInspector fieldPoi = (PrimitiveObjectInspector) fields.get(i).getFieldObjectInspector();
constant.add(fieldPoi.getPrimitiveJavaObject(value));
}
ExprNodeConstantDesc constantExpr = new ExprNodeConstantDesc(colInfo.getType(), constant);
constantExpr.setFoldedFromCol(colInfo.getInternalName());
constantExpr.setFoldedFromTab(colInfo.getTabAlias());
return constantExpr;
return constant;
}

/**
Expand Down
Loading

0 comments on commit fd92b39

Please sign in to comment.