PlanUtils.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.util;

import org.apache.doris.analysis.Expr;
import org.apache.doris.analysis.SlotRef;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.analyzer.UnboundSlot;
import org.apache.doris.nereids.glue.translator.ExpressionTranslator;
import org.apache.doris.nereids.glue.translator.PlanTranslatorContext;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.WindowExpression;
import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalEmptyRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.qe.ConnectContext;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableList.Builder;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Lists;
import com.google.common.collect.Sets;
import org.apache.commons.collections.map.CaseInsensitiveMap;

import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Optional;
import java.util.Set;
import java.util.stream.Collectors;

/**
 * Util for plan
 */
public class PlanUtils {
    public static Optional<LogicalFilter<? extends Plan>> filter(Set<Expression> predicates, Plan plan) {
        if (predicates.isEmpty()) {
            return Optional.empty();
        }
        return Optional.of(new LogicalFilter<>(predicates, plan));
    }

    public static Plan filterOrSelf(Set<Expression> predicates, Plan plan) {
        return filter(predicates, plan).map(Plan.class::cast).orElse(plan);
    }

    /**
     * normalize comparison predicate on a binary plan to its two sides are corresponding to the child's output.
     */
    public static ComparisonPredicate maybeCommuteComparisonPredicate(ComparisonPredicate expression, Plan left) {
        Set<Slot> slots = expression.left().getInputSlots();
        Set<Slot> leftSlots = left.getOutputSet();
        Set<Slot> buffer = Sets.newHashSet(slots);
        buffer.removeAll(leftSlots);
        return buffer.isEmpty() ? expression : expression.commute();
    }

    public static Optional<LogicalProject<? extends Plan>> project(List<NamedExpression> projects, Plan plan) {
        if (projects.isEmpty()) {
            return Optional.empty();
        }
        return Optional.of(new LogicalProject<>(projects, plan));
    }

    public static Plan projectOrSelf(List<NamedExpression> projects, Plan plan) {
        return project(projects, plan).map(Plan.class::cast).orElse(plan);
    }

    public static LogicalAggregate<Plan> distinct(Plan plan) {
        if (plan instanceof LogicalAggregate && ((LogicalAggregate<?>) plan).isDistinct()) {
            return (LogicalAggregate<Plan>) plan;
        } else {
            return new LogicalAggregate<>(ImmutableList.copyOf(plan.getOutput()), false, plan);
        }
    }

    /**
     * For the columns whose output exists in grouping sets, they need to be assigned as nullable.
     */
    public static List<NamedExpression> adjustNullableForRepeat(
            List<List<Expression>> groupingSets,
            List<NamedExpression> outputs) {
        Set<Slot> groupingSetsUsedSlots = groupingSets.stream()
                .flatMap(Collection::stream)
                .map(Expression::getInputSlots)
                .flatMap(Set::stream)
                .collect(Collectors.toSet());
        Builder<NamedExpression> nullableOutputs = ImmutableList.builderWithExpectedSize(outputs.size());
        for (NamedExpression output : outputs) {
            Expression nullableOutput = output.rewriteUp(expr -> {
                if (expr instanceof Slot && groupingSetsUsedSlots.contains(expr)) {
                    return ((Slot) expr).withNullable(true);
                }
                return expr;
            });
            nullableOutputs.add((NamedExpression) nullableOutput);
        }
        return nullableOutputs.build();
    }

    /**
     * merge childProjects with parentProjects
     */
    public static List<NamedExpression> mergeProjections(List<? extends NamedExpression> childProjects,
            List<? extends NamedExpression> parentProjects) {
        Map<Slot, Expression> replaceMap = ExpressionUtils.generateReplaceMap(childProjects);
        return ExpressionUtils.replaceNamedExpressions(parentProjects, replaceMap);
    }

    public static List<Expression> replaceExpressionByProjections(List<NamedExpression> childProjects,
            List<Expression> targetExpression) {
        Map<Slot, Expression> replaceMap = ExpressionUtils.generateReplaceMap(childProjects);
        return ExpressionUtils.replace(targetExpression, replaceMap);
    }

