PushDownAliasThroughJoin.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.MarkJoinSlotReference;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Pushdown Alias (inside must be Slot) through Join.
*/
public class PushDownAliasThroughJoin extends OneRewriteRuleFactory {
@Override
public Rule build() {
return logicalProject(logicalJoin())
.when(project -> project.getProjects().stream().allMatch(expr ->
(expr instanceof Slot && !(expr instanceof MarkJoinSlotReference))
|| (expr instanceof Alias && ((Alias) expr).child() instanceof Slot
&& !(((Alias) expr).child() instanceof MarkJoinSlotReference))))
.when(project -> project.getProjects().stream().anyMatch(expr -> expr instanceof Alias))
.then(project -> {
LogicalJoin<? extends Plan, ? extends Plan> join = project.child();
// aliasMap { Slot -> List<Alias<Slot>> }
Map<Expression, List<NamedExpression>> aliasMap = Maps.newHashMap();
project.getProjects().stream()
.filter(expr -> expr instanceof Alias && ((Alias) expr).child() instanceof Slot)
.forEach(expr -> {
List<NamedExpression> aliases = aliasMap.get(((Alias) expr).child());
if (aliases == null) {
aliases = Lists.newArrayList();
aliasMap.put(((Alias) expr).child(), aliases);
}
aliases.add(expr);
});
Preconditions.checkState(!aliasMap.isEmpty(), "aliasMap should not be empty");
List<NamedExpression> newProjects = project.getProjects().stream().map(NamedExpression::toSlot)
.collect(Collectors.toList());
List<Slot> leftOutput = join.left().getOutput();
List<NamedExpression> leftProjects = createNewOutput(leftOutput, aliasMap);
List<Slot> rightOutput = join.right().getOutput();
List<NamedExpression> rightProjects = createNewOutput(rightOutput, aliasMap);
Plan left;
Plan right;
if (leftOutput.equals(leftProjects)) {
left = join.left();
} else {
left = project.withProjectsAndChild(leftProjects, join.left());
}
if (rightOutput.equals(rightProjects)) {
right = join.right();
} else {
right = project.withProjectsAndChild(rightProjects, join.right());
}
// If condition use alias slot, we should replace condition
// project a.id as aid -- join a.id = b.id =>
// join aid = b.id -- project a.id as aid
Map<ExprId, Slot> replaceMap = aliasMap.entrySet().stream().collect(
Collectors.toMap(entry -> ((Slot) entry.getKey()).getExprId(),
entry -> entry.getValue().get(0).toSlot()));
List<Expression> newHash = replaceJoinConjuncts(join.getHashJoinConjuncts(), replaceMap);
List<Expression> newOther = replaceJoinConjuncts(join.getOtherJoinConjuncts(), replaceMap);
List<Expression> newMark = replaceJoinConjuncts(join.getMarkJoinConjuncts(), replaceMap);
Plan newJoin = join.withConjunctsChildren(newHash, newOther, newMark, left, right,
join.getJoinReorderContext());
return project.withProjectsAndChild(newProjects, newJoin);
}).toRule(RuleType.PUSH_DOWN_ALIAS_THROUGH_JOIN);
}
private List<Expression> replaceJoinConjuncts(List<Expression> joinConjuncts, Map<ExprId, Slot> replaceMaps) {
return joinConjuncts.stream().map(expr -> expr.rewriteUp(e -> {
if (e instanceof Slot && replaceMaps.containsKey(((Slot) e).getExprId())) {
return replaceMaps.get(((Slot) e).getExprId());
} else {
return e;
}
})).collect(ImmutableList.toImmutableList());
}
private List<NamedExpression> createNewOutput(List<Slot> oldOutput,
Map<Expression, List<NamedExpression>> aliasMap) {
// we should keep all original outputs and add new alias in the output list
// because the upper node may require both col#1 and col#1 as colAlias#2
List<NamedExpression> output = Stream.concat(oldOutput.stream(), oldOutput.stream()
.flatMap(slot -> aliasMap.getOrDefault(slot, Collections.emptyList()).stream()))
.collect(Collectors.toList());
return output;
}
}