SimplifyArithmeticComparisonRule.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Divide;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.Multiply;
import org.apache.doris.nereids.trees.expressions.Subtract;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.DaysSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.HoursSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.MinutesSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecondsAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.SecondsSub;
import org.apache.doris.nereids.trees.expressions.functions.scalar.WeeksAdd;
import org.apache.doris.nereids.trees.expressions.functions.scalar.WeeksSub;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.TypeCoercionUtils;

import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;

import java.util.Arrays;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;

/**
 * Simplify arithmetic comparison rule.
 * a + 1 > 1 => a > 0
 * a / -2 > 1 => a < -2
 */
public class SimplifyArithmeticComparisonRule implements ExpressionPatternRuleFactory {
    public static SimplifyArithmeticComparisonRule INSTANCE = new SimplifyArithmeticComparisonRule();

    // don't rearrange multiplication because divide may loss precision
    private static final Map<Class<? extends Expression>, Class<? extends Expression>> REARRANGEMENT_MAP = ImmutableMap
            .<Class<? extends Expression>, Class<? extends Expression>>builder()
            .put(Add.class, Subtract.class)
            .put(Subtract.class, Add.class)
            .put(Divide.class, Multiply.class)
            // ATTN: YearsAdd, MonthsAdd can not reverse
            //       for example, months_add(date '2024-01-31', 1) = date '2024-02-29' can not reverse to
            //       date '2024-01-31' = months_sub(date '2024-02-29', 1)
            .put(WeeksSub.class, WeeksAdd.class)
            .put(WeeksAdd.class, WeeksSub.class)
            .put(DaysSub.class, DaysAdd.class)
            .put(DaysAdd.class, DaysSub.class)
            .put(HoursSub.class, HoursAdd.class)
            .put(HoursAdd.class, HoursSub.class)
            .put(MinutesSub.class, MinutesAdd.class)
            .put(MinutesAdd.class, MinutesSub.class)
            .put(SecondsSub.class, SecondsAdd.class)
            .put(SecondsAdd.class, SecondsSub.class)
            .build();

    @Override
    public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
        return ImmutableList.of(
                matchesType(ComparisonPredicate.class)
                        .thenApply(ctx -> simplify(ctx.expr, new ExpressionRewriteContext(ctx.cascadesContext)))
                        .toRule(ExpressionRuleType.SIMPLIFY_ARITHMETIC_COMPARISON)
        );
    }

    /** simplify */
    public static Expression simplify(ComparisonPredicate comparison, ExpressionRewriteContext context) {
        if (!couldRearrange(comparison)) {
            return comparison;
        }
        ComparisonPredicate newComparison = normalize(comparison);
        if (newComparison == null) {
            return comparison;
        }
        try {
            List<Expression> children = tryRearrangeChildren(newComparison.left(), newComparison.right(), context);
            newComparison = (ComparisonPredicate) simplify(
                    (ComparisonPredicate) newComparison.withChildren(children), context);
            return TypeCoercionUtils.processComparisonPredicate(newComparison);
        } catch (Exception e) {
            return comparison;
        }
    }

    private static boolean couldRearrange(ComparisonPredicate cmp) {
        if (!REARRANGEMENT_MAP.containsKey(cmp.left().getClass()) || cmp.left().isConstant()) {
            return false;
        }

        for (Expression child : cmp.left().children()) {
            if (child.isConstant()) {
                return true;
            }
        }
        return false;
    }

    private static List<Expression> tryRearrangeChildren(Expression left, Expression right,
            ExpressionRewriteContext context) throws Exception {
        if (!left.child(1).isConstant()) {
            throw new RuntimeException(String.format("Expected literal when arranging children for Expr %s", left));
        }
        ComparableLiteral leftLiteral = (ComparableLiteral) FoldConstantRule.evaluate(left.child(1), context);
        Expression leftExpr = left.child(0);

        Class<? extends Expression> oppositeOperator = REARRANGEMENT_MAP.get(left.getClass());
        Expression newChild = oppositeOperator.getConstructor(Expression.class, Expression.class)
                .newInstance(right, leftLiteral);

        if (left instanceof Divide && leftLiteral.compareTo(new IntegerLiteral(0)) < 0) {
            // Multiplying by a negative number will change the operator.
            return Arrays.asList(newChild, leftExpr);
        }
        return Arrays.asList(leftExpr, newChild);
    }

    // Ensure that the second child must be Literal, such as
    private static @Nullable ComparisonPredicate normalize(ComparisonPredicate comparison) {
        Expression left = comparison.left();
        Expression leftRight = left.child(1);
        if (leftRight instanceof Literal) {
            return comparison;
        }
        if (left instanceof Add) {
            // 1 + a > 1 => a + 1 > 1
            Expression newLeft = left.withChildren(leftRight, left.child(0));
            return (ComparisonPredicate) comparison.withChildren(newLeft, comparison.right());
        } else if (left instanceof Subtract) {
            // 1 - a > 1 => a + 1 < 1
            Expression newLeft = left.child(0);
            Expression newRight = new Add(leftRight, comparison.right());
            comparison = (ComparisonPredicate) comparison.withChildren(newLeft, newRight);
            return comparison.commute();
        } else {
            // Don't normalize division/multiplication because the slot sign is undecided.
            return null;
        }
    }
}