PullUpPredicates.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.rules.FoldConstantRuleOnFE;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalExcept;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalGenerate;
import org.apache.doris.nereids.trees.plans.logical.LogicalIntersect;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
import org.apache.doris.nereids.trees.plans.logical.LogicalSort;
import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.trees.plans.logical.LogicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PredicateInferUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.base.Suppliers;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.IdentityHashMap;
import java.util.LinkedHashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.Supplier;

/**
 * poll up effective predicates from operator's children.
 */
public class PullUpPredicates extends PlanVisitor<ImmutableSet<Expression>, Void> {

    private static final ImmutableSet<Class<? extends Expression>> supportAggFunctions = ImmutableSet.of(
            Max.class, Min.class, AnyValue.class);
    Map<Plan, ImmutableSet<Expression>> cache = new IdentityHashMap<>();
    private final boolean getAllPredicates;
    private final ExpressionRewriteContext rewriteContext;

    public PullUpPredicates(boolean all, CascadesContext cascadesContext) {
        getAllPredicates = all;
        rewriteContext = new ExpressionRewriteContext(cascadesContext);
    }

    @Override
    public ImmutableSet<Expression> visit(Plan plan, Void context) {
        return ImmutableSet.of();
    }

    @Override
    public ImmutableSet<Expression> visitLogicalSort(LogicalSort<? extends Plan> sort, Void context) {
        return cacheOrElse(sort, () -> sort.child(0).accept(this, context));
    }

    @Override
    public ImmutableSet<Expression> visitLogicalLimit(LogicalLimit<? extends Plan> limit, Void context) {
        return cacheOrElse(limit, () -> limit.child(0).accept(this, context));
    }

    @Override
    public ImmutableSet<Expression> visitLogicalTopN(LogicalTopN<? extends Plan> topN, Void context) {
        return cacheOrElse(topN, () -> topN.child(0).accept(this, context));
    }

    @Override
    public ImmutableSet<Expression> visitLogicalPartitionTopN(LogicalPartitionTopN<? extends Plan> topN, Void context) {
        return cacheOrElse(topN, () -> topN.child(0).accept(this, context));
    }

    @Override
    public ImmutableSet<Expression> visitLogicalGenerate(LogicalGenerate<? extends Plan> generate, Void context) {
        return cacheOrElse(generate, () -> generate.child(0).accept(this, context));
    }

    @Override
    public ImmutableSet<Expression> visitLogicalWindow(LogicalWindow<? extends Plan> window, Void context) {
        return cacheOrElse(window, () -> window.child(0).accept(this, context));
    }

    @Override
    public ImmutableSet<Expression> visitLogicalRepeat(LogicalRepeat<? extends Plan> repeat, Void context) {
        return cacheOrElse(repeat, () -> {
            ImmutableSet<Expression> childPredicates = repeat.child().accept(this, context);
            Set<Expression> commonGroupingSetExpressions = repeat.getCommonGroupingSetExpressions();
            if (commonGroupingSetExpressions.isEmpty()) {
                return ImmutableSet.of();
            }

            Set<Expression> pulledPredicates = new LinkedHashSet<>();
            for (Expression conjunct : childPredicates) {
                Set<Slot> conjunctSlots = conjunct.getInputSlots();
                if (commonGroupingSetExpressions.containsAll(conjunctSlots)) {
                    pulledPredicates.add(conjunct);
                }
            }
            return ImmutableSet.copyOf(pulledPredicates);
        });
    }

    @Override
    public ImmutableSet<Expression> visitLogicalOneRowRelation(LogicalOneRowRelation r, Void context) {
        return cacheOrElse(r, () -> {
            Set<Expression> predicates = new LinkedHashSet<>();
            for (NamedExpression expr : r.getProjects()) {
                if (expr instanceof Alias && expr.child(0) instanceof Literal) {
                    predicates.add(generateEqual(expr));
                }
            }
            return ImmutableSet.copyOf(predicates);
        });
    }

