BetweenToEqual.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.List;
import java.util.Map;
/**
* f(A, B) between 1 and 1 => f(A, B) = 1
*
*/
public class BetweenToEqual implements ExpressionPatternRuleFactory {
public static BetweenToEqual INSTANCE = new BetweenToEqual();
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesType(And.class).then(BetweenToEqual::rewriteBetweenToEqual)
.toRule(ExpressionRuleType.BETWEEN_TO_EQUAL)
);
}
private static Expression rewriteBetweenToEqual(And and) {
List<Expression> conjuncts = ExpressionUtils.extractConjunction(and);
Map<Expression, List<ComparisonPredicate>> betweenCandidate = Maps.newHashMap();
for (Expression conj : conjuncts) {
if (isCandidate(conj)) {
conj = normalizeCandidate((ComparisonPredicate) conj);
Expression varPart = conj.child(0);
betweenCandidate.computeIfAbsent(varPart, k -> Lists.newArrayList());
betweenCandidate.get(varPart).add((ComparisonPredicate) conj);
}
}
List<EqualTo> equals = Lists.newArrayList();
List<Expression> equalsKey = Lists.newArrayList();
for (Expression varPart : betweenCandidate.keySet()) {
List<ComparisonPredicate> candidates = betweenCandidate.get(varPart);
if (candidates.size() == 2 && greaterEqualAndLessEqual(candidates.get(0), candidates.get(1))) {
if (candidates.get(0).child(1).equals(candidates.get(1).child(1))) {
equals.add(new EqualTo(candidates.get(0).child(0), candidates.get(0).child(1)));
equalsKey.add(candidates.get(0).child(0));
}
}
}
if (equals.isEmpty()) {
return null;
} else {
List<Expression> newConjuncts = Lists.newArrayList(equals);
for (Expression conj : conjuncts) {
if (isCandidate(conj)) {
conj = normalizeCandidate((ComparisonPredicate) conj);
if (equalsKey.contains(conj.child(0))) {
continue;
}
}
newConjuncts.add(conj);
}
return ExpressionUtils.and(newConjuncts);
}
}
// A >= a
// A <= a
// A is expr, a is literal
private static boolean isCandidate(Expression expr) {
if (expr instanceof GreaterThanEqual || expr instanceof LessThanEqual) {
return expr.child(0) instanceof Literal && !(expr.child(1) instanceof Literal)
|| expr.child(1) instanceof Literal && !(expr.child(0) instanceof Literal);
}
return false;
}
private static Expression normalizeCandidate(ComparisonPredicate expr) {
if (expr.child(1) instanceof Literal) {
return expr;
} else {
return expr.withChildren(expr.child(1), expr.child(0));
}
}
private static boolean greaterEqualAndLessEqual(ComparisonPredicate cmp1, ComparisonPredicate cmp2) {
return cmp1 instanceof GreaterThanEqual && cmp2 instanceof LessThanEqual
|| (cmp1 instanceof LessThanEqual && cmp2 instanceof GreaterThanEqual);
}
}