    /**
     * replace targetExpressions with project.
     * if the target expression contains a slot which is an alias and its origin expression contains
     * non-foldable expression and the slot exits multiple times, then can not replace.
     * for example, target expressions: [a, a + 10],  child project: [ t + random() as a ],
     * if replace with the projects, then result expressions: [ t + random(),  t + random() + 10 ],
     * it will calculate random two times, this is error.
     */
    public static boolean canReplaceWithProjections(List<? extends NamedExpression> childProjects,
            List<? extends Expression> targetExpressions) {
        Set<Slot> nonfoldableSlots = ExpressionUtils.generateReplaceMap(childProjects).entrySet().stream()
                .filter(entry -> entry.getValue().containsNonfoldable())
                .map(Entry::getKey)
                .collect(Collectors.toSet());
        if (nonfoldableSlots.isEmpty()) {
            return true;
        }

        Set<Slot> counterSet = Sets.newHashSet();
        return targetExpressions.stream().noneMatch(target -> target.anyMatch(
                e -> (e instanceof Slot) && nonfoldableSlots.contains(e) && !counterSet.add((Slot) e)));
    }

    public static Plan skipProjectFilterLimit(Plan plan) {
        if (plan instanceof LogicalProject && ((LogicalProject<?>) plan).isAllSlots()
                || plan instanceof LogicalFilter || plan instanceof LogicalLimit) {
            return plan.child(0);
        }
        return plan;
    }

    public static Set<LogicalCatalogRelation> getLogicalScanFromRootPlan(LogicalPlan rootPlan) {
        return rootPlan.collect(LogicalCatalogRelation.class::isInstance);
    }

    /**
     * get table set from plan root.
     */
    public static ImmutableSet<TableIf> getTableSet(LogicalPlan plan) {
        Set<LogicalCatalogRelation> tableSet = plan.collect(LogicalCatalogRelation.class::isInstance);
        return tableSet.stream()
                .map(LogicalCatalogRelation::getTable)
                .collect(ImmutableSet.<TableIf>toImmutableSet());
    }

    /** fastGetChildrenOutput */
    public static List<Slot> fastGetChildrenOutputs(List<Plan> children) {
        switch (children.size()) {
            case 1: return children.get(0).getOutput();
            case 0: return ImmutableList.of();
            default: {
            }
        }

        int outputNum = 0;
        // child.output is cached by AbstractPlan.logicalProperties,
        // we can compute output num without the overhead of re-compute output
        for (Plan child : children) {
            List<Slot> output = child.getOutput();
            outputNum += output.size();
        }
        // generate output list only copy once and without resize the list
        Builder<Slot> output = ImmutableList.builderWithExpectedSize(outputNum);
        for (Plan child : children) {
            output.addAll(child.getOutput());
        }
        return output.build();
    }

    /** fastGetChildrenOutput */
    public static List<Slot> fastGetChildrenAsteriskOutputs(List<Plan> children) {
        switch (children.size()) {
            case 1: return children.get(0).getAsteriskOutput();
            case 0: return ImmutableList.of();
            default: {
            }
        }

        int outputNum = 0;
        // child.output is cached by AbstractPlan.logicalProperties,
        // we can compute output num without the overhead of re-compute output
        for (Plan child : children) {
            List<Slot> output = child.getAsteriskOutput();
            outputNum += output.size();
        }
        // generate output list only copy once and without resize the list
        Builder<Slot> output = ImmutableList.builderWithExpectedSize(outputNum);
        for (Plan child : children) {
            output.addAll(child.getAsteriskOutput());
        }
        return output.build();
    }

    /** fastGetInputSlots */
    public static Set<Slot> fastGetInputSlots(List<? extends Expression> expressions) {
        switch (expressions.size()) {
            case 1: return expressions.get(0).getInputSlots();
            case 0: return ImmutableSet.of();
            default: {
            }
        }

        int inputSlotsNum = 0;
        // child.inputSlots is cached by Expression.inputSlots,
        // we can compute output num without the overhead of re-compute output
        for (Expression expr : expressions) {
            Set<Slot> output = expr.getInputSlots();
            inputSlotsNum += output.size();
        }
        // generate output list only copy once and without resize the list
        ImmutableSet.Builder<Slot> inputSlots = ImmutableSet.builderWithExpectedSize(inputSlotsNum);
        for (Expression expr : expressions) {
            inputSlots.addAll(expr.getInputSlots());
        }
        return inputSlots.build();
    }

    /**
     * Check if slot is from the plan.
     */
    public static boolean checkSlotFrom(Plan plan, SlotReference slot) {
        Set<LogicalCatalogRelation> tableSets = PlanUtils.getLogicalScanFromRootPlan((LogicalPlan) plan);
        for (LogicalCatalogRelation table : tableSets) {
            if (table.getOutputExprIds().contains(slot.getExprId())) {
                return true;
            }
        }
        return false;
    }

    /**
     * Check if the expression is a column reference.
     */
    public static boolean isColumnRef(Expression expr) {
        return expr instanceof SlotReference
                && ((SlotReference) expr).getOriginalColumn().isPresent()
                && ((SlotReference) expr).getOriginalTable().isPresent();
    }