    @Override
    public ImmutableSet<Expression> visitLogicalIntersect(LogicalIntersect intersect, Void context) {
        return cacheOrElse(intersect, () -> {
            Set<Expression> predicates = new LinkedHashSet<>();
            for (int i = 0; i < intersect.children().size(); ++i) {
                Plan child = intersect.child(i);
                Set<Expression> childFilters = child.accept(this, context);
                if (childFilters.isEmpty()) {
                    continue;
                }
                Map<Expression, Expression> replaceMap = new HashMap<>();
                for (int j = 0; j < intersect.getOutput().size(); ++j) {
                    NamedExpression output = intersect.getOutput().get(j);
                    replaceMap.put(intersect.getRegularChildOutput(i).get(j), output);
                }
                predicates.addAll(ExpressionUtils.replace(childFilters, replaceMap));
            }
            return getAvailableExpressions(ImmutableSet.copyOf(predicates), intersect);
        });
    }

    @Override
    public ImmutableSet<Expression> visitLogicalExcept(LogicalExcept except, Void context) {
        return cacheOrElse(except, () -> {
            if (except.arity() < 1) {
                return ImmutableSet.of();
            }
            Set<Expression> firstChildFilters = except.child(0).accept(this, context);
            if (firstChildFilters.isEmpty()) {
                return ImmutableSet.of();
            }
            Map<Expression, Expression> replaceMap = new HashMap<>();
            for (int i = 0; i < except.getOutput().size(); ++i) {
                NamedExpression output = except.getOutput().get(i);
                replaceMap.put(except.getRegularChildOutput(0).get(i), output);
            }
            return ImmutableSet.copyOf(ExpressionUtils.replace(firstChildFilters, replaceMap));
        });
    }

    @Override
    public ImmutableSet<Expression> visitLogicalUnion(LogicalUnion union, Void context) {
        return cacheOrElse(union, () -> {
            if (!union.getConstantExprsList().isEmpty() && union.arity() == 0) {
                return getFiltersFromUnionConstExprs(union);
            } else if (union.getConstantExprsList().isEmpty() && union.arity() != 0) {
                return getFiltersFromUnionChild(union, context);
            } else if (!union.getConstantExprsList().isEmpty() && union.arity() != 0) {
                Set<Expression> fromChildFilters = new LinkedHashSet<>(getFiltersFromUnionChild(union, context));
                if (fromChildFilters.isEmpty()) {
                    return ImmutableSet.of();
                }
                if (!ExpressionUtils.unionConstExprsSatisfyConjuncts(union, fromChildFilters)) {
                    return ImmutableSet.of();
                }
                return ImmutableSet.copyOf(fromChildFilters);
            }
            return ImmutableSet.of();
        });
    }

    @Override
    public ImmutableSet<Expression> visitLogicalFilter(LogicalFilter<? extends Plan> filter, Void context) {
        return cacheOrElse(filter, () -> {
            Set<Expression> predicates = Sets.newLinkedHashSet(filter.getConjuncts());
            predicates.addAll(filter.child().accept(this, context));
            return getAvailableExpressions(predicates, filter);
        });
    }

    @Override
    public ImmutableSet<Expression> visitLogicalJoin(LogicalJoin<? extends Plan, ? extends Plan> join, Void context) {
        return cacheOrElse(join, () -> {
            Set<Expression> predicates = new LinkedHashSet<>();
            Supplier<ImmutableSet<Expression>> leftPredicates = Suppliers.memoize(
                    () -> join.left().accept(this, context));
            Supplier<ImmutableSet<Expression>> rightPredicates = Suppliers.memoize(
                    () -> join.right().accept(this, context));
            switch (join.getJoinType()) {
                case CROSS_JOIN:
                case INNER_JOIN: {
                    predicates.addAll(leftPredicates.get());
                    predicates.addAll(rightPredicates.get());
                    predicates.addAll(join.getHashJoinConjuncts());
                    predicates.addAll(join.getOtherJoinConjuncts());
                    break;
                }
                case LEFT_OUTER_JOIN:
                case LEFT_SEMI_JOIN:
                case LEFT_ANTI_JOIN:
                case NULL_AWARE_LEFT_ANTI_JOIN: {
                    predicates.addAll(leftPredicates.get());
                    break;
                }
                case RIGHT_OUTER_JOIN:
                case RIGHT_SEMI_JOIN:
                case RIGHT_ANTI_JOIN: {
                    predicates.addAll(rightPredicates.get());
                    break;
                }
                default:
                    break;
            }
            return getAvailableExpressions(predicates, join);
        });
    }

