SubqueryToApply.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.analysis;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.TrySimplifyPredicateWithMarkJoinSlot;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.NotNullableAggregateFunction;
import org.apache.doris.nereids.trees.expressions.functions.scalar.AssertTrue;
import org.apache.doris.nereids.trees.expressions.functions.scalar.Nvl;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.VarcharLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalApply;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.Utils;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;
/**
* SubqueryToApply. translate from subquery to LogicalApply.
* In two steps
* The first step is to replace the predicate corresponding to the filter where the subquery is located.
* The second step converts the subquery into an apply node.
*/
public class SubqueryToApply implements AnalysisRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.FILTER_SUBQUERY_TO_APPLY.build(
logicalFilter().thenApply(ctx -> {
LogicalFilter<Plan> filter = ctx.root;
Set<Expression> conjuncts = filter.getConjuncts();
CollectSubquerys collectSubquerys = collectSubquerys(conjuncts);
if (!collectSubquerys.hasSubquery) {
return filter;
}
List<Boolean> shouldOutputMarkJoinSlot = shouldOutputMarkJoinSlot(conjuncts);
List<Expression> oldConjuncts = Utils.fastToImmutableList(conjuncts);
ImmutableSet.Builder<Expression> newConjuncts = new ImmutableSet.Builder<>();
LogicalPlan applyPlan = null;
LogicalPlan tmpPlan = (LogicalPlan) filter.child();
List<Set<SubqueryExpr>> subqueryExprsList = collectSubquerys.subqueies;
// Subquery traversal with the conjunct of and as the granularity.
for (int i = 0; i < subqueryExprsList.size(); ++i) {
Set<SubqueryExpr> subqueryExprs = subqueryExprsList.get(i);
if (subqueryExprs.isEmpty()) {
newConjuncts.add(oldConjuncts.get(i));
continue;
}
// first step: Replace the subquery of predicate in LogicalFilter
// second step: Replace subquery with LogicalApply
ReplaceSubquery replaceSubquery = new ReplaceSubquery(
ctx.statementContext, shouldOutputMarkJoinSlot.get(i));
SubqueryContext context = new SubqueryContext(subqueryExprs);
Expression conjunct = replaceSubquery.replace(oldConjuncts.get(i), context);
/*
* the idea is replacing each mark join slot with null and false literal
* then run FoldConstant rule, if the evaluate result are:
* 1. all true
* 2. all null and false (in logicalFilter, we discard both null and false values)
* the mark slot can be non-nullable boolean
* we pass this info to LogicalApply. And in InApplyToJoin rule
* if it's semi join with non-null mark slot
* we can safely change the mark conjunct to hash conjunct
*/
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext);
boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class)
? ExpressionUtils.canInferNotNullForMarkSlot(
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct,
rewriteContext), rewriteContext)
: false;
Pair<LogicalPlan, Optional<Expression>> result = subqueryToApply(subqueryExprs.stream()
.collect(ImmutableList.toImmutableList()), tmpPlan,
context.getSubqueryToMarkJoinSlot(),
ctx.cascadesContext,
Optional.of(conjunct), false, isMarkSlotNotNull);
applyPlan = result.first;
tmpPlan = applyPlan;
newConjuncts.add(result.second.isPresent() ? result.second.get() : conjunct);
}
Plan newFilter = new LogicalFilter<>(newConjuncts.build(), applyPlan);
return new LogicalProject<>(filter.getOutput().stream().collect(ImmutableList.toImmutableList()),
newFilter);
})
),
RuleType.PROJECT_SUBQUERY_TO_APPLY.build(logicalProject().thenApply(ctx -> {
LogicalProject<Plan> project = ctx.root;
List<NamedExpression> projects = project.getProjects();
CollectSubquerys collectSubquerys = collectSubquerys(projects);
if (!collectSubquerys.hasSubquery) {
return project;
}
List<Set<SubqueryExpr>> subqueryExprsList = collectSubquerys.subqueies;
List<NamedExpression> oldProjects = ImmutableList.copyOf(projects);
ImmutableList.Builder<NamedExpression> newProjects = new ImmutableList.Builder<>();
LogicalPlan childPlan = (LogicalPlan) project.child();
LogicalPlan applyPlan;
for (int i = 0; i < subqueryExprsList.size(); ++i) {
Set<SubqueryExpr> subqueryExprs = subqueryExprsList.get(i);
if (subqueryExprs.isEmpty()) {
newProjects.add(oldProjects.get(i));
continue;
}
// first step: Replace the subquery in logcialProject's project list
// second step: Replace subquery with LogicalApply
ReplaceSubquery replaceSubquery =
new ReplaceSubquery(ctx.statementContext, true);
SubqueryContext context = new SubqueryContext(subqueryExprs);
Expression newProject =
replaceSubquery.replace(oldProjects.get(i), context);
Pair<LogicalPlan, Optional<Expression>> result =
subqueryToApply(Utils.fastToImmutableList(subqueryExprs), childPlan,
context.getSubqueryToMarkJoinSlot(), ctx.cascadesContext,
Optional.of(newProject), true, false);
applyPlan = result.first;
childPlan = applyPlan;
newProjects.add(
result.second.isPresent() ? (NamedExpression) result.second.get()
: (NamedExpression) newProject);
}
return project.withProjectsAndChild(newProjects.build(), childPlan);
})),
RuleType.ONE_ROW_RELATION_SUBQUERY_TO_APPLY.build(logicalOneRowRelation()
.when(ctx -> ctx.getProjects().stream()
.anyMatch(project -> project.containsType(SubqueryExpr.class)))
.thenApply(ctx -> {
LogicalOneRowRelation oneRowRelation = ctx.root;
// create a LogicalProject node with the same project lists above LogicalOneRowRelation
// create a LogicalOneRowRelation with a dummy output column
// so PROJECT_SUBQUERY_TO_APPLY rule can handle the subquery unnest thing
return new LogicalProject<Plan>(oneRowRelation.getProjects(),
oneRowRelation.withProjects(
ImmutableList.of(new Alias(BooleanLiteral.of(true),
ctx.statementContext.generateColumnName()))));
})),
RuleType.JOIN_SUBQUERY_TO_APPLY
.build(logicalJoin()
.when(join -> join.getHashJoinConjuncts().isEmpty() && !join.getOtherJoinConjuncts().isEmpty())
.thenApply(ctx -> {
LogicalJoin<Plan, Plan> join = ctx.root;
Map<Boolean, List<Expression>> joinConjuncts = join.getOtherJoinConjuncts().stream()
.collect(Collectors.groupingBy(conjunct -> conjunct.containsType(SubqueryExpr.class),
Collectors.toList()));
List<Expression> subqueryConjuncts = joinConjuncts.get(true);
if (subqueryConjuncts == null || subqueryConjuncts.stream()
.anyMatch(expr -> !isValidSubqueryConjunct(expr))) {
return join;
}
List<RelatedInfo> relatedInfoList = collectRelatedInfo(
subqueryConjuncts, join.left(), join.right());
if (relatedInfoList.stream().anyMatch(info -> info == RelatedInfo.UnSupported)) {
return join;
}
ImmutableList<Set<SubqueryExpr>> subqueryExprsList = subqueryConjuncts.stream()
.<Set<SubqueryExpr>>map(e -> e.collect(SubqueryExpr.class::isInstance))
.collect(ImmutableList.toImmutableList());
ImmutableList.Builder<Expression> newConjuncts = new ImmutableList.Builder<>();
LogicalPlan applyPlan;
LogicalPlan leftChildPlan = (LogicalPlan) join.left();
LogicalPlan rightChildPlan = (LogicalPlan) join.right();
// Subquery traversal with the conjunct of and as the granularity.
for (int i = 0; i < subqueryExprsList.size(); ++i) {
Set<SubqueryExpr> subqueryExprs = subqueryExprsList.get(i);
if (subqueryExprs.size() > 1) {
// only support the conjunct contains one subquery expr
return join;
}
// first step: Replace the subquery of predicate in LogicalFilter
// second step: Replace subquery with LogicalApply
ReplaceSubquery replaceSubquery = new ReplaceSubquery(ctx.statementContext, true);
SubqueryContext context = new SubqueryContext(subqueryExprs);
Expression conjunct = replaceSubquery.replace(subqueryConjuncts.get(i), context);
/*
* the idea is replacing each mark join slot with null and false literal
* then run FoldConstant rule, if the evaluate result are:
* 1. all true
* 2. all null and false (in logicalFilter, we discard both null and false values)
* the mark slot can be non-nullable boolean
* we pass this info to LogicalApply. And in InApplyToJoin rule
* if it's semi join with non-null mark slot
* we can safely change the mark conjunct to hash conjunct
*/
ExpressionRewriteContext rewriteContext = new ExpressionRewriteContext(ctx.cascadesContext);
boolean isMarkSlotNotNull = conjunct.containsType(MarkJoinSlotReference.class)
? ExpressionUtils.canInferNotNullForMarkSlot(
TrySimplifyPredicateWithMarkJoinSlot.INSTANCE.rewrite(conjunct, rewriteContext),
rewriteContext)
: false;
Pair<LogicalPlan, Optional<Expression>> result = subqueryToApply(
subqueryExprs.stream().collect(ImmutableList.toImmutableList()),
relatedInfoList.get(i) == RelatedInfo.RelatedToLeft ? leftChildPlan : rightChildPlan,
context.getSubqueryToMarkJoinSlot(),
ctx.cascadesContext, Optional.of(conjunct), false, isMarkSlotNotNull);
applyPlan = result.first;
if (relatedInfoList.get(i) == RelatedInfo.RelatedToLeft) {
leftChildPlan = applyPlan;
} else {
rightChildPlan = applyPlan;
}
newConjuncts.add(result.second.isPresent() ? result.second.get() : conjunct);
}
List<Expression> simpleConjuncts = joinConjuncts.get(false);
if (simpleConjuncts != null) {
newConjuncts.addAll(simpleConjuncts);
}
Plan newJoin = join.withConjunctsChildren(join.getHashJoinConjuncts(),
newConjuncts.build(), leftChildPlan, rightChildPlan, null);
return newJoin;
}))
);
}
private static boolean isValidSubqueryConjunct(Expression expression) {
// only support 1 subquery expr in the expression
// don't support expression like subquery1 or subquery2
return expression
.collectToList(SubqueryExpr.class::isInstance)
.size() == 1;
}
private enum RelatedInfo {
// both subquery and its output don't related to any child. like (select sum(t.a) from t) > 1
Unrelated,
// either subquery or its output only related to left child. like bellow:
// tableLeft.a in (select t.a from t)
// 3 in (select t.b from t where t.a = tableLeft.a)
// tableLeft.a > (select sum(t.a) from t where tableLeft.b = t.b)
RelatedToLeft,
// like above, but related to right child
RelatedToRight,
// subquery related to both left and child is not supported:
// tableLeft.a > (select sum(t.a) from t where t.b = tableRight.b)
UnSupported
}
private ImmutableList<RelatedInfo> collectRelatedInfo(List<Expression> subqueryConjuncts,
Plan leftChild, Plan rightChild) {
int size = subqueryConjuncts.size();
ImmutableList.Builder<RelatedInfo> correlatedInfoList = new ImmutableList.Builder<>();
Set<Slot> leftOutputSlots = leftChild.getOutputSet();
Set<Slot> rightOutputSlots = rightChild.getOutputSet();
for (int i = 0; i < size; ++i) {
Expression expression = subqueryConjuncts.get(i);
List<SubqueryExpr> subqueryExprs = expression.collectToList(SubqueryExpr.class::isInstance);
RelatedInfo relatedInfo = RelatedInfo.UnSupported;
if (subqueryExprs.size() == 1) {
SubqueryExpr subqueryExpr = subqueryExprs.get(0);
List<Slot> correlatedSlots = subqueryExpr.getCorrelateSlots();
if (subqueryExpr instanceof ScalarSubquery) {
Set<Slot> inputSlots = subqueryExpr.getInputSlots();
if (correlatedSlots.isEmpty() && inputSlots.isEmpty()) {
relatedInfo = RelatedInfo.Unrelated;
} else if (leftOutputSlots.containsAll(inputSlots)
&& leftOutputSlots.containsAll(correlatedSlots)) {
relatedInfo = RelatedInfo.RelatedToLeft;
} else if (rightOutputSlots.containsAll(inputSlots)
&& rightOutputSlots.containsAll(correlatedSlots)) {
relatedInfo = RelatedInfo.RelatedToRight;
}
} else if (subqueryExpr instanceof InSubquery) {
InSubquery inSubquery = (InSubquery) subqueryExpr;
Set<Slot> compareSlots = inSubquery.getCompareExpr().getInputSlots();
if (compareSlots.isEmpty()) {
relatedInfo = RelatedInfo.UnSupported;
} else if (leftOutputSlots.containsAll(compareSlots)
&& leftOutputSlots.containsAll(correlatedSlots)) {
relatedInfo = RelatedInfo.RelatedToLeft;
} else if (rightOutputSlots.containsAll(compareSlots)
&& rightOutputSlots.containsAll(correlatedSlots)) {
relatedInfo = RelatedInfo.RelatedToRight;
}
} else if (subqueryExpr instanceof Exists) {
if (correlatedSlots.isEmpty()) {
relatedInfo = RelatedInfo.Unrelated;
} else if (leftOutputSlots.containsAll(correlatedSlots)) {
relatedInfo = RelatedInfo.RelatedToLeft;
} else if (rightOutputSlots.containsAll(correlatedSlots)) {
relatedInfo = RelatedInfo.RelatedToRight;
}
}
}
correlatedInfoList.add(relatedInfo);
}
return correlatedInfoList.build();
}
private Pair<LogicalPlan, Optional<Expression>> subqueryToApply(
List<SubqueryExpr> subqueryExprs, LogicalPlan childPlan,
Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot,
CascadesContext ctx, Optional<Expression> conjunct, boolean isProject,
boolean isMarkJoinSlotNotNull) {
Pair<LogicalPlan, Optional<Expression>> tmpPlan = Pair.of(childPlan, conjunct);
for (int i = 0; i < subqueryExprs.size(); ++i) {
SubqueryExpr subqueryExpr = subqueryExprs.get(i);
if (subqueryExpr instanceof Exists && hasTopLevelScalarAgg(subqueryExpr.getQueryPlan())) {
// because top level scalar agg always returns a value or null(for empty input)
// so Exists and Not Exists conjunct are always evaluated to True and false literals respectively
// we don't create apply node for it
continue;
}
if (!ctx.subqueryIsAnalyzed(subqueryExpr)) {
tmpPlan = addApply(subqueryExpr, tmpPlan.first,
subqueryToMarkJoinSlot, ctx, tmpPlan.second,
isProject, subqueryExprs.size() == 1, isMarkJoinSlotNotNull);
}
}
return tmpPlan;
}
private static boolean hasTopLevelScalarAgg(Plan plan) {
if (plan instanceof LogicalAggregate) {
return ((LogicalAggregate) plan).getGroupByExpressions().isEmpty();
} else if (plan instanceof LogicalProject || plan instanceof LogicalSort) {
return hasTopLevelScalarAgg(plan.child(0));
}
return false;
}
private Pair<LogicalPlan, Optional<Expression>> addApply(SubqueryExpr subquery,
LogicalPlan childPlan,
Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot,
CascadesContext ctx, Optional<Expression> conjunct, boolean isProject,
boolean singleSubquery, boolean isMarkJoinSlotNotNull) {
ctx.setSubqueryExprIsAnalyzed(subquery, true);
Optional<MarkJoinSlotReference> markJoinSlot = subqueryToMarkJoinSlot.get(subquery);
boolean needAddScalarSubqueryOutputToProjects = isConjunctContainsScalarSubqueryOutput(
subquery, conjunct, isProject, singleSubquery);
// for scalar subquery, we need ensure it output at most 1 row
// by doing that, we add an aggregate function any_value() to the project list
// we use needRuntimeAnyValue to indicate if any_value() is needed
// if needRuntimeAnyValue is true, we will add it to the project list
boolean needRuntimeAnyValue = false;
NamedExpression oldSubqueryOutput = subquery.getQueryPlan().getOutput().get(0);
Slot countSlot = null;
Slot anyValueSlot = null;
Optional<Expression> newConjunct = conjunct;
if (needAddScalarSubqueryOutputToProjects && subquery instanceof ScalarSubquery
&& !subquery.getCorrelateSlots().isEmpty()) {
if (((ScalarSubquery) subquery).hasTopLevelScalarAgg()) {
// consider sql: SELECT * FROM t1 WHERE t1.a <= (SELECT COUNT(t2.a) FROM t2 WHERE (t1.b = t2.b));
// when unnest correlated subquery, we create a left join node.
// outer query is left table and subquery is right one
// if there is no match, the row from right table is filled with nulls
// but COUNT function is always not nullable.
// so wrap COUNT with Nvl to ensure its result is 0 instead of null to get the correct result
if (conjunct.isPresent()) {
Map<Expression, Expression> replaceMap = new HashMap<>();
NamedExpression agg = ((ScalarSubquery) subquery).getTopLevelScalarAggFunction().get();
if (agg instanceof Alias) {
if (((Alias) agg).child() instanceof NotNullableAggregateFunction) {
NotNullableAggregateFunction notNullableAggFunc =
(NotNullableAggregateFunction) ((Alias) agg).child();
if (subquery.getQueryPlan() instanceof LogicalProject) {
LogicalProject logicalProject =
(LogicalProject) subquery.getQueryPlan();
Preconditions.checkState(logicalProject.getOutputs().size() == 1,
"Scalar subuqery's should only output 1 column");
Slot aggSlot = agg.toSlot();
replaceMap.put(aggSlot, new Alias(new Nvl(aggSlot,
notNullableAggFunc.resultForEmptyInput())));
NamedExpression newOutput = (NamedExpression) ExpressionUtils
.replace((NamedExpression) logicalProject.getProjects().get(0), replaceMap);
replaceMap.clear();
replaceMap.put(oldSubqueryOutput, newOutput.toSlot());
oldSubqueryOutput = newOutput;
subquery = subquery.withSubquery((LogicalPlan) logicalProject.child());
} else {
replaceMap.put(oldSubqueryOutput, new Nvl(oldSubqueryOutput,
notNullableAggFunc.resultForEmptyInput()));
}
}
if (!replaceMap.isEmpty()) {
newConjunct = Optional.of(ExpressionUtils.replace(conjunct.get(), replaceMap));
}
}
}
} else {
// if scalar subquery doesn't have top level scalar agg we will create one, for example
// select (select t2.c1 from t2 where t2.c2 = t1.c2) from t1;
// the original output of the correlate subquery is t2.c1, after adding a scalar agg, it will be
// select (select count(*), any_value(t2.c1) from t2 where t2.c2 = t1.c2) from t1;
Alias anyValueAlias = new Alias(new AnyValue(oldSubqueryOutput));
LogicalAggregate<Plan> aggregate;
if (((ScalarSubquery) subquery).limitOneIsEliminated()) {
aggregate = new LogicalAggregate<>(ImmutableList.of(),
ImmutableList.of(anyValueAlias), subquery.getQueryPlan());
} else {
Alias countAlias = new Alias(new Count());
countSlot = countAlias.toSlot();
aggregate = new LogicalAggregate<>(ImmutableList.of(),
ImmutableList.of(countAlias, anyValueAlias), subquery.getQueryPlan());
}
anyValueSlot = anyValueAlias.toSlot();
subquery = subquery.withSubquery(aggregate);
if (conjunct.isPresent()) {
Map<Expression, Expression> replaceMap = new HashMap<>();
replaceMap.put(oldSubqueryOutput, anyValueSlot);
newConjunct = Optional.of(ExpressionUtils.replace(conjunct.get(), replaceMap));
}
needRuntimeAnyValue = true;
}
}
LogicalApply.SubQueryType subQueryType;
boolean isNot = false;
Optional<Expression> compareExpr = Optional.empty();
if (subquery instanceof InSubquery) {
subQueryType = LogicalApply.SubQueryType.IN_SUBQUERY;
isNot = ((InSubquery) subquery).isNot();
compareExpr = Optional.of(((InSubquery) subquery).getCompareExpr());
} else if (subquery instanceof Exists) {
subQueryType = LogicalApply.SubQueryType.EXITS_SUBQUERY;
isNot = ((Exists) subquery).isNot();
} else if (subquery instanceof ScalarSubquery) {
subQueryType = LogicalApply.SubQueryType.SCALAR_SUBQUERY;
} else {
throw new AnalysisException(String.format("Unsupported subquery : %s", subquery.toString()));
}
LogicalApply newApply = new LogicalApply(
subquery.getCorrelateSlots(),
subQueryType, isNot, compareExpr, subquery.getTypeCoercionExpr(), Optional.empty(),
markJoinSlot,
needAddScalarSubqueryOutputToProjects, isProject, isMarkJoinSlotNotNull,
childPlan, subquery.getQueryPlan());
ImmutableList.Builder<NamedExpression> projects =
ImmutableList.builderWithExpectedSize(childPlan.getOutput().size() + 3);
// left child
projects.addAll(childPlan.getOutput());
// markJoinSlotReference
markJoinSlot.map(projects::add);
LogicalProject logicalProject;
if (needAddScalarSubqueryOutputToProjects) {
if (needRuntimeAnyValue) {
// if we create a new subquery in previous step, we need add the any_value() and assert_true()
// into the project list. So BE will use assert_true to check if the subquery return only 1 row
projects.add(anyValueSlot);
if (countSlot != null) {
List<NamedExpression> upperProjects = new ArrayList<>();
upperProjects.addAll(projects.build());
projects.add(new Alias(new AssertTrue(
ExpressionUtils.or(new IsNull(countSlot),
new LessThanEqual(countSlot, new IntegerLiteral(1))),
new VarcharLiteral("correlate scalar subquery must return only 1 row"))));
logicalProject = new LogicalProject(projects.build(), newApply);
logicalProject = new LogicalProject(upperProjects, logicalProject);
} else {
logicalProject = new LogicalProject(projects.build(), newApply);
}
} else {
projects.add(oldSubqueryOutput);
logicalProject = new LogicalProject(projects.build(), newApply);
}
} else {
logicalProject = new LogicalProject(projects.build(), newApply);
}
return Pair.of(logicalProject, newConjunct);
}
private boolean isConjunctContainsScalarSubqueryOutput(
SubqueryExpr subqueryExpr, Optional<Expression> conjunct, boolean isProject, boolean singleSubquery) {
return subqueryExpr instanceof ScalarSubquery
&& ((conjunct.isPresent() && ((ImmutableSet) conjunct.get().collect(SlotReference.class::isInstance))
.contains(subqueryExpr.getQueryPlan().getOutput().get(0)))
|| isProject);
}
/**
* The Subquery in the LogicalFilter will change to LogicalApply, so we must replace the origin Subquery.
* LogicalFilter(predicate(contain subquery)) -> LogicalFilter(predicate(not contain subquery)
* Replace the subquery in logical with the relevant expression.
*
* The replacement rules are as follows:
* before:
* 1.filter(t1.a = scalarSubquery(output b));
* 2.filter(inSubquery); inSubquery = (t1.a in select ***);
* 3.filter(exists); exists = (select ***);
*
* after:
* 1.filter(t1.a = b);
* 2.isMarkJoin ? filter(MarkJoinSlotReference) : filter(True);
* 3.isMarkJoin ? filter(MarkJoinSlotReference) : filter(True);
*/
private static class ReplaceSubquery extends DefaultExpressionRewriter<SubqueryContext> {
private final StatementContext statementContext;
private boolean isMarkJoin;
private final boolean shouldOutputMarkJoinSlot;
public ReplaceSubquery(StatementContext statementContext,
boolean shouldOutputMarkJoinSlot) {
this.statementContext = Objects.requireNonNull(statementContext, "statementContext can't be null");
this.shouldOutputMarkJoinSlot = shouldOutputMarkJoinSlot;
}
public Expression replace(Expression expression, SubqueryContext subqueryContext) {
return expression.accept(this, subqueryContext);
}
@Override
public Expression visitExistsSubquery(Exists exists, SubqueryContext context) {
// The result set when NULL is specified in the subquery and still evaluates to TRUE by using EXISTS
// When the number of rows returned is empty, agg will return null, so if there is more agg,
// it will always consider the returned result to be true
if (hasTopLevelScalarAgg(exists.getQueryPlan())) {
/*
top level scalar agg and always return a value or null for empty input
so Exists and Not Exists conjunct are always evaluated to True and False literals respectively
SELECT *
FROM t1
WHERE EXISTS (
SELECT SUM(a)
FROM t2
WHERE t1.a = t2.b and t1.a = 1;
);
*/
return exists.isNot() ? BooleanLiteral.FALSE : BooleanLiteral.TRUE;
} else {
boolean needCreateMarkJoinSlot = isMarkJoin || shouldOutputMarkJoinSlot;
if (needCreateMarkJoinSlot) {
MarkJoinSlotReference markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName());
context.setSubqueryToMarkJoinSlot(exists, Optional.of(markJoinSlotReference));
return new Nvl(markJoinSlotReference, BooleanLiteral.FALSE);
} else {
return BooleanLiteral.TRUE;
}
}
}
@Override
public Expression visitInSubquery(InSubquery in, SubqueryContext context) {
MarkJoinSlotReference markJoinSlotReference =
new MarkJoinSlotReference(statementContext.generateColumnName());
boolean needCreateMarkJoinSlot = isMarkJoin || shouldOutputMarkJoinSlot;
if (needCreateMarkJoinSlot) {
context.setSubqueryToMarkJoinSlot(in, Optional.of(markJoinSlotReference));
}
return needCreateMarkJoinSlot ? markJoinSlotReference : BooleanLiteral.TRUE;
}
@Override
public Expression visitScalarSubquery(ScalarSubquery scalar, SubqueryContext context) {
return scalar.getSubqueryOutput();
}
@Override
public Expression visitCompoundPredicate(CompoundPredicate compound, SubqueryContext context) {
// update isMarkJoin flag
if (compound instanceof Or) {
for (Expression child : compound.children()) {
if (child.anyMatch(SubqueryExpr.class::isInstance)) {
isMarkJoin = true;
break;
}
}
}
return compound.withChildren(
compound.children().stream().map(c -> replace(c, context)).collect(Collectors.toList())
);
}
}
/**
* subqueryToMarkJoinSlot: The markJoinSlot corresponding to each subquery.
* rule:
* For inSubquery and exists: it will be directly replaced by markSlotReference
* e.g.
* logicalFilter(predicate=exists) ---> logicalFilter(predicate=$c$1)
* For scalarSubquery: it will be replaced by scalarSubquery's output slot
* e.g.
* logicalFilter(predicate=k1 > scalarSubquery) ---> logicalFilter(predicate=k1 > $c$1)
*
* subqueryCorrespondingConjunct: Record the conject corresponding to the subquery.
* rule:
*
*
*/
private static class SubqueryContext {
private final Map<SubqueryExpr, Optional<MarkJoinSlotReference>> subqueryToMarkJoinSlot;
public SubqueryContext(Set<SubqueryExpr> subqueryExprs) {
this.subqueryToMarkJoinSlot = new LinkedHashMap<>(subqueryExprs.size());
subqueryExprs.forEach(subqueryExpr -> subqueryToMarkJoinSlot.put(subqueryExpr, Optional.empty()));
}
private Map<SubqueryExpr, Optional<MarkJoinSlotReference>> getSubqueryToMarkJoinSlot() {
return subqueryToMarkJoinSlot;
}
private void setSubqueryToMarkJoinSlot(SubqueryExpr subquery,
Optional<MarkJoinSlotReference> markJoinSlotReference) {
subqueryToMarkJoinSlot.put(subquery, markJoinSlotReference);
}
}
private enum SearchState {
SearchNot,
SearchAnd,
SearchExistsOrInSubquery
}
private List<Boolean> shouldOutputMarkJoinSlot(Collection<Expression> conjuncts) {
ImmutableList.Builder<Boolean> result = ImmutableList.builderWithExpectedSize(conjuncts.size());
for (Expression expr : conjuncts) {
result.add(!(expr instanceof SubqueryExpr) && expr.containsType(SubqueryExpr.class));
}
return result.build();
}
private CollectSubquerys collectSubquerys(Collection<? extends Expression> exprs) {
boolean hasSubqueryExpr = false;
ImmutableList.Builder<Set<SubqueryExpr>> subqueryExprsListBuilder = ImmutableList.builder();
for (Expression expression : exprs) {
Set<SubqueryExpr> subqueries = expression.collect(SubqueryExpr.class::isInstance);
hasSubqueryExpr |= !subqueries.isEmpty();
subqueryExprsListBuilder.add(subqueries);
}
return new CollectSubquerys(subqueryExprsListBuilder.build(), hasSubqueryExpr);
}
private static class CollectSubquerys {
final List<Set<SubqueryExpr>> subqueies;
final boolean hasSubquery;
public CollectSubquerys(List<Set<SubqueryExpr>> subqueies, boolean hasSubquery) {
this.subqueies = subqueies;
this.hasSubquery = hasSubquery;
}
}
}