SimplifyInPredicate.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.literal.DateTimeV2Literal;
import org.apache.doris.nereids.trees.expressions.literal.DateV2Literal;
import org.apache.doris.nereids.types.DateTimeV2Type;

import com.google.common.collect.ImmutableList;

import java.util.List;

/**
 * SimplifyInPredicate
 */
public class SimplifyInPredicate implements ExpressionPatternRuleFactory {
    public static final SimplifyInPredicate INSTANCE = new SimplifyInPredicate();

    @Override
    public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
        return ImmutableList.of(
                matchesType(InPredicate.class).then(SimplifyInPredicate::simplify)
                        .toRule(ExpressionRuleType.SIMPLIFY_IN_PREDICATE)
        );
    }

    /** simplify */
    public static Expression simplify(InPredicate expr) {
        if (expr.children().size() > 1) {
            if (expr.getCompareExpr() instanceof Cast) {
                Cast cast = (Cast) expr.getCompareExpr();
                if (cast.child().getDataType().isDateV2Type()
                        && expr.child(1) instanceof DateTimeV2Literal) {
                    List<Expression> literals = expr.children().subList(1, expr.children().size());
                    if (literals.stream().allMatch(literal -> literal instanceof DateTimeV2Literal
                            && canLosslessConvertToDateV2Literal((DateTimeV2Literal) literal))) {
                        ImmutableList.Builder<Expression> children = ImmutableList.builder();
                        children.add(cast.child());
                        literals.forEach(l -> children.add(convertToDateV2Literal((DateTimeV2Literal) l)));
                        return expr.withChildren(children.build());
                    }
                } else if (cast.child().getDataType().isDateTimeV2Type()
                        && expr.child(1) instanceof DateTimeV2Literal) {
                    List<Expression> literals = expr.children().subList(1, expr.children().size());
                    DateTimeV2Type compareType = (DateTimeV2Type) cast.child().getDataType();
                    if (literals.stream().allMatch(literal -> literal instanceof DateTimeV2Literal
                            && canLosslessConvertToLowScaleLiteral(
                            (DateTimeV2Literal) literal, compareType.getScale()))) {
                        ImmutableList.Builder<Expression> children = ImmutableList.builder();
                        children.add(cast.child());
                        literals.forEach(l -> children.add(new DateTimeV2Literal(compareType,
                                ((DateTimeV2Literal) l).getStringValue())));
                        return expr.withChildren(children.build());
                    }

                }
            }
        }
        return expr;
    }

    /*
    derive tree:
    DateLiteral
      |
      +--->DateTimeLiteral
      |        |
      |        +----->DateTimeV2Literal
      +--->DateV2Literal
    */
    private static boolean canLosslessConvertToDateV2Literal(DateTimeV2Literal literal) {
        return (literal.getHour() | literal.getMinute() | literal.getSecond()
                | literal.getMicroSecond()) == 0L;
    }

    private static DateV2Literal convertToDateV2Literal(DateTimeV2Literal literal) {
        return new DateV2Literal(literal.getYear(), literal.getMonth(), literal.getDay());
    }

    private static boolean canLosslessConvertToLowScaleLiteral(DateTimeV2Literal literal, int targetScale) {
        return literal.getMicroSecond() % (1L << (DateTimeV2Type.MAX_SCALE - targetScale)) == 0;
    }
}