PullUpJoinFromUnionAll.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.catalog.constraint.TableIdentifier;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.EqualTo;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalCatalogRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import org.apache.doris.nereids.util.ExpressionUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import javax.annotation.Nullable;
/**
* Pull up join from union all rules with project:
* Union
* / \
* project project
* (optional) (optional)
* | |
* Join Join
* / \ / \
* t1 t2 t1 t3 (t1 is common side; t2,t3 is other side)
* =====>
* project
* |
* Join
* / \
* Union t1
* / \
* project project
* (optional) (optional)
* | |
* t2 t3
*/
public class PullUpJoinFromUnionAll extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalUnion()
.when(union -> union.getQualifier() != Qualifier.DISTINCT
&& union.getConstantExprsList().isEmpty())
.then(union -> {
HashMap<Plan, List<Pair<LogicalJoin<?, ?>, Plan>>> commonChildrenMap =
tryToExtractCommonChild(union);
if (commonChildrenMap == null) {
return null;
}
// The joinsAndCommonSides size is the same as the number of union children.
List<Pair<LogicalJoin<?, ?>, Plan>> joinsAndCommonSides = null;
for (List<Pair<LogicalJoin<?, ?>, Plan>> childSet : commonChildrenMap.values()) {
if (childSet.size() == union.children().size()) {
joinsAndCommonSides = childSet;
break;
}
}
if (joinsAndCommonSides == null) {
return null;
}
List<List<NamedExpression>> otherOutputsList = new ArrayList<>();
List<Pair<Boolean, ExpressionOrIndex>> upperProjectExpressionOrIndex = new ArrayList<>();
// First, check whether the output of the union child meets the requirements.
if (!checkUnionChildrenOutput(union, joinsAndCommonSides, otherOutputsList,
upperProjectExpressionOrIndex)) {
return null;
}
List<Map<SlotReference, List<SlotReference>>> commonSlotToOtherSlotMaps = new ArrayList<>();
Set<SlotReference> joinCommonSlots = new LinkedHashSet<>();
if (!checkJoinCondition(joinsAndCommonSides, commonSlotToOtherSlotMaps, joinCommonSlots)) {
return null;
}
Map<SlotReference, List<Integer>> commonSlotToProjectsIndex = new HashMap<>();
LogicalUnion newUnion = constructNewUnion(joinsAndCommonSides, otherOutputsList,
commonSlotToOtherSlotMaps, joinCommonSlots, commonSlotToProjectsIndex);
LogicalJoin<LogicalUnion, Plan> newJoin = constructNewJoin(newUnion,
commonSlotToProjectsIndex, joinsAndCommonSides);
LogicalProject newProject = constructNewProject(union, newJoin, upperProjectExpressionOrIndex);
return newProject;
}).toRule(RuleType.PULL_UP_JOIN_FROM_UNION_ALL);
}
private LogicalProject<Plan> constructNewProject(LogicalUnion originUnion, LogicalJoin<LogicalUnion, Plan> newJoin,
List<Pair<Boolean, ExpressionOrIndex>> upperProjectExpressionOrIndex) {
List<Slot> originOutput = originUnion.getOutput();
List<NamedExpression> upperProjects = new ArrayList<>();
List<Slot> newUnionOutput = newJoin.left().getOutput();
if (originOutput.size() != upperProjectExpressionOrIndex.size()) {
return null;
}
for (int i = 0; i < upperProjectExpressionOrIndex.size(); ++i) {
Pair<Boolean, ExpressionOrIndex> pair = upperProjectExpressionOrIndex.get(i);
boolean fromCommon = pair.first;
if (fromCommon) {
upperProjects.add(new Alias(originOutput.get(i).getExprId(), pair.second.exprFromCommonSide,
originOutput.get(i).getName()));
} else {
upperProjects.add(new Alias(originOutput.get(i).getExprId(),
newUnionOutput.get(pair.second.indexOfNewUnionOutput), originOutput.get(i).getName()));
}
}
return new LogicalProject<>(upperProjects, newJoin);
}
private LogicalJoin<LogicalUnion, Plan> constructNewJoin(LogicalUnion union,
Map<SlotReference, List<Integer>> commonSlotToProjectsIndex,
List<Pair<LogicalJoin<?, ?>, Plan>> commonChild) {
LogicalJoin<?, ?> originalJoin = commonChild.iterator().next().first;
Plan newCommon = commonChild.iterator().next().second;
List<Expression> newHashExpressions = new ArrayList<>();
List<Slot> unionOutputs = union.getOutput();
for (Map.Entry<SlotReference, List<Integer>> entry : commonSlotToProjectsIndex.entrySet()) {
SlotReference commonSlot = entry.getKey();
for (Integer index : entry.getValue()) {
newHashExpressions.add(new EqualTo(unionOutputs.get(index), commonSlot));
}
}
return (LogicalJoin<LogicalUnion, Plan>) originalJoin
.withJoinConjuncts(newHashExpressions, ImmutableList.of(), originalJoin.getJoinReorderContext())
.withChildren(union, newCommon);
}
// Output parameter: commonSlotToProjectsIndex, key is the common slot of join condition,
// value is the index of the other slot corresponding to this common slot in the union output,
// which is used to construct the join condition of the new join.
private LogicalUnion constructNewUnion(List<Pair<LogicalJoin<?, ?>, Plan>> joinsAndCommonSides,
List<List<NamedExpression>> otherOutputsList,
List<Map<SlotReference, List<SlotReference>>> commonSlotToOtherSlotMaps,
Set<SlotReference> joinCommonSlots, Map<SlotReference, List<Integer>> commonSlotToProjectsIndex) {
List<Plan> newChildren = new ArrayList<>();
for (int i = 0; i < joinsAndCommonSides.size(); ++i) {
Pair<LogicalJoin<?, ?>, Plan> pair = joinsAndCommonSides.get(i);
// find the child that is not the common side
Plan otherSide;
if (pair.second == pair.first.left()) {
otherSide = pair.first.right();
} else {
otherSide = pair.first.left();
}
List<NamedExpression> projects = otherOutputsList.get(i);
// In projects, we also need to add the other slot in join condition
// TODO: may eliminate repeated output slots:
// e.g.select t2.a from test_like1 t1 join test_like2 t2 on t1.a=t2.a union ALL
// select t3.a from test_like1 t1 join test_like3 t3 on t1.a=t3.a;
// new union child will output t2.a/t3.a twice. one for output, the other for join condition.
Map<SlotReference, List<SlotReference>> commonSlotToOtherSlotMap = commonSlotToOtherSlotMaps.get(i);
for (SlotReference commonSlot : joinCommonSlots) {
List<SlotReference> otherSlots = commonSlotToOtherSlotMap.get(commonSlot);
for (SlotReference otherSlot : otherSlots) {
if (i == 0) {
int index = projects.size();
commonSlotToProjectsIndex.computeIfAbsent(commonSlot, k -> new ArrayList<>()).add(index);
}
projects.add(otherSlot);
}
}
LogicalProject<Plan> logicalProject = new LogicalProject<>(projects, otherSide);
newChildren.add(logicalProject);
}
//2. construct new union
LogicalUnion newUnion = new LogicalUnion(Qualifier.ALL, newChildren);
List<List<SlotReference>> childrenOutputs = newChildren.stream()
.map(p -> p.getOutput().stream()
.map(SlotReference.class::cast)
.collect(ImmutableList.toImmutableList()))
.collect(ImmutableList.toImmutableList());
newUnion = (LogicalUnion) newUnion.withChildrenAndTheirOutputs(newChildren, childrenOutputs);
newUnion = newUnion.withNewOutputs(newUnion.buildNewOutputs());
return newUnion;
}
/** This function is used to check whether the join condition meets the optimization condition
* Check the join condition, requiring that the join condition of each join is equal and the number is the same.
* Generate commonSlotToOtherSlotMaps. In each map of the list, the keySet must be the same,
* and the length of the value list of the same key must be the same.
* These are sql that can not do this transform:
* SQL1: select t2.a+1,2 from test_like1 t1 join test_like2 t2 on t1.a=t2.a union ALL
* select t3.a+1,3 from test_like1 t1 join test_like3 t3 on t1.a=t3.a and t1.b=t3.b;
* SQL2: select t2.a+1,2 from test_like1 t1 join test_like2 t2 on t1.a=t2.a union ALL
* select t3.a+1,3 from test_like1 t1 join test_like3 t3 on t1.b=t3.a;
* SQL3: select t2.a+1,2 from test_like1 t1 join test_like2 t2 on t1.a=t2.a union ALL
* select t3.a+1,3 from test_like1 t1 join test_like3 t3 on t1.a=t3.a and t1.a=t3.b;
* @param commonSlotToOtherSlotMaps Output parameter that records the join conditions for each join operation.
* The key represents the slot on the common side of the join, while the value
* corresponds to the slot on the other side.
* Example:
* For the following SQL:
* SELECT t2.a + 1, 2 FROM test_like1 t1
* JOIN test_like2 t2 ON t1.a = t2.a AND t1.a = t2.c AND t1.b = t2.b
* UNION ALL
* SELECT t3.a + 1, 3 FROM test_like1 t1
* JOIN test_like3 t3 ON t1.a = t3.a AND t1.a = t3.d AND t1.b = t3.b;
* commonSlotToOtherSlotMaps would be:
* {{t1.a: t2.a, t2.c; t1.b: t2.b}, {t1.a: t3.a, t3.d; t1.b: t3.b}}
* This parameter is used to verify if the join conditions meet
* optimization requirements and to help generate new join conditions.
* @param joinCommonSlots output parameter, which records join common side slots.
* */
private boolean checkJoinCondition(List<Pair<LogicalJoin<?, ?>, Plan>> joinsAndCommonSides,
List<Map<SlotReference, List<SlotReference>>> commonSlotToOtherSlotMaps,
Set<SlotReference> joinCommonSlots) {
Map<SlotReference, List<SlotReference>> conditionMapFirst = new HashMap<>();
Map<Slot, Slot> commonJoinSlotMap = buildCommonJoinMap(joinsAndCommonSides);
for (int i = 0; i < joinsAndCommonSides.size(); ++i) {
Pair<LogicalJoin<?, ?>, Plan> pair = joinsAndCommonSides.get(i);
LogicalJoin<?, ?> join = pair.first;
Plan commonSide = pair.second;
Map<SlotReference, List<SlotReference>> conditionMapSubsequent = new HashMap<>();
for (Expression condition : join.getHashJoinConjuncts()) {
if (!(condition instanceof EqualTo)) {
return false;
}
EqualTo equalTo = (EqualTo) condition;
if (!(equalTo.left() instanceof SlotReference) || !(equalTo.right() instanceof SlotReference)) {
return false;
}
SlotReference commonSideSlot;
SlotReference otherSideSlot;
if (commonSide.getOutputSet().contains(equalTo.left())) {
commonSideSlot = (SlotReference) equalTo.left();
otherSideSlot = (SlotReference) equalTo.right();
} else {
commonSideSlot = (SlotReference) equalTo.right();
otherSideSlot = (SlotReference) equalTo.left();
}
if (i == 0) {
conditionMapFirst.computeIfAbsent(commonSideSlot, k -> new ArrayList<>()).add(otherSideSlot);
joinCommonSlots.add(commonSideSlot);
} else {
conditionMapSubsequent.computeIfAbsent(
(SlotReference) ExpressionUtils.replace(commonSideSlot, commonJoinSlotMap),
k -> new ArrayList<>()).add(otherSideSlot);
}
}
if (i == 0) {
commonSlotToOtherSlotMaps.add(conditionMapFirst);
} else {
// reject SQL1
if (conditionMapSubsequent.size() != conditionMapFirst.size()) {
return false;
}
// reject SQL2
if (!conditionMapSubsequent.keySet().equals(conditionMapFirst.keySet())) {
return false;
}
// reject SQL3
for (Map.Entry<SlotReference, List<SlotReference>> entry : conditionMapFirst.entrySet()) {
SlotReference commonSlot = entry.getKey();
if (conditionMapSubsequent.get(commonSlot).size() != entry.getValue().size()) {
return false;
}
}
commonSlotToOtherSlotMaps.add(conditionMapSubsequent);
}
}
return true;
}
// Make a map to map the output of all other joins to the output of the first join
private Map<Slot, Slot> buildCommonJoinMap(List<Pair<LogicalJoin<?, ?>, Plan>> commonChild) {
Map<Slot, Slot> commonJoinSlotMap = new HashMap<>();
List<Slot> firstJoinOutput = new ArrayList<>();
for (int i = 0; i < commonChild.size(); ++i) {
Pair<LogicalJoin<?, ?>, Plan> pair = commonChild.get(i);
Plan commonSide = pair.second;
if (i == 0) {
firstJoinOutput.addAll(commonSide.getOutput());
for (Slot slot : commonSide.getOutput()) {
commonJoinSlotMap.put(slot, slot);
}
} else {
for (int j = 0; j < commonSide.getOutput().size(); ++j) {
commonJoinSlotMap.put(commonSide.getOutput().get(j), firstJoinOutput.get(j));
}
}
}
return commonJoinSlotMap;
}
private class ExpressionOrIndex {
Expression exprFromCommonSide = null;
int indexOfNewUnionOutput = -1;
private ExpressionOrIndex(Expression expr) {
exprFromCommonSide = expr;
}
private ExpressionOrIndex(int index) {
indexOfNewUnionOutput = index;
}
}
/** In the union child output, the number of outputs from the common side must be the same in each child output,
* and the outputs from the common side must be isomorphic (both a+1) and have the same index in the union output.
* In the union child output, the number of outputs from the non-common side must also be the same,
* but they do not need to be isomorphic.
* These are sql that can not do this transform:
* SQL1: select t2.a+t1.a from test_like1 t1 join test_like2 t2 on t1.a=t2.a union ALL
* select t3.a+1 from test_like1 t1 join test_like3 t3 on t1.a=t3.a;
* SQL2: select t2.a from test_like1 t1 join test_like2 t2 on t1.a=t2.a union ALL
* select t1.a from test_like1 t1 join test_like3 t3 on t1.a=t3.a;
* SQL3: select t1.a+1 from test_like1 t1 join test_like2 t2 on t1.a=t2.a union ALL
* select t1.a+2 from test_like1 t1 join test_like3 t3 on t1.a=t3.a;
* SQL4: select t1.a from test_like1 t1 join test_like2 t2 on t1.a=t2.a union ALL
* select 1 from test_like1 t1 join test_like3 t3 on t1.a=t3.a;
* @param otherOutputsList output parameter that stores the outputs of the other side.
* The length of each element in otherOutputsList must be the same.
* The i-th element represents the output of the other side in the i-th child of the union.
* This parameter is used to create child nodes of a new Union
* in the constructNewUnion function.
*
* @param upperProjectExpressionOrIndex Output parameter used in the constructNewProject function to create
* the top-level project.This parameter records the output column order of
* the original union and determines,based on the new join output, the columns
* or expressions to output in the upper-level project operator. The size of
* upperProjectExpressionOrIndex must match the output size of
* the original union。
* Each Pair in the List represents an output source:
* - Pair.first (Boolean): Indicates whether the output is from
* the common side (true) or the other side (false).
* - Pair.second (Object): When Pair.first is true, it stores
* the common side's output expression.When false, it saves the output index
* of the other side. Since the new union output is not yet constructed
* at this point, only the index is stored.
* The function’s check ensures that outputs at the same position in
* union children either come from the common side or from the other side.
* When the final join is constructed, the common side uses the first join's
* common side, so only the first child’s outputs need to be processed to
* fill in upperProjectExpressionOrIndex.
*/
private boolean checkUnionChildrenOutput(LogicalUnion union,
List<Pair<LogicalJoin<?, ?>, Plan>> joinsAndCommonSides,
List<List<NamedExpression>> otherOutputsList,
List<Pair<Boolean, ExpressionOrIndex>> upperProjectExpressionOrIndex) {
List<List<SlotReference>> regularChildrenOutputs = union.getRegularChildrenOutputs();
int arity = union.arity();
if (arity == 0) {
return false;
}
// fromCommonSide is used to ensure that the outputs at the same position in the union children
// must all come from the common side or from the other side
boolean[] fromCommonSide = new boolean[regularChildrenOutputs.get(0).size()];
// checkSameExpr and commonJoinSlotMap are used to ensure that Expr from the common side have the same structure
Expression[] checkSameExpr = new Expression[regularChildrenOutputs.get(0).size()];
Map<Slot, Slot> commonJoinSlotMap = buildCommonJoinMap(joinsAndCommonSides);
for (int i = 0; i < arity; ++i) {
List<SlotReference> regularChildrenOutput = regularChildrenOutputs.get(i);
Plan child = union.child(i);
List<NamedExpression> otherOutputs = new ArrayList<>();
for (int j = 0; j < regularChildrenOutput.size(); ++j) {
SlotReference slot = regularChildrenOutput.get(j);
if (child instanceof LogicalProject) {
LogicalProject<Plan> project = (LogicalProject<Plan>) child;
int index = project.getOutput().indexOf(slot);
NamedExpression expr = project.getOutputs().get(index);
Slot insideSlot;
Expression insideExpr;
Set<Slot> inputSlots = expr.getInputSlots();
// reject SQL1
if (inputSlots.size() > 1) {
return false;
} else if (inputSlots.size() == 1) {
if (expr instanceof Alias) {
insideSlot = inputSlots.iterator().next();
insideExpr = expr.child(0);
} else if (expr instanceof SlotReference) {
insideSlot = (Slot) expr;
insideExpr = expr;
} else {
return false;
}
Plan commonSide = joinsAndCommonSides.get(i).second;
if (i == 0) {
if (commonSide.getOutputSet().contains(insideSlot)) {
fromCommonSide[j] = true;
checkSameExpr[j] = insideExpr;
upperProjectExpressionOrIndex.add(Pair.of(true, new ExpressionOrIndex(insideExpr)));
} else {
fromCommonSide[j] = false;
upperProjectExpressionOrIndex.add(Pair.of(false, new ExpressionOrIndex(
otherOutputs.size())));
otherOutputs.add(expr);
}
} else {
// reject SQL2
if (commonSide.getOutputSet().contains(insideSlot) != fromCommonSide[j]) {
return false;
}
// reject SQL3
if (commonSide.getOutputSet().contains(insideSlot)) {
Expression sameExpr = ExpressionUtils.replace(insideExpr, commonJoinSlotMap);
if (!sameExpr.equals(checkSameExpr[j])) {
return false;
}
} else {
otherOutputs.add(expr);
}
}
} else if (expr.getInputSlots().isEmpty()) {
// Constants must come from other side
if (i == 0) {
fromCommonSide[j] = false;
upperProjectExpressionOrIndex.add(Pair.of(false, new ExpressionOrIndex(
otherOutputs.size())));
} else {
// reject SQL4
if (fromCommonSide[j]) {
return false;
}
}
otherOutputs.add(expr);
}
} else if (child instanceof LogicalJoin) {
Plan commonSide = joinsAndCommonSides.get(i).second;
if (i == 0) {
if (commonSide.getOutputSet().contains(slot)) {
fromCommonSide[j] = true;
checkSameExpr[j] = slot;
upperProjectExpressionOrIndex.add(Pair.of(true, new ExpressionOrIndex(slot)));
} else {
fromCommonSide[j] = false;
upperProjectExpressionOrIndex.add(Pair.of(false,
new ExpressionOrIndex(otherOutputs.size())));
otherOutputs.add(slot);
}
} else {
// reject SQL2
if (commonSide.getOutputSet().contains(slot) != fromCommonSide[j]) {
return false;
}
// reject SQL3
if (commonSide.getOutputSet().contains(slot)) {
Expression sameExpr = ExpressionUtils.replace(slot, commonJoinSlotMap);
if (!sameExpr.equals(checkSameExpr[j])) {
return false;
}
} else {
otherOutputs.add(slot);
}
}
}
}
otherOutputsList.add(otherOutputs);
}
return true;
}
/**
* Attempts to extract common children from a LogicalUnion.
*
* This method iterates through all children of the union, looking for LogicalJoin operations,
* and tries to identify common left or right subtrees. The results are stored in a Map where
* keys are potential common subtrees and values are lists of pairs containing the original
* join and the corresponding subtree.
*
* For example, given the following union:
* Union
* ├─ Join(A, B)
* ├─ Join(A, C)
* └─ Join(D, B)
*
* The returned Map would contain:
* A -> [(Join(A,B), A), (Join(A,C), A)]
* B -> [(Join(A,B), B), (Join(D,B), B)]
*
* This indicates that both A and B are potential common subtrees that could be extracted.
*
* @param union The LogicalUnion to analyze
* @return A Map containing potential common subtrees, or null if extraction is not possible
*/
private @Nullable HashMap<Plan, List<Pair<LogicalJoin<?, ?>, Plan>>> tryToExtractCommonChild(LogicalUnion union) {
HashMap<Plan, List<Pair<LogicalJoin<?, ?>, Plan>>> planCount = new HashMap<>();
for (Plan child : union.children()) {
LogicalJoin<? extends Plan, ? extends Plan> join = tryToGetJoin(child);
if (join == null) {
return null;
}
boolean added = false;
for (Plan plan : planCount.keySet()) {
LogicalPlanComparator comparator = new LogicalPlanComparator();
if (comparator.isLogicalEqual(join.left(), plan)) {
planCount.get(plan).add(Pair.of(join, join.left()));
added = true;
break;
} else if (comparator.isLogicalEqual(join.right(), plan)) {
planCount.get(plan).add(Pair.of(join, join.right()));
added = true;
break;
}
}
if (!added) {
planCount.put(join.left(), Lists.newArrayList(Pair.of(join, join.left())));
planCount.put(join.right(), Lists.newArrayList(Pair.of(join, join.right())));
}
}
return planCount;
}
// we only allow project(join) or join()
private @Nullable LogicalJoin<?, ?> tryToGetJoin(Plan child) {
if (child instanceof LogicalProject) {
child = child.child(0);
}
if (child instanceof LogicalJoin
&& ((LogicalJoin<?, ?>) child).getJoinType().isInnerJoin()
&& ((LogicalJoin<?, ?>) child).getOtherJoinConjuncts().isEmpty()
&& !((LogicalJoin<?, ?>) child).isMarkJoin()) {
return (LogicalJoin<?, ?>) child;
}
return null;
}
class LogicalPlanComparator {
private HashMap<Expression, Expression> plan1ToPlan2 = new HashMap<>();
public boolean isLogicalEqual(Plan plan1, Plan plan2) {
if (plan1.children().size() != plan2.children().size()) {
return false;
}
for (int i = 0; i < plan1.children().size(); i++) {
if (!isLogicalEqual(plan1.child(i), plan2.child(i))) {
return false;
}
}
if (isNotSupported(plan1) || isNotSupported(plan2)) {
return false;
}
return comparePlan(plan1, plan2);
}
boolean isNotSupported(Plan plan) {
return !(plan instanceof LogicalFilter)
&& !(plan instanceof LogicalCatalogRelation)
&& !(plan instanceof LogicalProject);
}
boolean comparePlan(Plan plan1, Plan plan2) {
boolean isEqual = true;
if (plan1 instanceof LogicalCatalogRelation && plan2 instanceof LogicalCatalogRelation) {
isEqual = new TableIdentifier(((LogicalCatalogRelation) plan1).getTable())
.equals(new TableIdentifier(((LogicalCatalogRelation) plan2).getTable()));
} else if (plan1 instanceof LogicalProject && plan2 instanceof LogicalProject) {
if (plan1.getOutput().size() != plan2.getOutput().size()) {
isEqual = false;
}
for (int i = 0; isEqual && i < plan2.getOutput().size(); i++) {
Expression expr1 = ((LogicalProject<?>) plan1).getProjects().get(i);
Expression expr2 = ((LogicalProject<?>) plan2).getProjects().get(i);
if (expr1 instanceof Alias) {
if (!(expr2 instanceof Alias)) {
return false;
}
expr1 = expr1.child(0);
expr2 = expr2.child(0);
}
Expression replacedExpr = expr1.rewriteUp(e -> plan1ToPlan2.getOrDefault(e, e));
if (!replacedExpr.equals(expr2)) {
isEqual = false;
}
}
} else if (plan1 instanceof LogicalFilter && plan2 instanceof LogicalFilter) {
Set<Expression> replacedConjuncts = new HashSet<>();
for (Expression expr : ((LogicalFilter<?>) plan1).getConjuncts()) {
replacedConjuncts.add(expr.rewriteUp(e -> plan1ToPlan2.getOrDefault(e, e)));
}
isEqual = replacedConjuncts.equals(((LogicalFilter<?>) plan2).getConjuncts());
} else {
isEqual = false;
}
if (!isEqual) {
return false;
}
for (int i = 0; i < plan1.getOutput().size(); i++) {
plan1ToPlan2.put(plan1.getOutput().get(i), plan2.getOutput().get(i));
}
return true;
}
}
}