ProjectOtherJoinConditionForNestedLoopJoin.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.thrift.TRuntimeFilterType;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
* join (l_orderkey > n_nationkey + n_regionkey)
* +----scan(lineItem)
* +----scan(nation)
* =>
* join(l_orderkey > x)
* +----scan(lineItem)
* +----project(n_nationkey + n_regionkey as x)
* +----scan(nation)
*/
public class ProjectOtherJoinConditionForNestedLoopJoin extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalJoin()
.when(join -> join.getHashJoinConjuncts().isEmpty()
&& !join.isMarkJoin()
&& !join.getOtherJoinConjuncts().isEmpty())
.whenNot(join -> ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable()
.allowedRuntimeFilterType(TRuntimeFilterType.BITMAP))
.then(join -> {
List<Expression> otherConjuncts = join.getOtherJoinConjuncts();
List<Expression> newOtherConjuncts = new ArrayList<>();
Set<Slot> leftSlots = join.child(0).getOutputSet();
Set<Slot> rightSlots = join.child(1).getOutputSet();
ReplacerContext ctx = new ReplacerContext(leftSlots, rightSlots);
for (Expression conj : otherConjuncts) {
Expression newConj = conj.accept(AliasReplacer.INSTANCE, ctx);
newOtherConjuncts.add(newConj);
}
boolean changed = !ctx.leftAlias.isEmpty() || !ctx.rightAlias.isEmpty();
if (changed) {
Plan left = join.left();
if (!ctx.leftAlias.isEmpty()) {
List<NamedExpression> newProjects = Lists.newArrayList(left.getOutput());
newProjects.addAll(ctx.leftAlias);
left = new LogicalProject<>(newProjects, left);
}
Plan right = join.right();
if (!ctx.rightAlias.isEmpty()) {
List<NamedExpression> newProjects = Lists.newArrayList(right.getOutput());
newProjects.addAll(ctx.rightAlias);
right = new LogicalProject<>(newProjects, right);
}
return join.withJoinConjuncts(join.getHashJoinConjuncts(),
newOtherConjuncts, join.getJoinReorderContext())
.withChildren(ImmutableList.of(left, right));
}
return null;
}
).toRule(RuleType.PROJECT_OTHER_JOIN_CONDITION);
}
private static class ReplacerContext {
HashMap<Expression, Alias> aliasMap = new HashMap<>();
Set<Slot> leftSlots;
Set<Slot> rightSlots;
Set<Alias> leftAlias = new HashSet<>();
Set<Alias> rightAlias = new HashSet<>();
public ReplacerContext(Set<Slot> leftSlots, Set<Slot> rightSlots) {
this.leftSlots = leftSlots;
this.rightSlots = rightSlots;
}
}
private static class AliasReplacer extends DefaultExpressionRewriter<ReplacerContext> {
public static AliasReplacer INSTANCE = new AliasReplacer();
@Override
public Expression visit(Expression expression, ReplacerContext ctx) {
Set<Slot> input = expression.getInputSlots();
if (input.isEmpty() || expression instanceof Slot) {
return expression;
}
if (ctx.leftSlots.containsAll(input)) {
Alias alias = ctx.aliasMap.computeIfAbsent(expression, o -> new Alias(o));
ctx.leftAlias.add(alias);
return alias.toSlot();
} else if (ctx.rightSlots.containsAll(input)) {
Alias alias = ctx.aliasMap.computeIfAbsent(expression, o -> new Alias(o));
ctx.rightAlias.add(alias);
return alias.toSlot();
} else {
return super.visit(expression, ctx);
}
}
}
}