RangeInference.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.rules.expression.rules;

import org.apache.doris.nereids.rules.expression.ExpressionRewriteContext;
import org.apache.doris.nereids.trees.expressions.And;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.InPredicate;
import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Or;
import org.apache.doris.nereids.trees.expressions.literal.ComparableLiteral;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
import org.apache.doris.nereids.util.ExpressionUtils;

import com.google.common.collect.BoundType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import com.google.common.collect.Multimaps;
import com.google.common.collect.Range;
import com.google.common.collect.RangeSet;
import com.google.common.collect.Sets;
import com.google.common.collect.TreeRangeSet;

import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map.Entry;
import java.util.Set;
import java.util.function.BinaryOperator;
import java.util.stream.Collectors;

/**
 * collect range of expression
 */
public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, ExpressionRewriteContext> {

    /*
     * get expression's value desc.
     */
    public ValueDesc getValue(Expression expr, ExpressionRewriteContext context) {
        return expr.accept(new RangeInference(), context);
    }

    @Override
    public ValueDesc visit(Expression expr, ExpressionRewriteContext context) {
        return new UnknownValue(context, expr);
    }

    private ValueDesc buildRange(ExpressionRewriteContext context, ComparisonPredicate predicate) {
        Expression right = predicate.child(1);
        if (right.isNullLiteral()) {
            return new UnknownValue(context, predicate);
        }
        // only handle `NumericType` and `DateLikeType`
        if (right instanceof ComparableLiteral
                && (right.getDataType().isNumericType() || right.getDataType().isDateLikeType())) {
            return ValueDesc.range(context, predicate);
        }
        return new UnknownValue(context, predicate);
    }

    @Override
    public ValueDesc visitGreaterThan(GreaterThan greaterThan, ExpressionRewriteContext context) {
        return buildRange(context, greaterThan);
    }

    @Override
    public ValueDesc visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, ExpressionRewriteContext context) {
        return buildRange(context, greaterThanEqual);
    }

    @Override
    public ValueDesc visitLessThan(LessThan lessThan, ExpressionRewriteContext context) {
        return buildRange(context, lessThan);
    }

    @Override
    public ValueDesc visitLessThanEqual(LessThanEqual lessThanEqual, ExpressionRewriteContext context) {
        return buildRange(context, lessThanEqual);
    }

    @Override
    public ValueDesc visitEqualTo(EqualTo equalTo, ExpressionRewriteContext context) {
        return buildRange(context, equalTo);
    }

    @Override
    public ValueDesc visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) {
        // only handle `NumericType` and `DateLikeType`
        if (inPredicate.getOptions().size() <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE
                && ExpressionUtils.isAllNonNullComparableLiteral(inPredicate.getOptions())
                && (ExpressionUtils.matchNumericType(inPredicate.getOptions())
                || ExpressionUtils.matchDateLikeType(inPredicate.getOptions()))) {
            return ValueDesc.discrete(context, inPredicate);
        }
        return new UnknownValue(context, inPredicate);
    }

    @Override
    public ValueDesc visitAnd(And and, ExpressionRewriteContext context) {
        return simplify(context, ExpressionUtils.extractConjunction(and),
                ValueDesc::intersect, true);
    }

    @Override
    public ValueDesc visitOr(Or or, ExpressionRewriteContext context) {
        return simplify(context, ExpressionUtils.extractDisjunction(or),
                ValueDesc::union, false);
    }

    private ValueDesc simplify(ExpressionRewriteContext context, List<Expression> predicates,
            BinaryOperator<ValueDesc> op, boolean isAnd) {

        boolean convertIsNullToEmptyValue = isAnd && predicates.stream().anyMatch(expr -> expr instanceof NullLiteral);
        Multimap<Expression, ValueDesc> groupByReference
                = Multimaps.newListMultimap(new LinkedHashMap<>(), ArrayList::new);
        for (Expression predicate : predicates) {
            // EmptyValue(a) = IsNull(a) and null,  it doesn't equals to IsNull(a).
            // Only the and expression contains at least a null literal in its conjunctions,
            // then EmptyValue(a) can equivalent to IsNull(a).
            // so for expression and(IsNull(a), IsNull(b), ..., null), a, b can convert to EmptyValue.
            // What's more, if a is not nullable, then EmptyValue(a) always equals to IsNull(a),
            // but we don't consider this case here, we should fold IsNull(a) to FALSE using other rule.
            ValueDesc valueDesc = null;
            if (convertIsNullToEmptyValue && predicate instanceof IsNull) {
                valueDesc = new EmptyValue(context, ((IsNull) predicate).child());
            } else {
                valueDesc = predicate.accept(this, context);
            }
            List<ValueDesc> valueDescs = (List<ValueDesc>) groupByReference.get(valueDesc.reference);
            valueDescs.add(valueDesc);
        }

        List<ValueDesc> valuePerRefs = Lists.newArrayList();
        for (Entry<Expression, Collection<ValueDesc>> referenceValues : groupByReference.asMap().entrySet()) {
            Expression reference = referenceValues.getKey();
            List<ValueDesc> valuePerReference = (List) referenceValues.getValue();
            if (!isAnd) {
                valuePerReference = ValueDesc.unionDiscreteAndRange(context, reference, valuePerReference);
            }

            // merge per reference
            ValueDesc simplifiedValue = valuePerReference.get(0);
            for (int i = 1; i < valuePerReference.size(); i++) {
                simplifiedValue = op.apply(simplifiedValue, valuePerReference.get(i));
            }

            valuePerRefs.add(simplifiedValue);
        }

        if (valuePerRefs.size() == 1) {
            return valuePerRefs.get(0);
        }

        // use UnknownValue to wrap different references
        return new UnknownValue(context, valuePerRefs, isAnd);
    }

    /**
     * value desc
     */
    public abstract static class ValueDesc {
        ExpressionRewriteContext context;
        Expression reference;

        public ValueDesc(ExpressionRewriteContext context, Expression reference) {
            this.context = context;
            this.reference = reference;
        }

        public Expression getReference() {
            return reference;
        }

        public ExpressionRewriteContext getExpressionRewriteContext() {
            return context;
        }

        public abstract ValueDesc union(ValueDesc other);

        /** or */
        public static ValueDesc union(ExpressionRewriteContext context,
                RangeValue range, DiscreteValue discrete, boolean reverseOrder) {
            if (discrete.values.stream().allMatch(x -> range.range.test(x))) {
                return range;
            }
            List<ValueDesc> sourceValues = reverseOrder
                    ? ImmutableList.of(discrete, range)
                    : ImmutableList.of(range, discrete);
            return new UnknownValue(context, sourceValues, false);
        }

        /** merge discrete and ranges only, no merge other value desc */
        public static List<ValueDesc> unionDiscreteAndRange(ExpressionRewriteContext context,
                Expression reference, List<ValueDesc> valueDescs) {
            // Since in-predicate's options is a list, the discrete values need to kept options' order.
            // If not keep options' order, the result in-predicate's option list will not equals to
            // the input in-predicate, later nereids will need to simplify the new in-predicate,
            // then cause dead loop.
            Set<ComparableLiteral> discreteValues = Sets.newLinkedHashSet();
            for (ValueDesc valueDesc : valueDescs) {
                if (valueDesc instanceof DiscreteValue) {
                    discreteValues.addAll(((DiscreteValue) valueDesc).getValues());
                }
            }

            // for 'a > 8 or a = 8', then range (8, +00) can convert to [8, +00)
            RangeSet<ComparableLiteral> rangeSet = TreeRangeSet.create();
            for (ValueDesc valueDesc : valueDescs) {
                if (valueDesc instanceof RangeValue) {
                    Range<ComparableLiteral> range = ((RangeValue) valueDesc).range;
                    rangeSet.add(range);
                    if (range.hasLowerBound()
                            && range.lowerBoundType() == BoundType.OPEN
                            && discreteValues.contains(range.lowerEndpoint())) {
                        rangeSet.add(Range.singleton(range.lowerEndpoint()));
                    }
                    if (range.hasUpperBound()
                            && range.upperBoundType() == BoundType.OPEN
                            && discreteValues.contains(range.upperEndpoint())) {
                        rangeSet.add(Range.singleton(range.upperEndpoint()));
                    }
                }
            }

            if (!rangeSet.isEmpty()) {
                discreteValues.removeIf(x -> rangeSet.contains(x));
            }

            List<ValueDesc> result = Lists.newArrayListWithExpectedSize(valueDescs.size());
            if (!discreteValues.isEmpty()) {
                result.add(new DiscreteValue(context, reference, discreteValues));
            }
            for (Range<ComparableLiteral> range : rangeSet.asRanges()) {
                result.add(new RangeValue(context, reference, range));
            }
            for (ValueDesc valueDesc : valueDescs) {
                if (!(valueDesc instanceof DiscreteValue) && !(valueDesc instanceof RangeValue)) {
                    result.add(valueDesc);
                }
            }

            return result;
        }

        /** intersect */
        public abstract ValueDesc intersect(ValueDesc other);

        /** intersect */
        public static ValueDesc intersect(ExpressionRewriteContext context, RangeValue range, DiscreteValue discrete) {
            // Since in-predicate's options is a list, the discrete values need to kept options' order.
            // If not keep options' order, the result in-predicate's option list will not equals to
            // the input in-predicate, later nereids will need to simplify the new in-predicate,
            // then cause dead loop.
            Set<ComparableLiteral> newValues = discrete.values.stream().filter(x -> range.range.contains(x))
                    .collect(Collectors.toCollection(
                            () -> Sets.newLinkedHashSetWithExpectedSize(discrete.values.size())));
            if (newValues.isEmpty()) {
                return new EmptyValue(context, range.reference);
            } else {
                return new DiscreteValue(context, range.reference, newValues);
            }
        }

        private static ValueDesc range(ExpressionRewriteContext context, ComparisonPredicate predicate) {
            ComparableLiteral value = (ComparableLiteral) predicate.right();
            if (predicate instanceof EqualTo) {
                return new DiscreteValue(context, predicate.left(), Sets.newHashSet(value));
            }
            Range<ComparableLiteral> range = null;
            if (predicate instanceof GreaterThanEqual) {
                range = Range.atLeast(value);
            } else if (predicate instanceof GreaterThan) {
                range = Range.greaterThan(value);
            } else if (predicate instanceof LessThanEqual) {
                range = Range.atMost(value);
            } else if (predicate instanceof LessThan) {
                range = Range.lessThan(value);
            }

            return new RangeValue(context, predicate.left(), range);
        }

        private static ValueDesc discrete(ExpressionRewriteContext context, InPredicate in) {
            // Since in-predicate's options is a list, the discrete values need to kept options' order.
            // If not keep options' order, the result in-predicate's option list will not equals to
            // the input in-predicate, later nereids will need to simplify the new in-predicate,
            // then cause dead loop.
            // Set<ComparableLiteral> literals = (Set) Utils.fastToImmutableSet(in.getOptions());
            Set<ComparableLiteral> literals = in.getOptions().stream()
                    .map(ComparableLiteral.class::cast)
                    .collect(Collectors.toCollection(
                            () -> Sets.newLinkedHashSetWithExpectedSize(in.getOptions().size())));
            return new DiscreteValue(context, in.getCompareExpr(), literals);
        }
    }

    /**
     * empty range
     */
    public static class EmptyValue extends ValueDesc {

        public EmptyValue(ExpressionRewriteContext context, Expression reference) {
            super(context, reference);
        }

        @Override
        public ValueDesc union(ValueDesc other) {
            return other;
        }

        @Override
        public ValueDesc intersect(ValueDesc other) {
            return this;
        }
    }

    /**
     * use @see com.google.common.collect.Range to wrap `ComparisonPredicate`
     * for example:
     * a > 1 => (1...+∞)
     */
    public static class RangeValue extends ValueDesc {
        Range<ComparableLiteral> range;

        public RangeValue(ExpressionRewriteContext context, Expression reference, Range<ComparableLiteral> range) {
            super(context, reference);
            this.range = range;
        }

        public Range<ComparableLiteral> getRange() {
            return range;
        }

        @Override
        public ValueDesc union(ValueDesc other) {
            if (other instanceof EmptyValue) {
                return other.union(this);
            }
            if (other instanceof RangeValue) {
                RangeValue o = (RangeValue) other;
                if (range.isConnected(o.range)) {
                    return new RangeValue(context, reference, range.span(o.range));
                }
                return new UnknownValue(context, ImmutableList.of(this, other), false);
            }
            if (other instanceof DiscreteValue) {
                return union(context, this, (DiscreteValue) other, false);
            }
            return new UnknownValue(context, ImmutableList.of(this, other), false);
        }

        @Override
        public ValueDesc intersect(ValueDesc other) {
            if (other instanceof EmptyValue) {
                return other.intersect(this);
            }
            if (other instanceof RangeValue) {
                RangeValue o = (RangeValue) other;
                if (range.isConnected(o.range)) {
                    Range<ComparableLiteral> newRange = range.intersection(o.range);
                    if (!newRange.isEmpty()) {
                        if (newRange.hasLowerBound() && newRange.hasUpperBound()
                                && newRange.lowerEndpoint().compareTo(newRange.upperEndpoint()) == 0
                                && newRange.lowerBoundType() == BoundType.CLOSED
                                && newRange.lowerBoundType() == BoundType.CLOSED) {
                            return new DiscreteValue(context, reference, Sets.newHashSet(newRange.lowerEndpoint()));
                        } else {
                            return new RangeValue(context, reference, newRange);
                        }
                    }
                }
                return new EmptyValue(context, reference);
            }
            if (other instanceof DiscreteValue) {
                return intersect(context, this, (DiscreteValue) other);
            }
            return new UnknownValue(context, ImmutableList.of(this, other), true);
        }

        @Override
        public String toString() {
            return range == null ? "UnknownRange" : range.toString();
        }
    }

    /**
     * use `Set` to wrap `InPredicate`
     * for example:
     * a in (1,2,3) => [1,2,3]
     */
    public static class DiscreteValue extends ValueDesc {
        final Set<ComparableLiteral> values;

        public DiscreteValue(ExpressionRewriteContext context,
                Expression reference, Set<ComparableLiteral> values) {
            super(context, reference);
            this.values = values;
        }

        public Set<ComparableLiteral> getValues() {
            return values;
        }

        @Override
        public ValueDesc union(ValueDesc other) {
            if (other instanceof EmptyValue) {
                return other.union(this);
            }
            if (other instanceof DiscreteValue) {
                Set<ComparableLiteral> newValues = Sets.newLinkedHashSet();
                newValues.addAll(((DiscreteValue) other).values);
                newValues.addAll(this.values);
                return new DiscreteValue(context, reference, newValues);
            }
            if (other instanceof RangeValue) {
                return union(context, (RangeValue) other, this, true);
            }
            return new UnknownValue(context, ImmutableList.of(this, other), false);
        }

        @Override
        public ValueDesc intersect(ValueDesc other) {
            if (other instanceof EmptyValue) {
                return other.intersect(this);
            }
            if (other instanceof DiscreteValue) {
                Set<ComparableLiteral> newValues = Sets.newLinkedHashSet();
                newValues.addAll(this.values);
                newValues.retainAll(((DiscreteValue) other).values);
                if (newValues.isEmpty()) {
                    return new EmptyValue(context, reference);
                } else {
                    return new DiscreteValue(context, reference, newValues);
                }
            }
            if (other instanceof RangeValue) {
                return intersect(context, (RangeValue) other, this);
            }
            return new UnknownValue(context, ImmutableList.of(this, other), true);
        }

        @Override
        public String toString() {
            return values.toString();
        }
    }

    /**
     * Represents processing result.
     */
    public static class UnknownValue extends ValueDesc {
        private final List<ValueDesc> sourceValues;
        private final boolean isAnd;

        private UnknownValue(ExpressionRewriteContext context, Expression expr) {
            super(context, expr);
            sourceValues = ImmutableList.of();
            isAnd = false;
        }

        private UnknownValue(ExpressionRewriteContext context,
                List<ValueDesc> sourceValues, boolean isAnd) {
            super(context, getReference(context, sourceValues, isAnd));
            this.sourceValues = ImmutableList.copyOf(sourceValues);
            this.isAnd = isAnd;
        }

        // reference is used to simplify multiple ValueDescs.
        // when ValueDesc A op ValueDesc B, only A and B's references equals,
        // can reduce them, like A op B = A.
        // If A and B's reference not equal, A op B will always get UnknownValue(A op B).
        //
        // for example:
        // 1. RangeValue(a < 10, reference=a) union RangeValue(a > 20, reference=a)
        //    = UnknownValue1(a < 10 or a > 20, reference=a)
        // 2. RangeValue(a < 10, reference=a) union RangeValue(b > 20, reference=b)
        //    = UnknownValue2(a < 10 or b > 20, reference=(a < 10 or b > 20))
        // then given EmptyValue(, reference=a) E,
        // 1. since E and UnknownValue1's reference equals, then
        //    E union UnknownValue1 = E.union(UnknownValue1) = UnknownValue1,
        // 2. since E and UnknownValue2's reference not equals, then
        //    E union UnknownValue2 = UnknownValue3(E union UnknownValue2, reference=E union UnknownValue2)
        private static Expression getReference(ExpressionRewriteContext context,
                List<ValueDesc> sourceValues, boolean isAnd) {
            Expression reference = sourceValues.get(0).reference;
            for (int i = 1; i < sourceValues.size(); i++) {
                if (!reference.equals(sourceValues.get(i).reference)) {
                    return SimplifyRange.INSTANCE.getExpression(context, sourceValues, isAnd);
                }
            }
            return reference;
        }

        public List<ValueDesc> getSourceValues() {
            return sourceValues;
        }

        public boolean isAnd() {
            return this.isAnd;
        }

        @Override
        public ValueDesc union(ValueDesc other) {
            // for RangeValue/DiscreteValue/UnknownValue, when union with EmptyValue,
            // call EmptyValue.union(this) => this
            if (other instanceof EmptyValue) {
                return other.union(this);
            }
            return new UnknownValue(context, ImmutableList.of(this, other), false);
        }

        @Override
        public ValueDesc intersect(ValueDesc other) {
            // for RangeValue/DiscreteValue/UnknownValue, when intersect with EmptyValue,
            // call EmptyValue.intersect(this) => EmptyValue
            if (other instanceof EmptyValue) {
                return other.intersect(this);
            }
            return new UnknownValue(context, ImmutableList.of(this, other), true);
        }
    }
}