    @Override
    public ImmutableSet<Expression> visitLogicalProject(LogicalProject<? extends Plan> project, Void context) {
        return cacheOrElse(project, () -> {
            ImmutableSet<Expression> childPredicates = project.child().accept(this, context);
            Set<Expression> allPredicates = Sets.newLinkedHashSet();
            /* this generateMap is used to
            * e.g LogicalProject(t.a)   the qualifier t may come from LogicalSubQueryAlias
            *       +--LogicalFilter(a>1)
            * use generateMap to make sure a>1 is pulled up and turn into t.a>1
            * */
            for (Entry<Slot, Expression> kv : generateMap(project.getProjects()).entrySet()) {
                Slot k = kv.getKey();
                Expression v = kv.getValue();
                for (Expression childPredicate : childPredicates) {
                    allPredicates.add(childPredicate.rewriteDownShortCircuit(c -> c.equals(v) ? k : c));
                }
            }
            for (NamedExpression expr : project.getProjects()) {
                if (expr instanceof Alias && expr.child(0) instanceof Literal) {
                    allPredicates.add(generateEqual(expr));
                }
            }
            return getAvailableExpressions(allPredicates, project);
        });
    }

    /* e.g. LogicalAggregate(output:max(a), min(a), avg(a))
              +--LogicalFilter(a>1, a<10)
       when a>1 is pulled up, we can have max(a)>1, min(a)>1 and avg(a)>1
       and a<10 is pulled up, we can have max(a)<10, min(a)<10 and avg(a)<10
    * */
    @Override
    public ImmutableSet<Expression> visitLogicalAggregate(LogicalAggregate<? extends Plan> aggregate, Void context) {
        return cacheOrElse(aggregate, () -> {
            ImmutableSet<Expression> childPredicates = aggregate.child().accept(this, context);
            List<NamedExpression> outputExpressions = aggregate.getOutputExpressions();
            Map<Expression, List<Slot>> expressionSlotMap
                    = Maps.newHashMapWithExpectedSize(outputExpressions.size());
            for (NamedExpression output : outputExpressions) {
                if (output instanceof Alias && supportPullUpAgg(output.child(0))) {
                    expressionSlotMap.computeIfAbsent(output.child(0).child(0),
                            k -> new ArrayList<>()).add(output.toSlot());
                }
            }
            Set<Expression> pullPredicates = new LinkedHashSet<>(childPredicates);
            for (Expression childPredicate : childPredicates) {
                if (childPredicate instanceof ComparisonPredicate) {
                    ComparisonPredicate cmp = (ComparisonPredicate) childPredicate;
                    if (cmp.left() instanceof SlotReference && cmp.right() instanceof Literal
                            && expressionSlotMap.containsKey(cmp.left())) {
                        for (Slot slot : expressionSlotMap.get(cmp.left())) {
                            Expression genPredicates = TypeCoercionUtils.processComparisonPredicate(
                                     (ComparisonPredicate) cmp.withChildren(slot, cmp.right()));
                            genPredicates = FoldConstantRuleOnFE.evaluate(genPredicates, rewriteContext);
                            pullPredicates.add(genPredicates);
                        }
                    }
                }
            }
            return getAvailableExpressions(pullPredicates, aggregate);
        });
    }

    private ImmutableSet<Expression> cacheOrElse(Plan plan, Supplier<ImmutableSet<Expression>> predicatesSupplier) {
        ImmutableSet<Expression> predicates = cache.get(plan);
        if (predicates != null) {
            return predicates;
        }
        predicates = predicatesSupplier.get();
        cache.put(plan, predicates);
        return predicates;
    }

