SubExprAnalyzer.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.analysis;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Exists;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InSubquery;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.ScalarSubquery;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SubqueryExpr;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.PlanType;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;

import java.util.ArrayDeque;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Objects;
import java.util.Set;

/**
 * Use the visitor to iterate sub expression.
 */
class SubExprAnalyzer<T> extends DefaultExpressionRewriter<T> {
    private final Scope scope;
    private final CascadesContext cascadesContext;

    public SubExprAnalyzer(Scope scope, CascadesContext cascadesContext) {
        this.scope = scope;
        this.cascadesContext = cascadesContext;
    }

    @Override
    public Expression visitNot(Not not, T context) {
        Expression child = not.child();
        if (child instanceof Exists) {
            return visitExistsSubquery(
                    new Exists(((Exists) child).getQueryPlan(), true), context);
        } else if (child instanceof InSubquery) {
            return visitInSubquery(new InSubquery(((InSubquery) child).getCompareExpr(),
                    ((InSubquery) child).getQueryPlan(), true), context);
        }
        return visit(not, context);
    }

    @Override
    public Expression visitExistsSubquery(Exists exists, T context) {
        LogicalPlan queryPlan = exists.getQueryPlan();
        // distinct is useless, remove it
        if (queryPlan instanceof LogicalProject && ((LogicalProject) queryPlan).isDistinct()) {
            exists = exists.withSubquery(((LogicalProject) queryPlan).withDistinct(false));
        }
        AnalyzedResult analyzedResult = analyzeSubquery(exists);
        if (analyzedResult.rootIsLimitZero()) {
            return BooleanLiteral.of(exists.isNot());
        }
        if (analyzedResult.isCorrelated() && analyzedResult.rootIsLimitWithOffset()) {
            throw new AnalysisException("Unsupported correlated subquery with a LIMIT clause with offset > 0 "
                    + analyzedResult.getLogicalPlan());
        }
        return new Exists(analyzedResult.getLogicalPlan(),
                analyzedResult.getCorrelatedSlots(), exists.isNot());
    }

    @Override
    public Expression visitInSubquery(InSubquery expr, T context) {
        LogicalPlan queryPlan = expr.getQueryPlan();
        // distinct is useless, remove it
        if (queryPlan instanceof LogicalProject && ((LogicalProject) queryPlan).isDistinct()) {
            expr = expr.withSubquery(((LogicalProject) queryPlan).withDistinct(false));
        }
        AnalyzedResult analyzedResult = analyzeSubquery(expr);

        checkOutputColumn(analyzedResult.getLogicalPlan());
        checkNoCorrelatedSlotsUnderAgg(analyzedResult);
        checkRootIsLimit(analyzedResult);

        return new InSubquery(
                expr.getCompareExpr().accept(this, context),
                analyzedResult.getLogicalPlan(),
                analyzedResult.getCorrelatedSlots(), expr.isNot());
    }

