SimplifyRange.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.expression.rules;
import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Not;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.types.BooleanType;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Range;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Iterator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;
/**
* This class implements the function to simplify expression range.
* for example:
* a > 1 and a > 2 => a > 2
* a > 1 or a > 2 => a > 1
* a in (1,2,3) and a > 1 => a in (2,3)
* a in (1,2,3) and a in (3,4,5) => a = 3
* a in (1,2,3) and a in (4,5,6) => false
* The logic is as follows:
* 1. for `And` expression.
* 1. extract conjunctions then build `ValueDesc` for each conjunction
* 2. grouping according to `reference`, `ValueDesc` in the same group can perform intersect
* for example:
* a > 1 and a > 2
* 1. a > 1 => RangeValueDesc((1...+∞)), a > 2 => RangeValueDesc((2...+∞))
* 2. (1...+∞) intersect (2...+∞) => (2...+∞)
* 2. for `Or` expression (similar to `And`).
* todo: support a > 10 and (a < 10 or a > 20 ) => a > 20
*/
public class SimplifyRange implements ExpressionPatternRuleFactory {
public static final SimplifyRange INSTANCE = new SimplifyRange();
@Override
public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
return ImmutableList.of(
matchesTopType(CompoundPredicate.class)
.thenApply(ctx -> SimplifyRange.rewrite(ctx.expr, ctx.rewriteContext))
);
}
/** rewrite */
public static Expression rewrite(CompoundPredicate expr, ExpressionRewriteContext context) {
ValueDesc valueDesc = expr.accept(new RangeInference(), context);
Expression toExpr = valueDesc.toExpression();
if (toExpr == null) {
// this mean cannot simplify
return valueDesc.toExpr;
}
return toExpr;
}
private static class RangeInference extends ExpressionVisitor<ValueDesc, ExpressionRewriteContext> {
@Override
public ValueDesc visit(Expression expr, ExpressionRewriteContext context) {
return new UnknownValue(context, expr);
}
private ValueDesc buildRange(ExpressionRewriteContext context, ComparisonPredicate predicate) {
Expression right = predicate.child(1);
if (right.isNullLiteral()) {
return new UnknownValue(context, predicate);
}
// only handle `NumericType` and `DateLikeType`
if (right.isLiteral() && (right.getDataType().isNumericType() || right.getDataType().isDateLikeType())) {
return ValueDesc.range(context, predicate);
}
return new UnknownValue(context, predicate);
}
@Override
public ValueDesc visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) {
return buildRange(context, greaterThan);
}
@Override
public ValueDesc visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) {
return buildRange(context, greaterThanEqual);
}
@Override
public ValueDesc visitLessThan(LessThan lessThan, ExpressionRewriteContext context) {
return buildRange(context, lessThan);
}
@Override
public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) {
return buildRange(context, lessThanEqual);
}
@Override
public ValueDesc visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) {
return buildRange(context, equalTo);
}
@Override
public ValueDesc visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
// only handle `NumericType` and `DateLikeType`
if (inPredicate.getOptions().size() <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE
&& ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions())
&& (ExpressionUtils.matchNumericType(inPredicate.getOptions())
|| ExpressionUtils.matchDateLikeType(inPredicate.getOptions()))) {
return ValueDesc.discrete(context, inPredicate);
}
return new UnknownValue(context, inPredicate);
}
@Override
public ValueDesc visitAnd(And and, ExpressionRewriteContext context) {
return simplify(context, and, ExpressionUtils.extractConjunction(and),
ValueDesc::intersect, ExpressionUtils::and);
}
@Override
public ValueDesc visitOr(Or or, ExpressionRewriteContext context) {
return simplify(context, or, ExpressionUtils.extractDisjunction(or),
ValueDesc::union, ExpressionUtils::or);
}
private ValueDesc simplify(ExpressionRewriteContext context,
Expression originExpr, List<Expression> predicates,
BinaryOperator<ValueDesc> op, BinaryOperator<Expression> exprOp) {
Multimap<Expression, ValueDesc> groupByReference
= Multimaps.newListMultimap(new LinkedHashMap<>(), ArrayList::new);
for (Expression predicate : predicates) {
ValueDesc valueDesc = predicate.accept(this, null);
List<ValueDesc> valueDescs = (List<ValueDesc>) groupByReference.get(valueDesc.reference);
valueDescs.add(valueDesc);
}
List<ValueDesc> valuePerRefs = Lists.newArrayList();
for (Entry<Expression, Collection<ValueDesc>> referenceValues : groupByReference.asMap().entrySet()) {
List<ValueDesc> valuePerReference = (List) referenceValues.getValue();
// merge per reference
ValueDesc simplifiedValue = valuePerReference.get(0);
for (int i = 1; i < valuePerReference.size(); i++) {
simplifiedValue = op.apply(simplifiedValue, valuePerReference.get(i));
}
valuePerRefs.add(simplifiedValue);
}
if (valuePerRefs.size() == 1) {
return valuePerRefs.get(0);
}
// use UnknownValue to wrap different references
return new UnknownValue(context, originExpr, valuePerRefs, exprOp);
}
}
private abstract static class ValueDesc {
ExpressionRewriteContext context;
Expression toExpr;
Expression reference;
public ValueDesc(ExpressionRewriteContext context, Expression reference, Expression toExpr) {
this.context = context;
this.toExpr = toExpr;
this.reference = reference;
}
public abstract ValueDesc union(ValueDesc other);
public static ValueDesc union(ExpressionRewriteContext context,
RangeValue range, DiscreteValue discrete, boolean reverseOrder) {
long count = discrete.values.stream().filter(x -> range.range.test(x)).count();
if (count == discrete.values.size()) {
return range;
}
Expression toExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(range.toExpr, discrete.toExpr), context);
List<ValueDesc> sourceValues = reverseOrder
? ImmutableList.of(discrete, range)
: ImmutableList.of(range, discrete);
return new UnknownValue(context, toExpr, sourceValues, ExpressionUtils::or);
}
public abstract ValueDesc intersect(ValueDesc other);
public static ValueDesc intersect(ExpressionRewriteContext context, RangeValue range, DiscreteValue discrete) {
// Since in-predicate's options is a list, the discrete values need to kept options' order.
// If not keep options' order, the result in-predicate's option list will not equals to
// the input in-predicate, later nereids will need to simplify the new in-predicate,
// then cause dead loop.
Set<Literal> newValues = discrete.values.stream()
.filter(x -> range.range.contains(x))
.collect(Collectors.toCollection(
() -> Sets.newLinkedHashSetWithExpectedSize(discrete.values.size())));
if (!newValues.isEmpty()) {
return new DiscreteValue(context, discrete.reference, discrete.toExpr, newValues);
}
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(range.toExpr, discrete.toExpr), context);
return new EmptyValue(context, range.reference, originExpr);
}
public abstract Expression toExpression();
public static ValueDesc range(ExpressionRewriteContext context, ComparisonPredicate predicate) {
Literal value = (Literal) predicate.right();
if (predicate instanceof EqualTo) {
return new DiscreteValue(context, predicate.left(), predicate, Sets.newHashSet(value));
}
RangeValue rangeValue = new RangeValue(context, predicate.left(), predicate);
if (predicate instanceof GreaterThanEqual) {
rangeValue.range = Range.atLeast(value);
} else if (predicate instanceof GreaterThan) {
rangeValue.range = Range.greaterThan(value);
} else if (predicate instanceof LessThanEqual) {
rangeValue.range = Range.atMost(value);
} else if (predicate instanceof LessThan) {
rangeValue.range = Range.lessThan(value);
}
return rangeValue;
}
public static ValueDesc discrete(ExpressionRewriteContext context, InPredicate in) {
// Since in-predicate's options is a list, the discrete values need to kept options' order.
// If not keep options' order, the result in-predicate's option list will not equals to
// the input in-predicate, later nereids will need to simplify the new in-predicate,
// then cause dead loop.
// Set<Literal> literals = (Set) Utils.fastToImmutableSet(in.getOptions());
Set<Literal> literals = in.getOptions().stream()
.map(Literal.class::cast)
.collect(Collectors.toCollection(
() -> Sets.newLinkedHashSetWithExpectedSize(in.getOptions().size())));
return new DiscreteValue(context, in.getCompareExpr(), in, literals);
}
}
private static class EmptyValue extends ValueDesc {
public EmptyValue(ExpressionRewriteContext context, Expression reference, Expression toExpr) {
super(context, reference, toExpr);
}
@Override
public ValueDesc union(ValueDesc other) {
return other;
}
@Override
public ValueDesc intersect(ValueDesc other) {
return this;
}
@Override
public Expression toExpression() {
if (reference.nullable()) {
return new And(new IsNull(reference), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.FALSE;
}
}
}
/**
* use @see com.google.common.collect.Range to wrap `ComparisonPredicate`
* for example:
* a > 1 => (1...+∞)
*/
private static class RangeValue extends ValueDesc {
Range<Literal> range;
public RangeValue(ExpressionRewriteContext context, Expression reference, Expression toExpr) {
super(context, reference, toExpr);
}
@Override
public ValueDesc union(ValueDesc other) {
if (other instanceof EmptyValue) {
return other.union(this);
}
if (other instanceof RangeValue) {
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(toExpr, other.toExpr), context);
RangeValue o = (RangeValue) other;
if (range.isConnected(o.range)) {
RangeValue rangeValue = new RangeValue(context, reference, originExpr);
rangeValue.range = range.span(o.range);
return rangeValue;
}
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::or);
}
if (other instanceof DiscreteValue) {
return union(context, this, (DiscreteValue) other, false);
}
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::or);
}
@Override
public ValueDesc intersect(ValueDesc other) {
if (other instanceof EmptyValue) {
return other.intersect(this);
}
if (other instanceof RangeValue) {
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(toExpr, other.toExpr), context);
RangeValue o = (RangeValue) other;
if (range.isConnected(o.range)) {
RangeValue rangeValue = new RangeValue(context, reference, originExpr);
rangeValue.range = range.intersection(o.range);
return rangeValue;
}
return new EmptyValue(context, reference, originExpr);
}
if (other instanceof DiscreteValue) {
return intersect(context, this, (DiscreteValue) other);
}
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::and);
}
@Override
public Expression toExpression() {
List<Expression> result = Lists.newArrayList();
if (range.hasLowerBound()) {
if (range.lowerBoundType() == BoundType.CLOSED) {
result.add(new GreaterThanEqual(reference, range.lowerEndpoint()));
} else {
result.add(new GreaterThan(reference, range.lowerEndpoint()));
}
}
if (range.hasUpperBound()) {
if (range.upperBoundType() == BoundType.CLOSED) {
result.add(new LessThanEqual(reference, range.upperEndpoint()));
} else {
result.add(new LessThan(reference, range.upperEndpoint()));
}
}
if (!result.isEmpty()) {
return ExpressionUtils.and(result);
} else {
if (reference.nullable()) {
return new Or(new Not(new IsNull(reference)), new NullLiteral(BooleanType.INSTANCE));
} else {
return BooleanLiteral.TRUE;
}
}
}
@Override
public String toString() {
return range == null ? "UnknownRange" : range.toString();
}
}
/**
* use `Set` to wrap `InPredicate`
* for example:
* a in (1,2,3) => [1,2,3]
*/
private static class DiscreteValue extends ValueDesc {
final Set<Literal> values;
public DiscreteValue(ExpressionRewriteContext context,
Expression reference, Expression toExpr, Set<Literal> values) {
super(context, reference, toExpr);
this.values = values;
}
@Override
public ValueDesc union(ValueDesc other) {
if (other instanceof EmptyValue) {
return other.union(this);
}
if (other instanceof DiscreteValue) {
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(toExpr, other.toExpr), context);
Set<Literal> otherValues = ((DiscreteValue) other).values;
Set<Literal> newValues = Sets.newLinkedHashSetWithExpectedSize(values.size() + otherValues.size());
newValues.addAll(values);
newValues.addAll(otherValues);
return new DiscreteValue(context, reference, originExpr, newValues);
}
if (other instanceof RangeValue) {
return union(context, (RangeValue) other, this, true);
}
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::or);
}
@Override
public ValueDesc intersect(ValueDesc other) {
if (other instanceof EmptyValue) {
return other.intersect(this);
}
if (other instanceof DiscreteValue) {
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(toExpr, other.toExpr), context);
Set<Literal> newValues = Sets.newLinkedHashSet(values);
newValues.retainAll(((DiscreteValue) other).values);
if (newValues.isEmpty()) {
return new EmptyValue(context, reference, originExpr);
} else {
return new DiscreteValue(context, reference, originExpr, newValues);
}
}
if (other instanceof RangeValue) {
return intersect(context, (RangeValue) other, this);
}
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::and);
}
@Override
public Expression toExpression() {
// NOTICE: it's related with `InPredicateToEqualToRule`
// They are same processes, so must change synchronously.
if (values.size() == 1) {
return new EqualTo(reference, values.iterator().next());
// this condition should as same as OrToIn, or else meet dead loop
} else if (values.size() < 2) {
Iterator<Literal> iterator = values.iterator();
return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next()));
} else {
return new InPredicate(reference, Lists.newArrayList(values));
}
}
@Override
public String toString() {
return values.toString();
}
}
/**
* Represents processing result.
*/
private static class UnknownValue extends ValueDesc {
private final List<ValueDesc> sourceValues;
private final BinaryOperator<Expression> mergeExprOp;
private UnknownValue(ExpressionRewriteContext context, Expression expr) {
super(context, expr, expr);
sourceValues = ImmutableList.of();
mergeExprOp = null;
}
public UnknownValue(ExpressionRewriteContext context, Expression toExpr,
List<ValueDesc> sourceValues, BinaryOperator<Expression> mergeExprOp) {
super(context, getReference(sourceValues, toExpr), toExpr);
this.sourceValues = ImmutableList.copyOf(sourceValues);
this.mergeExprOp = mergeExprOp;
}
// reference is used to simplify multiple ValueDescs.
// when ValueDesc A op ValueDesc B, only A and B's references equals,
// can reduce them, like A op B = A.
// If A and B's reference not equal, A op B will always get UnknownValue(A op B).
//
// for example:
// 1. RangeValue(a < 10, reference=a) union RangeValue(a > 20, reference=a)
// = UnknownValue1(a < 10 or a > 20, reference=a)
// 2. RangeValue(a < 10, reference=a) union RangeValue(b > 20, reference=b)
// = UnknownValue2(a < 10 or b > 20, reference=(a < 10 or b > 20))
// then given EmptyValue(, reference=a) E,
// 1. since E and UnknownValue1's reference equals, then
// E union UnknownValue1 = E.union(UnknownValue1) = UnknownValue1,
// 2. since E and UnknownValue2's reference not equals, then
// E union UnknownValue2 = UnknownValue3(E union UnknownValue2, reference=E union UnknownValue2)
private static Expression getReference(List<ValueDesc> sourceValues, Expression toExpr) {
Expression reference = sourceValues.get(0).reference;
for (int i = 1; i < sourceValues.size(); i++) {
if (!reference.equals(sourceValues.get(i).reference)) {
return toExpr;
}
}
return reference;
}
@Override
public ValueDesc union(ValueDesc other) {
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.or(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::or);
}
@Override
public ValueDesc intersect(ValueDesc other) {
Expression originExpr = FoldConstantRuleOnFE.evaluate(
ExpressionUtils.and(toExpr, other.toExpr), context);
return new UnknownValue(context, originExpr,
ImmutableList.of(this, other), ExpressionUtils::and);
}
@Override
public Expression toExpression() {
if (sourceValues.isEmpty()) {
return toExpr;
}
Expression result = sourceValues.get(0).toExpression();
for (int i = 1; i < sourceValues.size(); i++) {
result = mergeExprOp.apply(result, sourceValues.get(i).toExpression());
}
result = FoldConstantRuleOnFE.evaluate(result, context);
// ATTN: we must return original expr, because OrToIn is implemented with MutableState,
// newExpr will lose these states leading to dead loop by OrToIn -> SimplifyRange -> FoldConstantByFE
if (result.equals(toExpr)) {
return toExpr;
}
return result;
}
}
}