InferPredicateByReplace.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.rewrite;

import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.analyzer.Scope;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.Like;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.ImmutableEqualSet;
import org.apache.doris.nereids.util.PredicateInferUtils;

import com.google.common.collect.ImmutableList;
import org.jetbrains.annotations.Nullable;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

/**ReplacePredicate*/
public class InferPredicateByReplace {
    private static List<Expression> getAllSubExpressions(Expression expr) {
        List<Expression> subExpressions = new ArrayList<>();
        getAllSubExpressions(expr, subExpressions);
        return subExpressions;
    }

    private static void getAllSubExpressions(Expression expr, List<Expression> res) {
        res.add(expr);
        if (expr.children().size() != 1) {
            Set<Slot> slots = expr.getInputSlots();
            if (slots.size() == 1) {
                res.add(slots.iterator().next());
            }
            return;
        }
        getAllSubExpressions(expr.child(0), res);
    }

    /** fill map exprPredicates : expression and all its corresponding predicates */
    private static class PredicatesCollector extends ExpressionVisitor<Void, Map<Expression, Set<Expression>>> {
        public static PredicatesCollector INSTANCE = new PredicatesCollector();

        @Override
        public Void visit(Expression expr, Map<Expression, Set<Expression>> context) {
            return null;
        }

        @Override
        public Void visitInPredicate(InPredicate inPredicate, Map<Expression, Set<Expression>> context) {
            if (!validInPredicate(inPredicate)) {
                return null;
            }
            for (Expression expr : getAllSubExpressions(inPredicate.getCompareExpr())) {
                context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(inPredicate);
            }
            return null;
        }

        @Override
        public Void visitComparisonPredicate(ComparisonPredicate comparisonPredicate,
                Map<Expression, Set<Expression>> context) {
            if (!validComparisonPredicate(comparisonPredicate)) {
                return null;
            }
            // It is believed that 1<a has been rewritten as a>1
            for (Expression expr : getAllSubExpressions(comparisonPredicate.child(0))) {
                context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(comparisonPredicate);
            }
            return null;
        }

        @Override
        public Void visitNot(Not not, Map<Expression, Set<Expression>> context) {
            if (not.child(0) instanceof InPredicate && validInPredicate((InPredicate) not.child(0))
                    || not.child(0) instanceof ComparisonPredicate
                    && validComparisonPredicate((ComparisonPredicate) not.child(0))) {
                for (Expression expr : getAllSubExpressions(not.child(0).child(0))) {
                    context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(not);
                }
            }
            return null;
        }

        @Override
        public Void visitLike(Like like, Map<Expression, Set<Expression>> context) {
            if (!(like.child(1) instanceof Literal)) {
                return null;
            }
            for (Expression expr : getAllSubExpressions(like.child(0))) {
                context.computeIfAbsent(expr, k -> new LinkedHashSet<>()).add(like);
            }
            return null;
        }

        private boolean validComparisonPredicate(ComparisonPredicate comparisonPredicate) {
            return comparisonPredicate.right() instanceof Literal;
        }

        private boolean validInPredicate(InPredicate inPredicate) {
            return inPredicate.isLiteralChildren();
        }
    }

    /* replaceToThis: find all predicates that replaceToThis can deduce (e.g. replaceToThis = b)
     equalSet: the equivalent set of replaceToThis (e.g. equalSet: a=b)
     exprPredicates: expression and all its corresponding predicates (e.g. such as {a: [a<10, a>1], b: [b in (1, 2)]})
     return: all predicates that replaceToThis can deduce (return b<10, b>1) */
    private static <T extends Expression> Set<Expression> getEqualSetAndDoReplace(T replaceToThis, Set<T> equalSet,
            Map<? extends Expression, Set<Expression>> exprPredicates) {
        ExpressionAnalyzer analyzer = new ReplaceAnalyzer(null, new Scope(ImmutableList.of()), null, false, false);
        Set<Expression> res = new LinkedHashSet<>();
        for (T equals : equalSet) {
            Map<Expression, Expression> replaceMap = new HashMap<>();
            replaceMap.put(equals, replaceToThis);
            if (!exprPredicates.containsKey(equals)) {
                continue;
            }
            for (Expression predicate : exprPredicates.get(equals)) {
                Expression newPredicates = ExpressionUtils.replace(predicate, replaceMap);
                try {
                    Expression analyzed = analyzer.analyze(newPredicates);
                    res.add(analyzed.withInferred(true));
                } catch (Exception e) {
                    // has cast error, just not infer and do nothing
                }
            }
        }
        return res;
    }