    @Override
    public Expression visitScalarSubquery(ScalarSubquery scalar, T context) {
        AnalyzedResult analyzedResult = analyzeSubquery(scalar);
        boolean isCorrelated = analyzedResult.isCorrelated();
        LogicalPlan analyzedSubqueryPlan = analyzedResult.logicalPlan;
        checkOutputColumn(analyzedSubqueryPlan);
        // use limitOneIsEliminated to indicate if subquery has limit 1 clause
        // because limit 1 clause will ensure subquery output at most 1 row
        // we eliminate limit 1 clause and pass this info to later SubqueryToApply rule
        // so when creating LogicalApply node, we don't need to add AssertTrue function
        boolean limitOneIsEliminated = false;
        if (isCorrelated) {
            if (analyzedSubqueryPlan instanceof LogicalLimit) {
                LogicalLimit limit = (LogicalLimit) analyzedSubqueryPlan;
                if (limit.getOffset() == 0 && limit.getLimit() == 1) {
                    // skip useless limit node
                    analyzedResult = new AnalyzedResult((LogicalPlan) analyzedSubqueryPlan.child(0),
                            analyzedResult.correlatedSlots);
                    limitOneIsEliminated = true;
                } else {
                    throw new AnalysisException("limit is not supported in correlated subquery "
                            + analyzedResult.getLogicalPlan());
                }
            }
            if (analyzedSubqueryPlan instanceof LogicalSort) {
                // skip useless sort node
                analyzedResult = new AnalyzedResult((LogicalPlan) analyzedSubqueryPlan.child(0),
                        analyzedResult.correlatedSlots);
            }
            CorrelatedSlotsValidator validator =
                    new CorrelatedSlotsValidator(ImmutableSet.copyOf(analyzedResult.correlatedSlots));
            List<PlanNodeCorrelatedInfo> nodeInfoList = new ArrayList<>(16);
            Set<LogicalAggregate> topAgg = new HashSet<>();
            validateSubquery(analyzedResult.logicalPlan, validator, nodeInfoList, topAgg);
        }

        if (analyzedResult.getLogicalPlan() instanceof LogicalProject) {
            LogicalProject project = (LogicalProject) analyzedResult.getLogicalPlan();
            if (project.child() instanceof LogicalOneRowRelation
                    && project.getProjects().size() == 1
                    && project.getProjects().get(0) instanceof Alias) {
                // if scalar subquery is like select '2024-02-02 00:00:00'
                // we can just return the constant expr '2024-02-02 00:00:00'
                Alias alias = (Alias) project.getProjects().get(0);
                if (alias.isConstant()) {
                    return alias.child();
                }
            } else if (isCorrelated) {
                Set<Slot> correlatedSlots = new HashSet<>(analyzedResult.getCorrelatedSlots());
                if (!Sets.intersection(ExpressionUtils.getInputSlotSet(project.getProjects()),
                        correlatedSlots).isEmpty()) {
                    throw new AnalysisException(
                            "outer query's column is not supported in subquery's output "
                                    + analyzedResult.getLogicalPlan());
                }
            }
        }

        return new ScalarSubquery(analyzedResult.getLogicalPlan(), analyzedResult.getCorrelatedSlots(),
                limitOneIsEliminated);
    }

    private void checkOutputColumn(LogicalPlan plan) {
        if (plan.getOutput().size() != 1) {
            throw new AnalysisException("Multiple columns returned by subquery are not yet supported. Found "
                    + plan.getOutput().size());
        }
    }

    private void checkNoCorrelatedSlotsUnderAgg(AnalyzedResult analyzedResult) {
        if (analyzedResult.hasCorrelatedSlotsUnderAgg()) {
            throw new AnalysisException(
                    "Unsupported correlated subquery with grouping and/or aggregation "
                            + analyzedResult.getLogicalPlan());
        }
    }

    private void checkRootIsLimit(AnalyzedResult analyzedResult) {
        if (!analyzedResult.isCorrelated()) {
            return;
        }
        if (analyzedResult.rootIsLimit()) {
            throw new AnalysisException("Unsupported correlated subquery with a LIMIT clause "
                    + analyzedResult.getLogicalPlan());
        }
    }

    private AnalyzedResult analyzeSubquery(SubqueryExpr expr) {
        if (cascadesContext == null) {
            throw new IllegalStateException("Missing CascadesContext");
        }
        CascadesContext subqueryContext = CascadesContext.newContextWithCteContext(
                cascadesContext, expr.getQueryPlan(), cascadesContext.getCteContext());
        // don't use `getScope()` because we only need `getScope().getOuterScope()` and `getScope().getSlots()`
        // otherwise unexpected errors may occur
        Scope subqueryScope = new Scope(getScope().getOuterScope(),
                getScope().getSlots(), getScope().getAsteriskSlots());
        subqueryContext.setOuterScope(subqueryScope);
        subqueryContext.newAnalyzer().analyze();
        return new AnalyzedResult((LogicalPlan) subqueryContext.getRewritePlan(),
                subqueryScope.getCorrelatedSlots());
    }

    public Scope getScope() {
        return scope;
    }

    public CascadesContext getCascadesContext() {
        return cascadesContext;
    }

    private static class AnalyzedResult {
        private final LogicalPlan logicalPlan;
        private final List<Slot> correlatedSlots;

        public AnalyzedResult(LogicalPlan logicalPlan, Collection<Slot> correlatedSlots) {
            this.logicalPlan = Objects.requireNonNull(logicalPlan, "logicalPlan can not be null");
            this.correlatedSlots = correlatedSlots == null ? new ArrayList<>() : ImmutableList.copyOf(correlatedSlots);
        }