    private ImmutableSet<Expression> getAvailableExpressions(Set<Expression> predicates, Plan plan) {
        if (predicates.isEmpty()) {
            return ImmutableSet.of();
        }
        Set<Expression> inferPredicates = new LinkedHashSet<>();
        if (getAllPredicates) {
            inferPredicates.addAll(PredicateInferUtils.inferAllPredicate(predicates));
        } else {
            inferPredicates.addAll(PredicateInferUtils.inferPredicate(predicates));
        }
        Set<Expression> newPredicates = new LinkedHashSet<>(inferPredicates.size());
        Set<Slot> outputSet = plan.getOutputSet();

        for (Expression inferPredicate : inferPredicates) {
            if (outputSet.containsAll(inferPredicate.getInputSlots())) {
                newPredicates.add(inferPredicate);
            }
        }
        return ImmutableSet.copyOf(newPredicates);
    }

    private boolean supportPullUpAgg(Expression expr) {
        return supportAggFunctions.contains(expr.getClass());
    }

    private ImmutableSet<Expression> getFiltersFromUnionChild(LogicalUnion union, Void context) {
        Set<Expression> filters = new LinkedHashSet<>();
        for (int i = 0; i < union.getArity(); ++i) {
            Plan child = union.child(i);
            Set<Expression> childFilters = child.accept(this, context);
            if (childFilters.isEmpty()) {
                return ImmutableSet.of();
            }
            Map<Expression, Expression> replaceMap = new HashMap<>();
            for (int j = 0; j < union.getOutput().size(); ++j) {
                NamedExpression output = union.getOutput().get(j);
                replaceMap.put(union.getRegularChildOutput(i).get(j), output);
            }
            Set<Expression> unionFilters = ExpressionUtils.replace(childFilters, replaceMap);
            if (0 == i) {
                filters.addAll(unionFilters);
            } else {
                filters.retainAll(unionFilters);
            }
            if (filters.isEmpty()) {
                return ImmutableSet.of();
            }
        }
        return ImmutableSet.copyOf(filters);
    }

    private ImmutableSet<Expression> getFiltersFromUnionConstExprs(LogicalUnion union) {
        List<List<NamedExpression>> constExprs = union.getConstantExprsList();
        Set<Expression> filtersFromConstExprs = new LinkedHashSet<>();
        for (int col = 0; col < union.getOutput().size(); ++col) {
            Expression compareExpr = union.getOutput().get(col);
            Set<Expression> options = new LinkedHashSet<>();
            for (List<NamedExpression> constExpr : constExprs) {
                if (constExpr.get(col) instanceof Alias
                        && ((Alias) constExpr.get(col)).child() instanceof Literal) {
                    options.add(((Alias) constExpr.get(col)).child());
                } else {
                    options.clear();
                    break;
                }
            }
            options.removeIf(option -> option instanceof NullLiteral);
            if (options.size() > 1) {
                filtersFromConstExprs.add(new InPredicate(compareExpr, options));
            } else if (options.size() == 1) {
                filtersFromConstExprs.add(new EqualTo(compareExpr, options.iterator().next()));
            }
        }
        return ImmutableSet.copyOf(filtersFromConstExprs);
    }

    private Expression generateEqual(NamedExpression expr) {
        // IsNull have better performance and compatibility than NullSafeEqualTo
        if (expr.child(0) instanceof NullLiteral) {
            return new IsNull(expr.toSlot());
        } else {
            return new EqualTo(expr.toSlot(), expr.child(0));
        }
    }

    private Map<Slot, Expression> generateMap(List<NamedExpression> namedExpressions) {
        Map<Slot, Expression> replaceMap = new LinkedHashMap<>(namedExpressions.size());
        for (NamedExpression namedExpression : namedExpressions) {
            if (namedExpression instanceof Alias) {
                replaceMap.putIfAbsent(namedExpression.toSlot(), namedExpression.child(0));
            } else if (namedExpression instanceof SlotReference) {
                replaceMap.putIfAbsent((Slot) namedExpression, namedExpression);
            }
        }
        return replaceMap;
    }
}