SimplifyDecimalV3Comparison.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.trees.expressions.Cast;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.literal.DecimalV3Literal;
import org.apache.doris.nereids.types.DecimalV3Type;
import com.google.common.collect.ImmutableList;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.List;
/**
* if we have a column with decimalv3 type and set enable_decimal_conversion = false.
* we have a column named col1 with type decimalv3(15, 2)
* and we have a comparison like col1 > 0.5 + 0.1
* then the result type of 0.5 + 0.1 is decimalv2(27, 9)
* and the col1 need to convert to decimalv3(27, 9) to match the precision of right hand
* this rule simplify it from cast(col1 as decimalv3(27, 9)) > 0.6 to col1 > 0.6
*/
public class SimplifyDecimalV3Comparison implements ExpressionPatternRuleFactory {
public static SimplifyDecimalV3Comparison INSTANCE = new SimplifyDecimalV3Comparison();
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(ComparisonPredicate.class).then(SimplifyDecimalV3Comparison::simplify)
);
}
/** simplify */
public static Expression simplify(ComparisonPredicate cp) {
Expression left = cp.left();
Expression right = cp.right();
if (left.getDataType() instanceof DecimalV3Type
&& left instanceof Cast
&& ((Cast) left).child().getDataType() instanceof DecimalV3Type
&& ((DecimalV3Type) left.getDataType()).getScale()
>= ((DecimalV3Type) ((Cast) left).child().getDataType()).getScale()
&& right instanceof DecimalV3Literal) {
try {
return doProcess(cp, (Cast) left, (DecimalV3Literal) right);
} catch (ArithmeticException e) {
return cp;
}
}
return cp;
}
private static Expression doProcess(ComparisonPredicate cp, Cast left, DecimalV3Literal right) {
BigDecimal trailingZerosValue = right.getValue().stripTrailingZeros();
int scale = org.apache.doris.analysis.DecimalLiteral.getBigDecimalScale(trailingZerosValue);
int precision = org.apache.doris.analysis.DecimalLiteral.getBigDecimalPrecision(trailingZerosValue);
try {
trailingZerosValue = trailingZerosValue.setScale(scale, RoundingMode.UNNECESSARY);
} catch (ArithmeticException e) {
return cp;
}
Expression castChild = left.child();
if (!(castChild.getDataType() instanceof DecimalV3Type)) {
throw new AnalysisException("cast child's type should be DecimalV3Type, but its type is "
+ castChild.getDataType().toSql());
}
DecimalV3Type leftType = (DecimalV3Type) castChild.getDataType();
if (scale <= leftType.getScale() && precision - scale <= leftType.getRange()) {
// precision and scale of literal all smaller than left, we don't need the cast
DecimalV3Literal newRight = new DecimalV3Literal(
DecimalV3Type.createDecimalV3TypeLooseCheck(leftType.getPrecision(), leftType.getScale()),
trailingZerosValue.setScale(leftType.getScale(), RoundingMode.UNNECESSARY));
return cp.withChildren(castChild, newRight);
} else {
return cp;
}
}
}