        public LogicalPlan getLogicalPlan() {
            return logicalPlan;
        }

        public List<Slot> getCorrelatedSlots() {
            return correlatedSlots;
        }

        public boolean isCorrelated() {
            return !correlatedSlots.isEmpty();
        }

        public boolean hasCorrelatedSlotsUnderAgg() {
            return correlatedSlots.isEmpty() ? false
                    : hasCorrelatedSlotsUnderNode(logicalPlan,
                            ImmutableSet.copyOf(correlatedSlots), LogicalAggregate.class);
        }

        private static <T> boolean hasCorrelatedSlotsUnderNode(Plan rootPlan,
                                                               ImmutableSet<Slot> slots, Class<T> clazz) {
            ArrayDeque<Plan> planQueue = new ArrayDeque<>();
            planQueue.add(rootPlan);
            while (!planQueue.isEmpty()) {
                Plan plan = planQueue.poll();
                if (plan.getClass().equals(clazz)) {
                    if (plan.containsSlots(slots)) {
                        return true;
                    }
                } else {
                    for (Plan child : plan.children()) {
                        planQueue.add(child);
                    }
                }
            }
            return false;
        }

        public boolean rootIsLimit() {
            return logicalPlan instanceof LogicalLimit;
        }

        public boolean rootIsLimitWithOffset() {
            return logicalPlan instanceof LogicalLimit && ((LogicalLimit<?>) logicalPlan).getOffset() != 0;
        }

        public boolean rootIsLimitZero() {
            return logicalPlan instanceof LogicalLimit && ((LogicalLimit<?>) logicalPlan).getLimit() == 0;
        }
    }

    private static class PlanNodeCorrelatedInfo {
        private PlanType planType;
        private boolean containCorrelatedSlots;
        private boolean hasGroupBy;
        private LogicalAggregate aggregate;

        public PlanNodeCorrelatedInfo(PlanType planType, boolean containCorrelatedSlots) {
            this(planType, containCorrelatedSlots, null);
        }

        public PlanNodeCorrelatedInfo(PlanType planType, boolean containCorrelatedSlots,
                LogicalAggregate aggregate) {
            this.planType = planType;
            this.containCorrelatedSlots = containCorrelatedSlots;
            this.aggregate = aggregate;
            this.hasGroupBy = aggregate != null ? !aggregate.getGroupByExpressions().isEmpty() : false;
        }
    }

    private static class CorrelatedSlotsValidator
            extends PlanVisitor<PlanNodeCorrelatedInfo, Void> {
        private final ImmutableSet<Slot> correlatedSlots;

        public CorrelatedSlotsValidator(ImmutableSet<Slot> correlatedSlots) {
            this.correlatedSlots = correlatedSlots;
        }

        @Override
        public PlanNodeCorrelatedInfo visit(Plan plan, Void context) {
            return new PlanNodeCorrelatedInfo(plan.getType(), findCorrelatedSlots(plan));
        }

        public PlanNodeCorrelatedInfo visitLogicalProject(LogicalProject plan, Void context) {
            boolean containCorrelatedSlots = findCorrelatedSlots(plan);
            if (containCorrelatedSlots) {
                throw new AnalysisException(
                        String.format("access outer query's column in project is not supported",
                                correlatedSlots));
            } else {
                PlanType planType = ExpressionUtils.containsWindowExpression(
                        ((LogicalProject<?>) plan).getProjects()) ? PlanType.LOGICAL_WINDOW : plan.getType();
                return new PlanNodeCorrelatedInfo(planType, false);
            }
        }

        public PlanNodeCorrelatedInfo visitLogicalAggregate(LogicalAggregate plan, Void context) {
            boolean containCorrelatedSlots = findCorrelatedSlots(plan);
            if (containCorrelatedSlots) {
                throw new AnalysisException(
                        String.format("access outer query's column in aggregate is not supported",
                                correlatedSlots, plan));
            } else {
                return new PlanNodeCorrelatedInfo(plan.getType(), false, plan);
            }
        }

        public PlanNodeCorrelatedInfo visitLogicalJoin(LogicalJoin plan, Void context) {
            boolean containCorrelatedSlots = findCorrelatedSlots(plan);
            if (containCorrelatedSlots) {
                throw new AnalysisException(
                        String.format("access outer query's column in join is not supported",
                                correlatedSlots, plan));
            } else {
                return new PlanNodeCorrelatedInfo(plan.getType(), false);
            }
        }

        public PlanNodeCorrelatedInfo visitLogicalSort(LogicalSort plan, Void context) {
            boolean containCorrelatedSlots = findCorrelatedSlots(plan);
            if (containCorrelatedSlots) {
                throw new AnalysisException(
                        String.format("access outer query's column in order by is not supported",
                                correlatedSlots, plan));
            } else {
                return new PlanNodeCorrelatedInfo(plan.getType(), false);
            }
        }

        private boolean findCorrelatedSlots(Plan plan) {
            return plan.getExpressions().stream().anyMatch(expression -> !Sets
                    .intersection(correlatedSlots, expression.getInputSlots()).isEmpty());
        }
    }