    /* Extract the equivalence relationship a=b, and when case (d_tinyint as int)=d_int is encountered,
    remove the cast and extract d_tinyint=d_int
    EqualPairs is the output parameter and the equivalent pair of predicate derivation input,
    which is used to ensure that the derivation
    does not generate repeated equivalent conditions, such as a=b and b=a */
    private static ImmutableEqualSet<Expression> findEqual(Set<Expression> inputs) {
        ImmutableEqualSet.Builder<Expression> fromCastEqualSetBuilder = new ImmutableEqualSet.Builder<>();
        for (Expression input : inputs) {
            if (!(input instanceof EqualTo)) {
                continue;
            }
            EqualTo equalTo = (EqualTo) input;
            Set<Slot> leftInputSlots = equalTo.left().getInputSlots();
            Set<Slot> rightInputSlots = equalTo.right().getInputSlots();
            if (leftInputSlots.isEmpty() && rightInputSlots.isEmpty()) {
                continue;
            }
            PredicateInferUtils.getPairFromCast((ComparisonPredicate) input)
                    .filter(pair -> PredicateInferUtils.isSlotOrLiteral(pair.first)
                            && PredicateInferUtils.isSlotOrLiteral(pair.second))
                    .filter(pair -> !(pair.first instanceof NullLiteral) && !(pair.second instanceof NullLiteral))
                    .ifPresent(pair -> {
                        Expression left = pair.first;
                        Expression right = pair.second;
                        fromCastEqualSetBuilder.addEqualPair(left, right);
                    });
        }
        return fromCastEqualSetBuilder.build();
    }

    /** This is the exposed interface. Inputs are the input predicates for derivation.
     * The return value is the derived predicates*/
    public static Set<Expression> infer(Set<Expression> inputs) {
        ImmutableEqualSet<Expression> hasCastEqualSet = findEqual(inputs);
        Set<Expression> targetExprs = hasCastEqualSet.getAllItemSet();
        if (targetExprs.isEmpty()) {
            return new LinkedHashSet<>(inputs);
        }
        Map<Expression, Set<Expression>> exprPredicates = new HashMap<>();
        for (Expression input : inputs) {
            if (input.anyMatch(expr -> !((ExpressionTrait) expr).isDeterministic())
                    || input.getInputSlots().size() != 1) {
                continue;
            }
            input.accept(PredicatesCollector.INSTANCE, exprPredicates);
        }
        Set<Expression> inferPredicates = new LinkedHashSet<>(inputs);
        if (!exprPredicates.isEmpty()) {
            for (Expression expr : targetExprs) {
                if (expr instanceof Literal) {
                    continue;
                }
                inferPredicates.addAll(getEqualSetAndDoReplace(expr, hasCastEqualSet.calEqualSet(expr),
                        exprPredicates));
            }
        }
        return inferPredicates;
    }

    /** ReplaceAnalyzer is to perform type conversion on the expression after replacement
     * and perform type check on the expression.
     * If there is a cast that will cause an error during execution, an exception should be thrown. */
    private static class ReplaceAnalyzer extends ExpressionAnalyzer {
        private ReplaceAnalyzer(Plan currentPlan, Scope scope,
                @Nullable CascadesContext cascadesContext,
                boolean enableExactMatch, boolean bindSlotInOuterScope) {
            super(currentPlan, scope, cascadesContext, enableExactMatch, bindSlotInOuterScope);
        }

        @Override
        public Expression visitCast(Cast cast, ExpressionRewriteContext context) {
            cast = (Cast) super.visitCast(cast, context);
            if (cast.getDataType().isDecimalV3Type()) {
                DecimalV3Type targetType = (DecimalV3Type) cast.getDataType();
                DecimalV3Type childType = DecimalV3Type.forType(cast.child().getDataType());
                if ((childType.getPrecision() - childType.getScale())
                        > (targetType.getPrecision() - targetType.getScale())
                        || childType.getScale() > targetType.getScale()) {
                    throw new AnalysisException("can not cast from origin type " + cast.child().getDataType()
                            + " to target type=" + targetType);
                }
            } else if (cast.getDataType().isDecimalV2Type()) {
                DecimalV2Type targetType = (DecimalV2Type) cast.getDataType();
                DecimalV2Type childType = DecimalV2Type.forType(cast.child().getDataType());
                if ((childType.getPrecision() - childType.getScale())
                        > (targetType.getPrecision() - targetType.getScale())
                        || childType.getScale() > targetType.getScale()) {
                    throw new AnalysisException("can not cast from origin type " + cast.child().getDataType()
                            + " to target type=" + targetType);
                }
            }
            return cast;
        }
    }
}