    /**
     * collect non_window_agg_func
     */
    public static class CollectNonWindowedAggFuncs {
        public static List<AggregateFunction> collect(Collection<? extends Expression> expressions) {
            List<AggregateFunction> aggFunctions = Lists.newArrayList();
            for (Expression expression : expressions) {
                doCollect(expression, aggFunctions);
            }
            return aggFunctions;
        }

        public static List<AggregateFunction> collect(Expression expression) {
            List<AggregateFunction> aggFuns = Lists.newArrayList();
            doCollect(expression, aggFuns);
            return aggFuns;
        }

        private static void doCollect(Expression expression, List<AggregateFunction> aggFunctions) {
            expression.foreach(expr -> {
                if (expr instanceof AggregateFunction) {
                    aggFunctions.add((AggregateFunction) expr);
                    return true;
                } else if (expr instanceof WindowExpression) {
                    WindowExpression windowExpression = (WindowExpression) expr;
                    for (Expression exprInWindowsSpec : windowExpression.getExpressionsInWindowSpec()) {
                        doCollect(exprInWindowsSpec, aggFunctions);
                    }
                    return true;
                } else {
                    return false;
                }
            });
        }
    }

    /**OutermostPlanFinderContext*/
    public static class OutermostPlanFinderContext {
        public Plan outermostPlan = null;
        public boolean found = false;
    }

    /**OutermostPlanFinder*/
    public static class OutermostPlanFinder extends
            DefaultPlanVisitor<Void, OutermostPlanFinderContext> {
        public static final OutermostPlanFinder INSTANCE = new OutermostPlanFinder();

        @Override
        public Void visit(Plan plan, OutermostPlanFinderContext ctx) {
            if (ctx.found) {
                return null;
            }
            ctx.outermostPlan = plan;
            ctx.found = true;
            return null;
        }

        @Override
        public Void visitLogicalCTEAnchor(LogicalCTEAnchor<? extends Plan, ? extends Plan> cteAnchor,
                OutermostPlanFinderContext ctx) {
            if (ctx.found) {
                return null;
            }
            return super.visit(cteAnchor, ctx);
        }

        @Override
        public Void visitLogicalCTEProducer(LogicalCTEProducer<? extends Plan> cteProducer,
                OutermostPlanFinderContext ctx) {
            return null;
        }
    }

    /**
     * translate to legacy expr, which do not need complex expression and table columns
     */
    public static Expr translateToLegacyExpr(Expression expression, TableIf table, ConnectContext ctx) {
        LogicalEmptyRelation plan = new LogicalEmptyRelation(
                ConnectContext.get().getStatementContext().getNextRelationId(), new ArrayList<>());
        CascadesContext cascadesContext = CascadesContext.initContext(ctx.getStatementContext(), plan,
                PhysicalProperties.ANY);
        ExpressionAnalyzer analyzer = new CustomExpressionAnalyzer(table, cascadesContext);
        expression = analyzer.analyze(expression);

        PlanTranslatorContext translatorContext = new PlanTranslatorContext(cascadesContext);
        ExpressionToExpr translator = new ExpressionToExpr();
        return expression.accept(translator, translatorContext);
    }

    private static class CustomExpressionAnalyzer extends ExpressionAnalyzer {
        private Map<String, DataType> columnTypes = new CaseInsensitiveMap();

        public CustomExpressionAnalyzer(TableIf table, CascadesContext cascadesContext) {
            super(null, new Scope(ImmutableList.of()), cascadesContext, false, false);
            if (table != null) {
                for (Column column : table.getFullSchema()) {
                    columnTypes.put(column.getName(), DataType.fromCatalogType(column.getType()));
                }
            }
        }

        @Override
        public Expression visitUnboundSlot(UnboundSlot unboundSlot, ExpressionRewriteContext context) {
            DataType dataType = columnTypes.getOrDefault(unboundSlot.getName(), VarcharType.MAX_VARCHAR_TYPE);
            return new SlotReference(unboundSlot.getName(), dataType);
        }
    }

    private static class ExpressionToExpr extends ExpressionTranslator {
        @Override
        public Expr visitSlotReference(SlotReference slotReference, PlanTranslatorContext context) {
            SlotRef slotRef = new SlotRef(slotReference.getDataType().toCatalogDataType(), slotReference.nullable());
            slotRef.setLabel(slotReference.getName());
            slotRef.setCol(slotReference.getName());
            slotRef.setDisableTableName(true);
            return slotRef;
        }
    }
}