    private LogicalAggregate validateNodeInfoList(List<PlanNodeCorrelatedInfo> nodeInfoList) {
        LogicalAggregate topAggregate = null;
        int size = nodeInfoList.size();
        if (size > 0) {
            List<PlanNodeCorrelatedInfo> correlatedNodes = new ArrayList<>(4);
            boolean checkNodeTypeAfterCorrelatedNode = false;
            boolean checkAfterAggNode = false;
            for (int i = size - 1; i >= 0; --i) {
                PlanNodeCorrelatedInfo nodeInfo = nodeInfoList.get(i);
                if (checkNodeTypeAfterCorrelatedNode) {
                    switch (nodeInfo.planType) {
                        case LOGICAL_LIMIT:
                            throw new AnalysisException(
                                    "limit is not supported in correlated subquery");
                        case LOGICAL_GENERATE:
                            throw new AnalysisException(
                                    "access outer query's column before lateral view is not supported");
                        case LOGICAL_AGGREGATE:
                            if (checkAfterAggNode) {
                                throw new AnalysisException(
                                        "access outer query's column before two agg nodes is not supported");
                            }
                            if (nodeInfo.hasGroupBy) {
                                // TODO support later
                                throw new AnalysisException(
                                        "access outer query's column before agg with group by is not supported");
                            }
                            checkAfterAggNode = true;
                            topAggregate = nodeInfo.aggregate;
                            break;
                        case LOGICAL_WINDOW:
                            throw new AnalysisException(
                                    "access outer query's column before window function is not supported");
                        case LOGICAL_JOIN:
                            throw new AnalysisException(
                                    "access outer query's column before join is not supported");
                        case LOGICAL_SORT:
                            // allow any sort node, the sort node will be removed by ELIMINATE_ORDER_BY_UNDER_SUBQUERY
                            break;
                        case LOGICAL_PROJECT:
                            // allow any project node
                            break;
                        case LOGICAL_SUBQUERY_ALIAS:
                            // allow any subquery alias
                            break;
                        default:
                            if (checkAfterAggNode) {
                                throw new AnalysisException(
                                        "only project, sort and subquery alias node is allowed after agg node");
                            }
                            break;
                    }
                }
                if (nodeInfo.containCorrelatedSlots) {
                    correlatedNodes.add(nodeInfo);
                    checkNodeTypeAfterCorrelatedNode = true;
                }
            }

            // only support 1 correlated node for now
            if (correlatedNodes.size() > 1) {
                throw new AnalysisException(
                        "access outer query's column in two places is not supported");
            }
        }
        return topAggregate;
    }

    private void validateSubquery(Plan plan, CorrelatedSlotsValidator validator,
            List<PlanNodeCorrelatedInfo> nodeInfoList, Set<LogicalAggregate> topAgg) {
        nodeInfoList.add(plan.accept(validator, null));
        for (Plan child : plan.children()) {
            validateSubquery(child, validator, nodeInfoList, topAgg);
        }
        if (plan.children().isEmpty()) {
            LogicalAggregate topAggNode = validateNodeInfoList(nodeInfoList);
            if (topAggNode != null) {
                topAgg.add(topAggNode);
            }
        }
        nodeInfoList.remove(nodeInfoList.size() - 1);
    }
}