diff --git a/fe/be-java-extensions/jdbc-scanner/src/main/java/org/apache/doris/jdbc/JdbcDataSourceConfig.java b/fe/be-java-extensions/jdbc-scanner/src/main/java/org/apache/doris/jdbc/JdbcDataSourceConfig.java index dcf576986fef53..5fdbc211ab0161 100644 --- a/fe/be-java-extensions/jdbc-scanner/src/main/java/org/apache/doris/jdbc/JdbcDataSourceConfig.java +++ b/fe/be-java-extensions/jdbc-scanner/src/main/java/org/apache/doris/jdbc/JdbcDataSourceConfig.java @@ -30,11 +30,11 @@ public class JdbcDataSourceConfig { private int batchSize; private TJdbcOperation op; private TOdbcTableType tableType; - private int connectionPoolMinSize; - private int connectionPoolMaxSize; - private int connectionPoolMaxWaitTime; - private int connectionPoolMaxLifeTime; - private boolean connectionPoolKeepAlive; + private int connectionPoolMinSize = 1; + private int connectionPoolMaxSize = 10; + private int connectionPoolMaxWaitTime = 5000; + private int connectionPoolMaxLifeTime = 1800000; + private boolean connectionPoolKeepAlive = false; public String createCacheKey() { return catalogId + jdbcUrl + jdbcUser + jdbcPassword + jdbcDriverUrl + jdbcDriverClass diff --git a/fe/fe-core/src/main/java/org/apache/doris/backup/BackupJob.java b/fe/fe-core/src/main/java/org/apache/doris/backup/BackupJob.java index 597dacc3c071d4..5404fcca064dde 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/backup/BackupJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/backup/BackupJob.java @@ -910,6 +910,9 @@ private void cancelInternal() { } } + // meta info and job info not need save in log when cancel, we need to clean them here + backupMeta = null; + jobInfo = null; releaseSnapshots(); snapshotInfos.clear(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java index 5552b3d1d5c7e5..56016ad071a5ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java +++ b/fe/fe-core/src/main/java/org/apache/doris/catalog/FunctionRegistry.java @@ -53,14 +53,14 @@ public class FunctionRegistry { // to record the global alias function and other udf. private static final String GLOBAL_FUNCTION = "__GLOBAL_FUNCTION__"; - private final Map> name2InternalBuiltinBuilders; + private final Map> name2BuiltinBuilders; private final Map>> name2UdfBuilders; public FunctionRegistry() { - name2InternalBuiltinBuilders = new ConcurrentHashMap<>(); + name2BuiltinBuilders = new ConcurrentHashMap<>(); name2UdfBuilders = new ConcurrentHashMap<>(); - registerBuiltinFunctions(name2InternalBuiltinBuilders); - afterRegisterBuiltinFunctions(name2InternalBuiltinBuilders); + registerBuiltinFunctions(name2BuiltinBuilders); + afterRegisterBuiltinFunctions(name2BuiltinBuilders); } // this function is used to test. @@ -78,12 +78,33 @@ public FunctionBuilder findFunctionBuilder(String name, Object argument) { } public Optional> tryGetBuiltinBuilders(String name) { - List builders = name2InternalBuiltinBuilders.get(name); - return name2InternalBuiltinBuilders.get(name) == null + List builders = name2BuiltinBuilders.get(name); + return name2BuiltinBuilders.get(name) == null ? Optional.empty() : Optional.of(ImmutableList.copyOf(builders)); } + public boolean isAggregateFunction(String dbName, String name) { + name = name.toLowerCase(); + Class aggClass = org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction.class; + if (StringUtils.isEmpty(dbName)) { + List functionBuilders = name2BuiltinBuilders.get(name); + for (FunctionBuilder functionBuilder : functionBuilders) { + if (aggClass.isAssignableFrom(functionBuilder.functionClass())) { + return true; + } + } + } + + List udfBuilders = findUdfBuilder(dbName, name); + for (FunctionBuilder udfBuilder : udfBuilders) { + if (aggClass.isAssignableFrom(udfBuilder.functionClass())) { + return true; + } + } + return false; + } + // currently we only find function by name and arity and args' types. public FunctionBuilder findFunctionBuilder(String dbName, String name, List arguments) { List functionBuilders = null; @@ -92,11 +113,11 @@ public FunctionBuilder findFunctionBuilder(String dbName, String name, List a if (StringUtils.isEmpty(dbName)) { // search internal function only if dbName is empty - functionBuilders = name2InternalBuiltinBuilders.get(name.toLowerCase()); + functionBuilders = name2BuiltinBuilders.get(name.toLowerCase()); if (CollectionUtils.isEmpty(functionBuilders) && AggCombinerFunctionBuilder.isAggStateCombinator(name)) { String nestedName = AggCombinerFunctionBuilder.getNestedName(name); String combinatorSuffix = AggCombinerFunctionBuilder.getCombinatorSuffix(name); - functionBuilders = name2InternalBuiltinBuilders.get(nestedName.toLowerCase()); + functionBuilders = name2BuiltinBuilders.get(nestedName.toLowerCase()); if (functionBuilders != null) { List candidateBuilders = Lists.newArrayListWithCapacity(functionBuilders.size()); for (FunctionBuilder functionBuilder : functionBuilders) { @@ -199,8 +220,8 @@ public void dropUdf(String dbName, String name, List argTypes) { } synchronized (name2UdfBuilders) { Map> builders = name2UdfBuilders.getOrDefault(dbName, ImmutableMap.of()); - builders.getOrDefault(name, Lists.newArrayList()).removeIf(builder -> ((UdfBuilder) builder).getArgTypes() - .equals(argTypes)); + builders.getOrDefault(name, Lists.newArrayList()) + .removeIf(builder -> ((UdfBuilder) builder).getArgTypes().equals(argTypes)); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/load/routineload/RoutineLoadJob.java b/fe/fe-core/src/main/java/org/apache/doris/load/routineload/RoutineLoadJob.java index 68f0e6345e1281..77663be058cd77 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/load/routineload/RoutineLoadJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/load/routineload/RoutineLoadJob.java @@ -1166,7 +1166,8 @@ public void afterVisible(TransactionState txnState, boolean txnOperated) { return; } RoutineLoadTaskInfo routineLoadTaskInfo = routineLoadTaskInfoOptional.get(); - if (routineLoadTaskInfo.getTxnStatus() != TransactionStatus.COMMITTED) { + if (routineLoadTaskInfo.getTxnStatus() != TransactionStatus.COMMITTED + && routineLoadTaskInfo.getTxnStatus() != TransactionStatus.VISIBLE) { // TODO(cmy): Normally, this should not happen. But for safe reason, just pause the job String msg = String.format( "should not happen, we find that task %s is not COMMITTED when handling afterVisble." diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java index a561507fb947a8..04ba599aa60d5b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java @@ -19,11 +19,13 @@ import org.apache.doris.catalog.Env; import org.apache.doris.catalog.FunctionRegistry; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.CascadesContext; import org.apache.doris.nereids.NereidsPlanner; import org.apache.doris.nereids.StatementContext; import org.apache.doris.nereids.analyzer.MappingSlot; import org.apache.doris.nereids.analyzer.Scope; +import org.apache.doris.nereids.analyzer.UnboundFunction; import org.apache.doris.nereids.analyzer.UnboundOneRowRelation; import org.apache.doris.nereids.analyzer.UnboundResultSink; import org.apache.doris.nereids.analyzer.UnboundSlot; @@ -351,12 +353,12 @@ private LogicalHaving bindHaving(MatchingContext> ctx) CascadesContext cascadesContext = ctx.cascadesContext; // bind slot by child.output first - Scope defaultScope = toScope(cascadesContext, childPlan.getOutput()); + Scope childOutput = toScope(cascadesContext, childPlan.getOutput()); // then bind slot by child.children.output - Supplier backupScope = Suppliers.memoize(() -> + Supplier childChildrenOutput = Suppliers.memoize(() -> toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(childPlan.children())) ); - return bindHavingByScopes(having, cascadesContext, defaultScope, backupScope); + return bindHavingByScopes(having, cascadesContext, childOutput, childChildrenOutput); } private LogicalHaving bindHavingAggregate( @@ -365,13 +367,115 @@ private LogicalHaving bindHavingAggregate( Aggregate aggregate = having.child(); CascadesContext cascadesContext = ctx.cascadesContext; - // having(aggregate) should bind slot by aggregate.child.output first - Scope defaultScope = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children())); - // then bind slot by aggregate.output - Supplier backupScope = Suppliers.memoize(() -> - toScope(cascadesContext, aggregate.getOutput()) - ); - return bindHavingByScopes(ctx.root, ctx.cascadesContext, defaultScope, backupScope); + // keep same behavior as mysql + Supplier bindByAggChild = Suppliers.memoize(() -> { + Scope aggChildOutputScope + = toScope(cascadesContext, PlanUtils.fastGetChildrenOutputs(aggregate.children())); + return (analyzer, unboundSlot) -> analyzer.bindSlotByScope(unboundSlot, aggChildOutputScope); + }); + + Scope aggOutputScope = toScope(cascadesContext, aggregate.getOutput()); + Supplier bindByGroupByThenAggOutputThenAggChild = Suppliers.memoize(() -> { + List groupByExprs = aggregate.getGroupByExpressions(); + ImmutableList.Builder groupBySlots + = ImmutableList.builderWithExpectedSize(groupByExprs.size()); + for (Expression groupBy : groupByExprs) { + if (groupBy instanceof Slot) { + groupBySlots.add((Slot) groupBy); + } + } + Scope groupBySlotsScope = toScope(cascadesContext, groupBySlots.build()); + + Supplier> separateAggOutputScopes = Suppliers.memoize(() -> { + ImmutableList.Builder groupByOutputs = ImmutableList.builderWithExpectedSize( + aggregate.getOutputExpressions().size()); + ImmutableList.Builder aggFunOutputs = ImmutableList.builderWithExpectedSize( + aggregate.getOutputExpressions().size()); + for (NamedExpression outputExpression : aggregate.getOutputExpressions()) { + if (outputExpression.anyMatch(AggregateFunction.class::isInstance)) { + aggFunOutputs.add(outputExpression.toSlot()); + } else { + groupByOutputs.add(outputExpression.toSlot()); + } + } + Scope nonAggFunSlotsScope = toScope(cascadesContext, groupByOutputs.build()); + Scope aggFuncSlotsScope = toScope(cascadesContext, aggFunOutputs.build()); + return Pair.of(nonAggFunSlotsScope, aggFuncSlotsScope); + }); + + return (analyzer, unboundSlot) -> { + List boundInGroupBy = analyzer.bindSlotByScope(unboundSlot, groupBySlotsScope); + if (boundInGroupBy.size() == 1) { + return boundInGroupBy; + } + + Pair separateAggOutputScope = separateAggOutputScopes.get(); + List boundInNonAggFuncs = analyzer.bindSlotByScope(unboundSlot, separateAggOutputScope.first); + if (boundInNonAggFuncs.size() == 1) { + return boundInNonAggFuncs; + } + + List boundInAggFuncs = analyzer.bindSlotByScope(unboundSlot, separateAggOutputScope.second); + if (boundInAggFuncs.size() == 1) { + return boundInAggFuncs; + } + + return bindByAggChild.get().bindSlot(analyzer, unboundSlot); + }; + }); + + FunctionRegistry functionRegistry = cascadesContext.getConnectContext().getEnv().getFunctionRegistry(); + ExpressionAnalyzer havingAnalyzer = new ExpressionAnalyzer(having, aggOutputScope, cascadesContext, + false, true) { + private boolean currentIsInAggregateFunction; + + @Override + public Expression visitAggregateFunction(AggregateFunction aggregateFunction, + ExpressionRewriteContext context) { + if (!currentIsInAggregateFunction) { + currentIsInAggregateFunction = true; + try { + return super.visitAggregateFunction(aggregateFunction, context); + } finally { + currentIsInAggregateFunction = false; + } + } else { + return super.visitAggregateFunction(aggregateFunction, context); + } + } + + @Override + public Expression visitUnboundFunction(UnboundFunction unboundFunction, ExpressionRewriteContext context) { + if (!currentIsInAggregateFunction && isAggregateFunction(unboundFunction, functionRegistry)) { + currentIsInAggregateFunction = true; + try { + return super.visitUnboundFunction(unboundFunction, context); + } finally { + currentIsInAggregateFunction = false; + } + } else { + return super.visitUnboundFunction(unboundFunction, context); + } + } + + @Override + protected List bindSlotByThisScope(UnboundSlot unboundSlot) { + if (currentIsInAggregateFunction) { + return bindByAggChild.get().bindSlot(this, unboundSlot); + } else { + return bindByGroupByThenAggOutputThenAggChild.get().bindSlot(this, unboundSlot); + } + } + }; + + Set havingExprs = having.getConjuncts(); + ImmutableSet.Builder analyzedHaving = ImmutableSet.builderWithExpectedSize(havingExprs.size()); + ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext); + for (Expression expression : havingExprs) { + analyzedHaving.add(havingAnalyzer.analyze(expression, rewriteContext)); + } + + return new LogicalHaving<>(analyzedHaving.build(), having.child()); } private LogicalHaving bindHavingByScopes( @@ -764,6 +868,11 @@ private void checkIfOutputAliasNameDuplicatedForGroupBy(Collection e } } + private boolean isAggregateFunction(UnboundFunction unboundFunction, FunctionRegistry functionRegistry) { + return functionRegistry.isAggregateFunction( + unboundFunction.getDbName(), unboundFunction.getName()); + } + private E checkBoundExceptLambda(E expression, Plan plan) { if (expression instanceof Lambda) { return expression; @@ -797,6 +906,12 @@ private SimpleExprAnalyzer buildSimpleExprAnalyzer( boolean enableExactMatch, boolean bindSlotInOuterScope) { List childrenOutputs = PlanUtils.fastGetChildrenOutputs(children); Scope scope = toScope(cascadesContext, childrenOutputs); + return buildSimpleExprAnalyzer(currentPlan, cascadesContext, scope, enableExactMatch, bindSlotInOuterScope); + } + + private SimpleExprAnalyzer buildSimpleExprAnalyzer( + Plan currentPlan, CascadesContext cascadesContext, Scope scope, + boolean enableExactMatch, boolean bindSlotInOuterScope) { ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(cascadesContext); ExpressionAnalyzer expressionAnalyzer = new ExpressionAnalyzer(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java index aa2917ae1e1d9f..1e7d4dbb09d201 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckPolicy.java @@ -24,14 +24,16 @@ import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy; +import org.apache.doris.nereids.trees.plans.logical.LogicalCheckPolicy.RelatedPolicy; import org.apache.doris.nereids.trees.plans.logical.LogicalFileScan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalRelation; import org.apache.doris.nereids.util.ExpressionUtils; import com.google.common.collect.ImmutableList; -import java.util.HashSet; +import java.util.LinkedHashSet; import java.util.List; import java.util.Set; @@ -60,7 +62,7 @@ public List buildRules() { return ctx.root.child(); } LogicalRelation relation = (LogicalRelation) child; - Set combineFilter = new HashSet<>(); + Set combineFilter = new LinkedHashSet<>(); // replace incremental params as AND expression if (relation instanceof LogicalFileScan) { @@ -72,18 +74,20 @@ public List buildRules() { } } - // row policy - checkPolicy.getFilter(relation, ctx.connectContext) - .ifPresent(expression -> combineFilter.addAll( + RelatedPolicy relatedPolicy = checkPolicy.findPolicy(relation, ctx.cascadesContext); + relatedPolicy.rowPolicyFilter.ifPresent(expression -> combineFilter.addAll( ExpressionUtils.extractConjunctionToSet(expression))); - - if (combineFilter.isEmpty()) { - return ctx.root.child(); - } + Plan result = relation; if (upperFilter != null) { combineFilter.addAll(upperFilter.getConjuncts()); } - return new LogicalFilter<>(combineFilter, relation); + if (!combineFilter.isEmpty()) { + result = new LogicalFilter<>(combineFilter, relation); + } + if (relatedPolicy.dataMaskProjects.isPresent()) { + result = new LogicalProject<>(relatedPolicy.dataMaskProjects.get(), result); + } + return result; }) ) ); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggCombinerFunctionBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggCombinerFunctionBuilder.java index 3c514475eedde2..4ef642890cae15 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggCombinerFunctionBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/AggCombinerFunctionBuilder.java @@ -56,6 +56,11 @@ public AggCombinerFunctionBuilder(String combinatorSuffix, FunctionBuilder neste this.nestedBuilder = Objects.requireNonNull(nestedBuilder, "nestedBuilder can not be null"); } + @Override + public Class functionClass() { + return nestedBuilder.functionClass(); + } + @Override public boolean canApply(List arguments) { if (combinatorSuffix.equals(STATE) || combinatorSuffix.equals(FOREACH)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java index 74c4a918cf0aa0..e2dab713332fd6 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/BuiltinFunctionBuilder.java @@ -42,13 +42,21 @@ public class BuiltinFunctionBuilder extends FunctionBuilder { // Concrete BoundFunction's constructor private final Constructor builderMethod; + private final Class functionClass; - public BuiltinFunctionBuilder(Constructor builderMethod) { + public BuiltinFunctionBuilder( + Class functionClass, Constructor builderMethod) { + this.functionClass = Objects.requireNonNull(functionClass, "functionClass can not be null"); this.builderMethod = Objects.requireNonNull(builderMethod, "builderMethod can not be null"); this.arity = builderMethod.getParameterCount(); this.isVariableLength = arity > 0 && builderMethod.getParameterTypes()[arity - 1].isArray(); } + @Override + public Class functionClass() { + return functionClass; + } + @Override public boolean canApply(List arguments) { if (isVariableLength && arity > arguments.size() + 1) { @@ -133,7 +141,7 @@ public static List resolve(Class funct + functionClass.getSimpleName()); return Arrays.stream(functionClass.getConstructors()) .filter(constructor -> Modifier.isPublic(constructor.getModifiers())) - .map(constructor -> new BuiltinFunctionBuilder((Constructor) constructor)) + .map(constructor -> new BuiltinFunctionBuilder(functionClass, (Constructor) constructor)) .collect(ImmutableList.toImmutableList()); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java index 0b3a50a239b8e0..d1e69d3e307d6f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/FunctionBuilder.java @@ -27,6 +27,8 @@ * This class used to build BoundFunction(Builtin or Combinator) by a list of Expressions. */ public abstract class FunctionBuilder { + public abstract Class functionClass(); + /** check whether arguments can apply to the constructor */ public abstract boolean canApply(List arguments); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdfBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdfBuilder.java index 197cb8b396df5d..733bd5fcae1164 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdfBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/AliasUdfBuilder.java @@ -50,6 +50,11 @@ public List getArgTypes() { return aliasUdf.getArgTypes(); } + @Override + public Class functionClass() { + return AliasUdf.class; + } + @Override public boolean canApply(List arguments) { if (arguments.size() != aliasUdf.arity()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdafBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdafBuilder.java index d0c0b067e944d2..89073b4eb77e4d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdafBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdafBuilder.java @@ -49,6 +49,11 @@ public List getArgTypes() { .collect(Collectors.toList())).get(); } + @Override + public Class functionClass() { + return JavaUdaf.class; + } + @Override public boolean canApply(List arguments) { if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdfBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdfBuilder.java index efbcbf9f4835f3..a78f8cfe4b46ae 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdfBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/functions/udf/JavaUdfBuilder.java @@ -51,6 +51,11 @@ public List getArgTypes() { .collect(Collectors.toList())).get(); } + @Override + public Class functionClass() { + return JavaUdf.class; + } + @Override public boolean canApply(List arguments) { if ((isVarArgs && arity > arguments.size() + 1) || (!isVarArgs && arguments.size() != arity)) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java index 54f38034761527..bda3b1f49d570d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalCheckPolicy.java @@ -19,13 +19,20 @@ import org.apache.doris.analysis.UserIdentity; import org.apache.doris.mysql.privilege.AccessControllerManager; +import org.apache.doris.mysql.privilege.DataMaskPolicy; import org.apache.doris.mysql.privilege.RowFilterPolicy; +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.analyzer.UnboundAlias; import org.apache.doris.nereids.exceptions.AnalysisException; import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.properties.LogicalProperties; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.PlanType; import org.apache.doris.nereids.trees.plans.PropagateFuncDeps; @@ -37,6 +44,7 @@ import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; +import org.apache.commons.collections.CollectionUtils; import java.util.ArrayList; import java.util.List; @@ -113,32 +121,58 @@ public Plan withChildren(List children) { } /** - * get wherePredicate of policy for logicalRelation. + * find related policy for logicalRelation. * * @param logicalRelation include tableName and dbName - * @param connectContext include information about user and policy + * @param cascadesContext include information about user and policy */ - public Optional getFilter(LogicalRelation logicalRelation, ConnectContext connectContext) { + public RelatedPolicy findPolicy(LogicalRelation logicalRelation, CascadesContext cascadesContext) { if (!(logicalRelation instanceof CatalogRelation)) { - return Optional.empty(); + return RelatedPolicy.NO_POLICY; } + ConnectContext connectContext = cascadesContext.getConnectContext(); AccessControllerManager accessManager = connectContext.getEnv().getAccessManager(); UserIdentity currentUserIdentity = connectContext.getCurrentUserIdentity(); if (currentUserIdentity.isRootUser() || currentUserIdentity.isAdminUser()) { - return Optional.empty(); + return RelatedPolicy.NO_POLICY; } CatalogRelation catalogRelation = (CatalogRelation) logicalRelation; String ctlName = catalogRelation.getDatabase().getCatalog().getName(); String dbName = catalogRelation.getDatabase().getFullName(); String tableName = catalogRelation.getTable().getName(); - List policies = accessManager.evalRowFilterPolicies(currentUserIdentity, ctlName, - dbName, tableName); - if (policies.isEmpty()) { - return Optional.empty(); + + NereidsParser nereidsParser = new NereidsParser(); + ImmutableList.Builder dataMasks + = ImmutableList.builderWithExpectedSize(logicalRelation.getOutput().size()); + + boolean hasDataMask = false; + for (Slot slot : logicalRelation.getOutput()) { + Optional dataMaskPolicy = accessManager.evalDataMaskPolicy( + currentUserIdentity, ctlName, dbName, tableName, slot.getName()); + if (dataMaskPolicy.isPresent()) { + Expression unboundExpr = nereidsParser.parseExpression(dataMaskPolicy.get().getMaskTypeDef()); + Expression childOfAlias + = unboundExpr instanceof UnboundAlias ? unboundExpr.child(0) : unboundExpr; + Alias alias = new Alias( + StatementScopeIdGenerator.newExprId(), + ImmutableList.of(childOfAlias), + slot.getName(), slot.getQualifier(), false + ); + dataMasks.add(alias); + hasDataMask = true; + } else { + dataMasks.add(slot); + } } - return Optional.ofNullable(mergeRowPolicy(policies)); + + List policies = accessManager.evalRowFilterPolicies( + currentUserIdentity, ctlName, dbName, tableName); + return new RelatedPolicy( + Optional.ofNullable(CollectionUtils.isEmpty(policies) ? null : mergeRowPolicy(policies)), + hasDataMask ? Optional.of(dataMasks.build()) : Optional.empty() + ); } private Expression mergeRowPolicy(List policies) { @@ -172,4 +206,17 @@ private Expression mergeRowPolicy(List policies) { return null; } } + + /** RelatedPolicy */ + public static class RelatedPolicy { + public static final RelatedPolicy NO_POLICY = new RelatedPolicy(Optional.empty(), Optional.empty()); + + public final Optional rowPolicyFilter; + public final Optional> dataMaskProjects; + + public RelatedPolicy(Optional rowPolicyFilter, Optional> dataMaskProjects) { + this.rowPolicyFilter = rowPolicyFilter; + this.dataMaskProjects = dataMaskProjects; + } + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java index 33e5ff72d912ca..a20e40cdc67585 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/ConnectContext.java @@ -1085,18 +1085,25 @@ public String getCloudCluster(boolean updateErr) { } String cluster = null; + String choseWay = null; if (!Strings.isNullOrEmpty(this.cloudCluster)) { cluster = this.cloudCluster; + choseWay = "use @cluster"; + LOG.debug("finally set context cluster name {} for user {} with chose way '{}'", + cloudCluster, getCurrentUserIdentity(), choseWay); + return cluster; } String defaultCluster = getDefaultCloudCluster(); if (!Strings.isNullOrEmpty(defaultCluster)) { cluster = defaultCluster; - } - - String authorizedCluster = getAuthorizedCloudCluster(); - if (!Strings.isNullOrEmpty(authorizedCluster)) { - cluster = authorizedCluster; + choseWay = "default cluster"; + } else { + String authorizedCluster = getAuthorizedCloudCluster(); + if (!Strings.isNullOrEmpty(authorizedCluster)) { + cluster = authorizedCluster; + choseWay = "authorized cluster"; + } } if (Strings.isNullOrEmpty(cluster)) { @@ -1107,7 +1114,8 @@ public String getCloudCluster(boolean updateErr) { } } else { this.cloudCluster = cluster; - LOG.info("finally set context cluster name {} for user {}", cloudCluster, getCurrentUserIdentity()); + LOG.info("finally set context cluster name {} for user {} with chose way '{}'", + cloudCluster, getCurrentUserIdentity(), choseWay); } return cluster; diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/privileges/TestCheckPrivileges.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/privileges/TestCheckPrivileges.java index 9b88ab57ea14e2..5ad41d7e6b3be0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/privileges/TestCheckPrivileges.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/privileges/TestCheckPrivileges.java @@ -33,7 +33,18 @@ import org.apache.doris.mysql.privilege.PrivPredicate; import org.apache.doris.mysql.privilege.RowFilterPolicy; import org.apache.doris.nereids.exceptions.AnalysisException; +import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.pattern.GeneratedMemoPatterns; +import org.apache.doris.nereids.rules.RulePromise; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.scalar.Concat; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.policy.FilterType; import org.apache.doris.utframe.TestWithFeService; import com.google.common.collect.ImmutableList; @@ -50,8 +61,10 @@ import java.util.Map; import java.util.Optional; import java.util.Set; +import java.util.function.Function; +import java.util.stream.Collectors; -public class TestCheckPrivileges extends TestWithFeService { +public class TestCheckPrivileges extends TestWithFeService implements GeneratedMemoPatterns { private static final Map>> CATALOG_META = ImmutableMap.of( "test_db", ImmutableMap.of( "test_tbl1", ImmutableList.of( @@ -65,12 +78,16 @@ public class TestCheckPrivileges extends TestWithFeService { "test_tbl3", ImmutableList.of( new Column("id", PrimitiveType.INT), new Column("name", PrimitiveType.VARCHAR) + ), + "test_tbl4", ImmutableList.of( + new Column("id", PrimitiveType.INT), + new Column("name", PrimitiveType.VARCHAR) ) ) ); @Test - public void testColumnPrivileges() throws Exception { + public void testPrivilegesAndPolicies() throws Exception { FeConstants.runningUnitTest = true; String catalogProvider = "org.apache.doris.nereids.privileges.TestCheckPrivileges$CustomCatalogProvider"; @@ -90,6 +107,7 @@ public void testColumnPrivileges() throws Exception { String table1 = "test_tbl1"; String table2 = "test_tbl2"; String table3 = "test_tbl3"; + String table4 = "test_tbl4"; String view1 = "query_tbl2_view1"; createView("create view " + internalDb + "." @@ -119,7 +137,12 @@ public void testColumnPrivileges() throws Exception { .allowSelectColumns(user, ImmutableSet.of("name")), MakePrivileges.table("internal", internalDb, view4) - .allowSelectColumns(user, ImmutableSet.of("id")) + .allowSelectColumns(user, ImmutableSet.of("id")), + + // data masking and row policy + MakePrivileges.table(catalog, db, table4).allowSelectTable(user) + .addRowPolicy(user, "id = 1") + .addDataMasking(user, "id", "concat(id, '_****_', id)") ); AccessControllerManager accessManager = Env.getCurrentEnv().getAccessManager(); @@ -182,6 +205,64 @@ public void testColumnPrivileges() throws Exception { query("select name from " + internalDb + "." + view4) ); } + + // test row policy with data masking + { + Function checkId = (NamedExpression ne) -> { + if (!(ne instanceof Alias) || !ne.getName().equals("id")) { + return false; + } + return ne.child(0) instanceof Concat; + }; + PlanChecker.from(connectContext) + .parse("select id," + + " test_tbl4.id," + + " test_db.test_tbl4.id, " + + " custom_catalog.test_db.test_tbl4.id, " + + " * " + + "from custom_catalog.test_db.test_tbl4") + .analyze() + .rewrite() + .matches(logicalProject( + logicalFilter( + logicalTestScan() + ).when(f -> { + EqualTo predicate = (EqualTo) f.getPredicate(); + return predicate.left() instanceof Slot + && predicate.right().equals(new IntegerLiteral((byte) 1)); + }) + ).when(p -> { + List projects = p.getProjects(); + if (!checkId.apply(projects.get(0)) || !checkId.apply(projects.get(1)) + || !checkId.apply(projects.get(2)) || !checkId.apply(projects.get(3)) + || !checkId.apply(projects.get(4))) { + return false; + } + return projects.get(5) instanceof Slot && projects.get(5).getName().equals("name"); + })); + + PlanChecker.from(connectContext) + .parse("select id, t.id, *" + + "from custom_catalog.test_db.test_tbl4 t") + .analyze() + .rewrite() + .matches(logicalProject( + logicalFilter( + logicalTestScan() + ).when(f -> { + EqualTo predicate = (EqualTo) f.getPredicate(); + return predicate.left() instanceof Slot + && predicate.right().equals(new IntegerLiteral((byte) 1)); + }) + ).when(p -> { + List projects = p.getProjects(); + if (!checkId.apply(projects.get(0)) || !checkId.apply(projects.get(1)) + || !checkId.apply(projects.get(2))) { + return false; + } + return projects.get(3) instanceof Slot && projects.get(3).getName().equals("name"); + })); + } }); } @@ -195,23 +276,36 @@ private void query(String sql) { private void withPrivileges(List privileges, Runnable task) { List tablePrivileges = Lists.newArrayList(); List columnPrivileges = Lists.newArrayList(); + List rowPolicies = Lists.newArrayList(); + List dataMaskingPolicies = Lists.newArrayList(); for (MakeTablePrivileges privilege : privileges) { tablePrivileges.addAll(privilege.tablePrivileges); columnPrivileges.addAll(privilege.columnPrivileges); + rowPolicies.addAll(privilege.rowPolicies); + dataMaskingPolicies.addAll(privilege.dataMaskingPolicies); } SimpleCatalogAccessController.tablePrivileges.set(tablePrivileges); SimpleCatalogAccessController.columnPrivileges.set(columnPrivileges); + SimpleCatalogAccessController.rowPolicies.set(rowPolicies); + SimpleCatalogAccessController.dataMaskings.set(dataMaskingPolicies); try { task.run(); } finally { + SimpleCatalogAccessController.rowPolicies.remove(); + SimpleCatalogAccessController.dataMaskings.remove(); SimpleCatalogAccessController.tablePrivileges.remove(); SimpleCatalogAccessController.columnPrivileges.remove(); } } + @Override + public RulePromise defaultPromise() { + return RulePromise.REWRITE; + } + public static class CustomCatalogProvider implements TestCatalogProvider { @Override @@ -230,6 +324,8 @@ public CatalogAccessController createAccessController(Map prop) public static class SimpleCatalogAccessController implements CatalogAccessController { private static ThreadLocal> tablePrivileges = new ThreadLocal<>(); private static ThreadLocal> columnPrivileges = new ThreadLocal<>(); + private static ThreadLocal> rowPolicies = new ThreadLocal<>(); + private static ThreadLocal> dataMaskings = new ThreadLocal<>(); @Override public boolean checkGlobalPriv(UserIdentity currentUser, PrivPredicate wanted) { @@ -312,13 +408,40 @@ public boolean checkCloudPriv(UserIdentity currentUser, String resourceName, Pri @Override public Optional evalDataMaskPolicy(UserIdentity currentUser, String ctl, String db, String tbl, String col) { + List dataMaskingPolicies = dataMaskings.get(); + if (dataMaskingPolicies == null) { + return Optional.empty(); + } + + for (CustomDataMaskingPolicy dataMaskingPolicy : dataMaskingPolicies) { + if (dataMaskingPolicy.column.equalsIgnoreCase(col)) { + return Optional.of(dataMaskingPolicy); + } + } return Optional.empty(); } @Override public List evalRowFilterPolicies(UserIdentity currentUser, String ctl, String db, String tbl) { - return Lists.newArrayList(); + List customRowPolicies = rowPolicies.get(); + if (customRowPolicies == null) { + return ImmutableList.of(); + } + NereidsParser nereidsParser = new NereidsParser(); + return customRowPolicies.stream() + .map(p -> new RowFilterPolicy() { + @Override + public Expression getFilterExpression() { + return nereidsParser.parseExpression(p.filter); + } + + @Override + public String getPolicyIdent() { + return "custom policy: " + p.filter; + } + }) + .collect(Collectors.toList()); } } @@ -335,6 +458,8 @@ private static class MakeTablePrivileges { private List tablePrivileges; private List columnPrivileges; + private List rowPolicies; + private List dataMaskingPolicies; public MakeTablePrivileges(String catalog, String db, String table) { this.catalog = catalog; @@ -342,6 +467,8 @@ public MakeTablePrivileges(String catalog, String db, String table) { this.table = table; this.tablePrivileges = Lists.newArrayList(); this.columnPrivileges = Lists.newArrayList(); + this.rowPolicies = Lists.newArrayList(); + this.dataMaskingPolicies = Lists.newArrayList(); } public MakeTablePrivileges allowSelectTable(String user) { @@ -353,6 +480,16 @@ public MakeTablePrivileges allowSelectColumns(String user, Set allowColu columnPrivileges.add(new ColumnPrivilege(catalog, db, table, user, allowColumns)); return this; } + + public MakeTablePrivileges addRowPolicy(String user, String filter) { + rowPolicies.add(new CustomRowPolicy(user, filter)); + return this; + } + + public MakeTablePrivileges addDataMasking(String user, String column, String project) { + dataMaskingPolicies.add(new CustomDataMaskingPolicy(user, column, project)); + return this; + } } private static class TablePrivilege { @@ -409,4 +546,59 @@ public boolean isSameTable(String catalog, String db, String tbl) { && StringUtils.equals(this.table, tbl); } } + + private static class CustomRowPolicy implements RowFilterPolicy { + private final String user; + private final String filter; + + public CustomRowPolicy(String user, String filter) { + this.user = user; + this.filter = filter; + } + + public String getUser() { + return user; + } + + @Override + public Expression getFilterExpression() { + return new NereidsParser().parseExpression(filter); + } + + @Override + public String getPolicyIdent() { + return "custom policy: " + filter; + } + + @Override + public FilterType getFilterType() { + return FilterType.PERMISSIVE; + } + } + + private static class CustomDataMaskingPolicy implements DataMaskPolicy { + private final String user; + private final String column; + private final String project; + + public CustomDataMaskingPolicy(String user, String name, String project) { + this.user = user; + this.column = name; + this.project = project; + } + + public String getUser() { + return user; + } + + @Override + public String getMaskTypeDef() { + return project; + } + + @Override + public String getPolicyIdent() { + return "custom policy: " + project; + } + } } diff --git a/regression-test/data/nereids_syntax_p0/bind_priority.out b/regression-test/data/nereids_syntax_p0/bind_priority.out index 7cd1e3ea5696e4..eb4002d960a676 100644 --- a/regression-test/data/nereids_syntax_p0/bind_priority.out +++ b/regression-test/data/nereids_syntax_p0/bind_priority.out @@ -36,3 +36,43 @@ all 2 4 5 6 6 +-- !having_bind_child -- +1 10 + +-- !having_bind_child2 -- +2 10 + +-- !having_bind_child3 -- +2 10 + +-- !having_bind_project -- +2 10 + +-- !having_bind_project2 -- + +-- !having_bind_project3 -- + +-- !having_bind_project4 -- +2 11 + +-- !having_bind_child4 -- +2 11 + +-- !having_bind_child5 -- +2 11 + +-- !having_bind_agg_fun -- + +-- !having_bind_agg_fun -- +2 4 +3 3 + +-- !having_bind_group_by -- +7 3 + +-- !having_bind_group_by -- +7 3 + +-- !having_bind_group_by -- +7 3 + diff --git a/regression-test/suites/nereids_syntax_p0/bind_priority.groovy b/regression-test/suites/nereids_syntax_p0/bind_priority.groovy index 072587ed28eee4..4e1740061b6375 100644 --- a/regression-test/suites/nereids_syntax_p0/bind_priority.groovy +++ b/regression-test/suites/nereids_syntax_p0/bind_priority.groovy @@ -31,7 +31,7 @@ suite("bind_priority") { sql """ insert into bind_priority_tbl values(1, 2),(3, 4) """ - + sql "SET enable_nereids_planner=true" sql "SET enable_fallback_to_original_planner=false" @@ -100,17 +100,17 @@ suite("bind_priority") { ); """ sql "insert into bind_priority_tbl2 values(3,5),(2, 6),(1,4);" - + qt_bind_order_to_project_alias """ select bind_priority_tbl.b b, bind_priority_tbl2.b - from bind_priority_tbl join bind_priority_tbl2 on bind_priority_tbl.a=bind_priority_tbl2.a + from bind_priority_tbl join bind_priority_tbl2 on bind_priority_tbl.a=bind_priority_tbl2.a order by b; """ qt_bind_order_to_project_alias """ select bind_priority_tbl.b, bind_priority_tbl2.b b - from bind_priority_tbl join bind_priority_tbl2 on bind_priority_tbl.a=bind_priority_tbl2.a + from bind_priority_tbl join bind_priority_tbl2 on bind_priority_tbl.a=bind_priority_tbl2.a order by b; """ @@ -148,11 +148,144 @@ suite("bind_priority") { ) a ), tb2 as ( - select * from tb1 + select * from tb1 ) select * from tb2 order by id; """) result([[1], [2], [3]]) } + + def testBindHaving = { + sql "drop table if exists test_bind_having_slots" + + sql """create table test_bind_having_slots + (id int, age int) + distributed by hash(id) + properties('replication_num'='1'); + """ + sql "insert into test_bind_having_slots values(1, 10), (2, 20), (3, 30);" + + order_qt_having_bind_child """ + select id, sum(age) + from test_bind_having_slots s + group by id + having id = 1; -- bind id from group by + """ + + order_qt_having_bind_child2 """ + select id + 1 as id, sum(age) + from test_bind_having_slots s + group by id + having id = 1; -- bind id from group by + """ + + order_qt_having_bind_child3 """ + select id + 1 as id, sum(age) + from test_bind_having_slots s + group by id + having id + 1 = 2; -- bind id from group by + """ + + order_qt_having_bind_project """ + select id + 1 as id, sum(age) + from test_bind_having_slots s + group by id + 1 + having id = 2; -- bind id from project + """ + + order_qt_having_bind_project2 """ + select id + 1 as id, sum(age) + from test_bind_having_slots s + group by id + 1 + having id + 1 = 2; -- bind id from project + """ + + order_qt_having_bind_project3 """ + select id + 1 as id, sum(age + 1) as age + from test_bind_having_slots s + group by id + having age = 10; -- bind age from project + """ + + order_qt_having_bind_project4 """ + select id + 1 as id, sum(age + 1) as age + from test_bind_having_slots s + group by id + having age = 11; -- bind age from project + """ + + order_qt_having_bind_child4 """ + select id + 1 as id, sum(age + 1) as age + from test_bind_having_slots s + group by id + having sum(age) = 10; -- bind age from s + """ + + order_qt_having_bind_child5 """ + select id + 1 as id, sum(age + 1) as age + from test_bind_having_slots s + group by id + having sum(age + 1) = 11 -- bind age from s + """ + + + + + sql "drop table if exists test_bind_having_slots2" + sql """create table test_bind_having_slots2 + (id int) + distributed by hash(id) + properties('replication_num'='1'); + """ + sql "insert into test_bind_having_slots2 values(1), (2), (3), (2);" + + order_qt_having_bind_agg_fun """ + select id, abs(sum(id)) as id + from test_bind_having_slots2 + group by id + having sum(id) + id >= 7 + """ + + order_qt_having_bind_agg_fun """ + select id, abs(sum(id)) as id + from test_bind_having_slots2 + group by id + having sum(id) + id >= 6 + """ + + + + + + sql "drop table if exists test_bind_having_slots3" + + sql """CREATE TABLE `test_bind_having_slots3`(pk int, pk2 int) + DUPLICATE KEY(`pk`) + DISTRIBUTED BY HASH(`pk`) BUCKETS 10 + properties('replication_num'='1'); + """ + sql "insert into test_bind_having_slots3 values(1, 1), (2, 2), (2, 2), (3, 3), (3, 3), (3, 3);" + + order_qt_having_bind_group_by """ + SELECT pk + 6 as ps, COUNT(pk ) * 3 as pk + FROM test_bind_having_slots3 tbl_alias1 + GROUP by pk + HAVING pk = 1 + """ + + order_qt_having_bind_group_by """ + SELECT pk + 6 as pk, COUNT(pk ) * 3 as pk + FROM test_bind_having_slots3 tbl_alias1 + GROUP by pk + 6 + HAVING pk = 7 + """ + + order_qt_having_bind_group_by """ + SELECT pk + 6, COUNT(pk ) * 3 as pk + FROM test_bind_having_slots3 tbl_alias1 + GROUP by pk + 6 + HAVING pk = 3 + """ + }() }