UnequalPredicateInfer.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.catalog.Column;
import org.apache.doris.catalog.TableIf;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.exceptions.AnalysisException;
import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.GreaterThan;
import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
import org.apache.doris.nereids.trees.expressions.LessThan;
import org.apache.doris.nereids.trees.expressions.LessThanEqual;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.ExpressionTrait;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.util.PredicateInferUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
/**
* this class do these things:
* {@code
* 1. t1.a=t2.b t2.b=t3.c -> t1.a=t2.b t2.b=t3.c (reserve all three condition)
* 2. remove useless equal predicates(e.g. t1.a=t1.b t1.a=1 t1.b=1 -> t1.a=1 t1.b=1. t1.a=t1.b is removed)
* 3. do unequalPredicateInfer(e.g. t1.a<t2.b and t2.b<1 -> t1.a<1 and t1.a<t2.b and t2.b<1)
* 4. remove useless unequal predicates(e.g. t1.a<t1.b t1.a<1 t1.b<1 -> t1.a<t1.b t1.b<1)}
* */
public class UnequalPredicateInfer {
/**InferenceGraph*/
public static class InferenceGraph {
/** relation between inputExprs */
public enum Relation {
GT,
GTE,
EQ,
UNDEFINED
}
private static class PairAndRelation {
private final Pair<Expression, Expression> pair;
private final Relation relation;
private PairAndRelation(Pair<Expression, Expression> p, Relation r) {
pair = p;
relation = r;
}
}
// Save and infer the relationship between inputExpressions
private final Relation[][] graph;
// slots or literal at both ends of the input predicate, and its index corresponds to the one in the graph.
private final List<Expression> usedExprs = new ArrayList<>();
// predicates used in derivation, this is used in chooseInputPredicates
private final List<ComparisonPredicate> usedPredicates = new ArrayList<>();
// usedPredicatesPairs has same length with usedPredicates,
// usedPredicatesPairs[i] and usedPredicates[i] correspond to same predicates
// usedPredicatesPairs is extracted from cast and used in graph
private final List<PairAndRelation> usedPredicatesPairs = new ArrayList<>();
// Elements and their indexes in usedExprs
private final Map<Expression, Integer> usedExprPosition = new HashMap<>();
// size of usedExprs
private final int size;
// not use input predicates
private final List<Expression> otherPredicates = new ArrayList<>();
/**Constructor*/
public InferenceGraph(Set<Expression> inputs) {
Set<Expression> inputExpressionSet = new HashSet<>();
for (Expression input : inputs) {
if (!(input instanceof ComparisonPredicate)) {
otherPredicates.add(input);
continue;
}
ComparisonPredicate comparison = (ComparisonPredicate) input;
if (comparison.left().equals(comparison.right())) {
otherPredicates.add(comparison);
continue;
}
if (comparison.left() instanceof NullLiteral || comparison.right() instanceof NullLiteral) {
otherPredicates.add(comparison);
continue;
}
Set<Slot> leftSlots = comparison.left().getInputSlots();
Set<Slot> rightSlots = comparison.right().getInputSlots();
if (leftSlots.isEmpty() && rightSlots.isEmpty()) {
otherPredicates.add(comparison);
continue;
}
ComparisonPredicate commute;
if (comparison instanceof LessThan || comparison instanceof LessThanEqual) {
commute = (ComparisonPredicate) comparison.commute().withInferred(comparison.isInferred());
} else if (comparison instanceof GreaterThan || comparison instanceof GreaterThanEqual
|| comparison instanceof EqualTo) {
commute = comparison;
} else {
otherPredicates.add(comparison);
continue;
}
Optional<Pair<Expression, Expression>> optionalPair = PredicateInferUtils.getPairFromCast(commute);
if (!optionalPair.isPresent()) {
otherPredicates.add(comparison);
continue;
}
Pair<Expression, Expression> pair = optionalPair.get();
if (!PredicateInferUtils.isSlotOrLiteral(pair.first)
|| !PredicateInferUtils.isSlotOrLiteral(pair.second)) {
otherPredicates.add(comparison);
continue;
}
inputExpressionSet.add(pair.first);
inputExpressionSet.add(pair.second);
usedPredicates.add(comparison);
usedPredicatesPairs.add(new PairAndRelation(pair, getType(commute)));
}
usedExprs.addAll(inputExpressionSet);
// Sorting is required to ensure the stability of the plan shape
// and to ensure that the same results are output in the derivation of d>1 d=c and c>1 d=c
usedExprs.sort(Comparator.comparing(ExpressionTrait::toSql));
size = usedExprs.size();
for (int i = 0; i < size; ++i) {
usedExprPosition.put(usedExprs.get(i), i);
}
graph = new Relation[size][size];
initGraph(graph);
// Add edges to the graph.
for (PairAndRelation predicatesPair : usedPredicatesPairs) {
int l = usedExprPosition.get(predicatesPair.pair.first);
int r = usedExprPosition.get(predicatesPair.pair.second);
set(graph, l, r, predicatesPair.relation);
}
}
public void initGraph(Relation[][] g) {
for (int i = 0; i < size; ++i) {
for (int j = 0; j < size; ++j) {
g[i][j] = Relation.UNDEFINED;
}
}
}
private void connect(Relation[][] graph, int left, int right, int mid) {
if (graph[left][right] != Relation.EQ) {
if (graph[left][mid] == Relation.EQ && graph[mid][right] == Relation.EQ) {
graph[left][right] = Relation.EQ;
}
}
if (graph[left][right] != Relation.GTE) {
if (graph[left][mid] == Relation.GTE && graph[mid][right] == Relation.EQ
|| graph[left][mid] == Relation.EQ && graph[mid][right] == Relation.GTE) {
graph[left][right] = Relation.GTE;
}
}
if (graph[left][right] != Relation.GT) {
if (graph[left][mid] == Relation.GT && graph[mid][right] != Relation.UNDEFINED
|| graph[left][mid] != Relation.UNDEFINED && graph[mid][right] == Relation.GT) {
graph[left][right] = Relation.GT;
}
}
}
// Calculate the relationship between left and right derived from mid
private Relation connectInThisPath(final Relation[][] graph, int left, int right, int mid) {
Relation deduceRelation = Relation.UNDEFINED;
if (graph[left][mid] == Relation.EQ && graph[mid][right] == Relation.EQ) {
deduceRelation = Relation.EQ;
}
if (graph[left][mid] == Relation.GTE && graph[mid][right] == Relation.EQ
|| graph[left][mid] == Relation.EQ && graph[mid][right] == Relation.GTE) {
deduceRelation = Relation.GTE;
}
if (graph[left][mid] == Relation.GT && graph[mid][right] != Relation.UNDEFINED
|| graph[left][mid] != Relation.UNDEFINED && graph[mid][right] == Relation.GT) {
deduceRelation = Relation.GT;
}
return deduceRelation;
}
/** use Floyd algorithm to deduce the inequality */
public void deduce(Relation[][] graph) {
for (int mid = 0; mid < size; ++mid) {
for (int left = 0; left < size; ++left) {
for (int right = 0; right < size; ++right) {
connect(graph, left, right, mid);
}
}
}
}
/**topoSort*/
public List<Integer> topoSort() {
ArrayList<Integer> order = new ArrayList<>();
order.ensureCapacity(size);
ArrayList<Boolean> visited = new ArrayList<>();
visited.ensureCapacity(size);
for (int i = 0; i < size; ++i) {
visited.add(false);
}
for (int i = 0; i < size; ++i) {
dfs(i, visited, order);
}
return order;
}
private void dfs(int node, List<Boolean> visited, List<Integer> order) {
if (visited.get(node)) {
return;
}
visited.set(node, true);
for (int i = 0; i < size; ++i) {
if (graph[node][i] == Relation.GT || graph[node][i] == Relation.GTE) {
dfs(i, visited, order);
}
}
order.add(node);
}
/**Determine whether the slots in a predicate come from only one table*/
private boolean isTableFilter(int left, int right) {
Set<String> qualifiers = new HashSet<>();
for (Slot slot : usedExprs.get(left).getInputSlots()) {
qualifiers.add(String.join(".", slot.getQualifier()));
}
for (Slot slot : usedExprs.get(right).getInputSlots()) {
qualifiers.add(String.join(".", slot.getQualifier()));
}
// TODO:
// isTableFilter(abs(t1.a)#1 = abs(t1.b)#2) will return true
// isTableFilter(abs(t1.a)#1 = abs(t2.b)#2) will also return true, which is wrong.
// because expr(e.g. abs(a) #1) qualifiers is empty.
// We cannot distinguish whether abs(t1.a)#1 = abs(t2.b)#2 is a TableFilter or not.
// current code may lead to some useful predicates be removed
return qualifiers.size() == 1;
}
private boolean hasIndexOrPartitionColumn(Expression left, Expression right) {
SlotReference checkSlot;
if (left instanceof SlotReference && right instanceof Literal) {
checkSlot = (SlotReference) left;
} else if (left instanceof Literal && right instanceof SlotReference) {
checkSlot = (SlotReference) right;
} else {
return false;
}
if (!checkSlot.isColumnFromTable()) {
return false;
}
Column column = checkSlot.getOriginalColumn().get();
if (column.isKey()) {
return true;
}
if (!checkSlot.getOriginalTable().isPresent()) {
return false;
}
TableIf tableIf = checkSlot.getOriginalTable().get();
if (tableIf.isPartitionedTable() && tableIf.isPartitionColumn(column.getName())) {
return true;
}
/* Indexes are seldom used and are not supported temporarily
if (tableIf.getType() != TableType.OLAP) {
return false;
}
TableIndexes tableIndexes = tableIf.getTableIndexes();
for (Index index : tableIndexes.getIndexes()) {
IndexDef.IndexType type = index.getIndexType();
if (type == IndexType.NGRAM_BF || type == IndexType.BLOOMFILTER) {
continue;
}
Set<String> columns = new HashSet<>(index.getColumns());
if (columns.contains(column.getName())) {
return true;
}
}*/
return false;
}
// determine whether the comparison predicate of type between left right can be deduced by mid
private boolean checkDeducible(final Relation[][] graph, int left, int right, int mid, Relation type) {
Relation deduceType = connectInThisPath(graph, left, right, mid);
return deduceType == type;
}
private List<Integer> removeExprEqualToConstant(List<Integer> order, Set<Integer> equalWithConstant) {
// Remove expr equal to constant
List<Integer> orderToInfer = new ArrayList<>();
for (Integer integer : order) {
if (equalWithConstant.contains(integer)) {
continue;
}
orderToInfer.add(integer);
}
return orderToInfer;
}
/**chooseUnequalPredicates*/
public void chooseUnequalPredicates(Relation[][] chosen, Set<Integer> equalWithConstant) {
List<Integer> order = topoSort();
List<Integer> orderToInfer = removeExprEqualToConstant(order, equalWithConstant);
//Select predicate:
// 1. Do not select predicates that can be deduced from the intermediate expr
// 2. If it is an index column or partition column, reserve the predicate
for (int i = 1; i < orderToInfer.size(); ++i) {
for (int j = 0; j < i; ++j) {
int left = orderToInfer.get(i);
int right = orderToInfer.get(j);
if (graph[left][right] == Relation.EQ || graph[left][right] == Relation.UNDEFINED) {
continue;
}
if (!isTableFilter(left, right)) {
continue;
}
boolean skip = hasIndexOrPartitionColumn(usedExprs.get(left), usedExprs.get(right));
boolean deducible = false;
for (int m = j + 1; !skip && !deducible && m < i; ++m) {
int mid = orderToInfer.get(m);
if (usedExprs.get(mid) instanceof Literal) {
deducible = checkDeducible(graph, left, right, mid, graph[left][right]);
} else if (isTableFilter(left, mid) && isTableFilter(right, mid)) {
deducible = checkDeducible(graph, left, right, mid, graph[left][right]);
}
}
if (!deducible) {
set(chosen, left, right, graph[left][right]);
}
}
}
}
private Set<Expression> generatePredicates(Relation[][] chosen) {
Set<Expression> newPredicates = new LinkedHashSet<>();
for (int i = 0; i < size; ++i) {
for (int j = 0; j < size; ++j) {
if (i == j || isAllLiteral(i, j)) {
continue;
}
try {
if (chosen[i][j] == Relation.GT) {
newPredicates.add(normalize(new GreaterThan(usedExprs.get(i), usedExprs.get(j))));
} else if (chosen[i][j] == Relation.GTE) {
newPredicates.add(normalize(new GreaterThanEqual(usedExprs.get(i), usedExprs.get(j))));
} else if (chosen[i][j] == Relation.EQ) {
newPredicates.add(normalize(new EqualTo(usedExprs.get(i), usedExprs.get(j))));
clear(chosen, i, j, Relation.EQ);
}
} catch (AnalysisException e) {
// type error, just not generate this predicate, do nothing but continue
}
}
}
return newPredicates;
}
private ComparisonPredicate normalizePredicate(ComparisonPredicate expr) {
return expr.left().isConstant() && !expr.right().isConstant() ? expr.commute() : expr;
}
private Relation getType(ComparisonPredicate comparisonPredicate) {
if (comparisonPredicate instanceof GreaterThan) {
return Relation.GT;
} else if (comparisonPredicate instanceof GreaterThanEqual) {
return Relation.GTE;
} else if (comparisonPredicate instanceof EqualTo) {
return Relation.EQ;
}
return Relation.UNDEFINED;
}
private void clear(Relation[][] graph, int left, int right, Relation type) {
graph[left][right] = Relation.UNDEFINED;
if (type == Relation.EQ) {
graph[right][left] = Relation.UNDEFINED;
}
}
private void set(Relation[][] graph, int left, int right, Relation type) {
graph[left][right] = type;
if (type == Relation.EQ) {
graph[right][left] = type;
}
}
// A new edge from hub1 to hub2 has been added to the graph.
// Use this edge to extend the connectivity between the graph nodes
private void expandGraph(Relation[][] graph, int hub1, int hub2) {
//Update the path from all nodes to hub2 (use hub1->hub2)
for (int left = 0; left < size; ++left) {
connect(graph, left, hub2, hub1);
}
// Use hub2 as the transit node to update the path between any two nodes
for (int l = 0; l < size; ++l) {
for (int r = 0; r < size; ++r) {
connect(graph, l, r, hub2);
}
}
}
/**chooseInputPredicates*/
public Set<Expression> chooseInputPredicates(Relation[][] chosen) {
boolean[] keep = new boolean[usedPredicates.size()];
Relation[][] deduced = new Relation[size][size];
for (int i = 0; i < size; ++i) {
for (int j = 0; j < size; ++j) {
deduced[i][j] = chosen[i][j];
if (i == j) {
deduced[i][j] = Relation.EQ;
}
}
}
deduce(deduced);
// If an input predicate is not chosen and can be deduced by chosen,
// then the input predicate need not be retained (because it is a useless predicate)
// And the predicates in inputs that cannot be deduced by chosen should be retained.
for (int i = 0; i < usedPredicates.size(); ++i) {
Relation type = usedPredicatesPairs.get(i).relation;
int left = usedExprPosition.get(usedPredicatesPairs.get(i).pair.first);
int right = usedExprPosition.get(usedPredicatesPairs.get(i).pair.second);
if (chosen[left][right] == type) {
keep[i] = true;
clear(chosen, left, right, type);
} else if (deduced[left][right] != type) {
keep[i] = true;
set(deduced, left, right, Relation.EQ);
expandGraph(deduced, left, right);
if (type == Relation.EQ) {
expandGraph(deduced, right, left);
}
}
}
Set<Expression> chooseInputs = new LinkedHashSet<>();
for (int i = 0; i < usedPredicates.size(); ++i) {
if (!keep[i]) {
continue;
}
chooseInputs.add(normalizePredicate(usedPredicates.get(i))
.withInferred(usedPredicates.get(i).isInferred()));
}
return chooseInputs;
}
/**chooseEqualPredicates*/
public Relation[][] chooseEqualPredicates(Set<Integer> equalWithConstant) {
Relation[][] chosen = new Relation[size][size];
initGraph(chosen);
int[] equalToLiteral = new int[size];
Arrays.fill(equalToLiteral, -1);
// save equal predicates like a=b (no literal)
List<Pair<Integer, Integer>> tableFilters = new ArrayList<>();
// save equal predicates like t1.a=t2.b (no literal)
List<Pair<Integer, Integer>> nonTableFilters = new ArrayList<>();
for (int i = 0; i < size; ++i) {
for (int j = i + 1; j < size; ++j) {
if (graph[i][j] != Relation.EQ) {
continue;
}
// choose predicate with one side literal or t1.a=t2.b(not table filter equal)
if (usedExprs.get(i) instanceof Literal && usedExprs.get(j) instanceof Literal) {
continue;
} else if (!(usedExprs.get(i) instanceof Literal) && !(usedExprs.get(j) instanceof Literal)) {
if (isTableFilter(i, j)) {
tableFilters.add(Pair.of(i, j));
} else {
nonTableFilters.add(Pair.of(i, j));
}
} else if (usedExprs.get(i) instanceof Literal
|| usedExprs.get(j) instanceof Literal) {
set(chosen, i, j, Relation.EQ);
if (usedExprs.get(i) instanceof Literal) {
equalToLiteral[j] = i;
equalWithConstant.add(j);
} else {
equalToLiteral[i] = j;
equalWithConstant.add(i);
}
}
}
}
// a=b a=c a=1 only infer a=1 b=1 c=1, not retain a=b a=c
for (Pair<Integer, Integer> tableFilter : tableFilters) {
int left = tableFilter.first;
int right = tableFilter.second;
if (equalToLiteral[left] == -1 || equalToLiteral[right] == -1) {
set(chosen, left, right, Relation.EQ);
equalToLiteral[left] = left;
equalToLiteral[right] = left;
}
}
for (Pair<Integer, Integer> nonTableFilter : nonTableFilters) {
int left = nonTableFilter.first;
int right = nonTableFilter.second;
if (!equalWithConstant.contains(left) && !equalWithConstant.contains(right)) {
set(chosen, left, right, Relation.EQ);
}
}
return chosen;
}
private Expression normalize(ComparisonPredicate cmp) {
return TypeCoercionUtils.processComparisonPredicate(normalizePredicate(cmp)).withInferred(true);
}
private boolean isAllLiteral(int i, int j) {
Expression left = usedExprs.get(i);
Expression right = usedExprs.get(j);
return left instanceof Literal && right instanceof Literal;
}
/** for test */
public Relation[][] getGraph() {
return graph;
}
}
/**inferUnequalPredicates*/
public static Set<? extends Expression> inferUnequalPredicates(Set<Expression> inputs) {
if (inputs.size() < 2) {
return inputs;
}
InferenceGraph inferGraph = new InferenceGraph(inputs);
if (inferGraph.usedExprs.isEmpty()) {
return inputs;
}
inferGraph.deduce(inferGraph.graph);
Set<Integer> equalWithConstant = new HashSet<>();
InferenceGraph.Relation[][] chosen = inferGraph.chooseEqualPredicates(equalWithConstant);
inferGraph.chooseUnequalPredicates(chosen, equalWithConstant);
Set<Expression> newPredicates = inferGraph.chooseInputPredicates(chosen);
newPredicates.addAll(inferGraph.generatePredicates(chosen));
newPredicates.addAll(inferGraph.otherPredicates);
return newPredicates;
}
/** deduce predicates and generate all predicates without choosing*/
public static Set<? extends Expression> inferAllPredicates(Set<Expression> inputs) {
if (inputs.size() < 2) {
return inputs;
}
InferenceGraph inferGraph = new InferenceGraph(inputs);
if (inferGraph.usedExprs.isEmpty()) {
return inputs;
}
inferGraph.deduce(inferGraph.graph);
Set<Expression> newPredicates = new LinkedHashSet<>();
newPredicates.addAll(inferGraph.generatePredicates(inferGraph.graph));
newPredicates.addAll(inferGraph.otherPredicates);
return newPredicates;
}
}