PushDownProject.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.StatementContext;
import org.apache.doris.nereids.pattern.MatchingContext;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.PreferPushDownProject;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.SetOperation.Qualifier;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalOneRowRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalUnion;
import com.google.common.collect.ArrayListMultimap;
import com.google.common.collect.ImmutableCollection;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Multimap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.function.Function;
/** push down project if the expression instance of PreferPushDownProject */
public class PushDownProject implements RewriteRuleFactory, NormalizeToSlot {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
RuleType.PUSH_DOWN_PROJECT_THROUGH_JOIN.build(
logicalJoin().thenApply(this::pushDownJoinExpressions)
),
RuleType.PUSH_DOWN_PROJECT_THROUGH_JOIN.build(
logicalProject(logicalJoin()).thenApply(this::defaultPushDownProject)
),
RuleType.PUSH_DOWN_PROJECT_THROUGH_WINDOW.build(
logicalProject(logicalWindow()).thenApply(this::defaultPushDownProject)
),
RuleType.PUSH_DOWN_PROJECT_THROUGH_PARTITION_TOP_N.build(
logicalProject(logicalPartitionTopN()).thenApply(this::defaultPushDownProject)
),
// RuleType.PUSH_DOWN_PROJECT_THROUGH_DEFER_MATERIALIZE_TOP_N.build(
// logicalProject(logicalDeferMaterializeTopN()).thenApply(this::defaultPushDownProject)
// ),
RuleType.PUSH_DOWN_PROJECT_THROUGH_UNION.build(
logicalProject(
logicalUnion().when(u -> u.getQualifier() == Qualifier.ALL)
).thenApply(this::pushThroughUnion)
)
);
}
private Plan pushDownJoinExpressions(MatchingContext<LogicalJoin<Plan, Plan>> ctx) {
LogicalJoin<Plan, Plan> join = ctx.root;
Optional<Pair<List<Expression>, Map<Integer, List<NamedExpression>>>> rewriteHashJoinConjunctsResult
= pushDownProjectInExpressions(join, join.getHashJoinConjuncts(), ctx.statementContext);
Optional<Pair<List<Expression>, Map<Integer, List<NamedExpression>>>> rewriteOtherJoinConjunctsResult
= pushDownProjectInExpressions(join, join.getOtherJoinConjuncts(), ctx.statementContext);
if (!rewriteHashJoinConjunctsResult.isPresent() && !rewriteOtherJoinConjunctsResult.isPresent()) {
return join;
}
List<Expression> newHashJoinConjuncts = rewriteHashJoinConjunctsResult.isPresent()
? rewriteHashJoinConjunctsResult.get().first : join.getHashJoinConjuncts();
List<Expression> newOtherJoinConjuncts = rewriteOtherJoinConjunctsResult.isPresent()
? rewriteOtherJoinConjunctsResult.get().first : join.getOtherJoinConjuncts();
List<List<NamedExpression>> pushedOutput = new ArrayList<>();
pushedOutput.add(new ArrayList<>(join.left().getOutput()));
pushedOutput.add(new ArrayList<>(join.right().getOutput()));
if (rewriteHashJoinConjunctsResult.isPresent()) {
Map<Integer, List<NamedExpression>> childIndexToConjuncts = rewriteHashJoinConjunctsResult.get().second;
List<NamedExpression> leftConjuncts = childIndexToConjuncts.get(0);
if (leftConjuncts != null) {
pushedOutput.get(0).addAll(leftConjuncts);
}
List<NamedExpression> rightConjuncts = childIndexToConjuncts.get(1);
if (rightConjuncts != null) {
pushedOutput.get(1).addAll(rightConjuncts);
}
}
if (rewriteOtherJoinConjunctsResult.isPresent()) {
Map<Integer, List<NamedExpression>> childIndexToConjuncts = rewriteOtherJoinConjunctsResult.get().second;
List<NamedExpression> leftOtherConjuncts = childIndexToConjuncts.get(0);
if (leftOtherConjuncts != null) {
pushedOutput.get(0).addAll(leftOtherConjuncts);
}
List<NamedExpression> rightOtherConjuncts = childIndexToConjuncts.get(1);
if (rightOtherConjuncts != null) {
pushedOutput.get(1).addAll(rightOtherConjuncts);
}
}
Plan newLeft = join.left();
Plan newRight = join.right();
if (pushedOutput.get(0).size() != newLeft.getOutput().size()) {
newLeft = new LogicalProject<>(pushedOutput.get(0), newLeft);
}
if (pushedOutput.get(1).size() != newRight.getOutput().size()) {
newRight = new LogicalProject<>(pushedOutput.get(1), newRight);
}
return join.withJoinConjuncts(
newHashJoinConjuncts, newOtherJoinConjuncts,
join.getMarkJoinConjuncts(), join.getJoinReorderContext()
).withChildren(newLeft, newRight);
}
// return:
// key: rewrite the PreferPushDownProject to slot
// value: the pushed down project outputs which contains the Alias(PreferPushDownProject)
private Optional<Pair<List<Expression>, Map<Integer, List<NamedExpression>>>> pushDownProjectInExpressions(
Plan plan, Collection<Expression> expressions, StatementContext context) {
boolean changed = false;
Map<Integer, List<NamedExpression>> childIndexToPushedAlias = new LinkedHashMap<>();
List<Expression> newExpressions = new ArrayList<>();
for (Expression expression : expressions) {
Expression newExpression = expression.rewriteDownShortCircuit(e -> {
if (e instanceof PreferPushDownProject) {
List<Plan> children = plan.children();
for (int i = 0; i < children.size(); i++) {
Plan child = children.get(i);
if (child.getOutputSet().containsAll(e.getInputSlots())) {
Alias alias = new Alias(context.getNextExprId(), e);
Slot slot = alias.toSlot();
List<NamedExpression> namedExpressions
= childIndexToPushedAlias.computeIfAbsent(i, k -> new ArrayList<>());
namedExpressions.add(alias);
return slot;
}
}
}
return e;
});
newExpressions.add(newExpression);
changed |= newExpression != expression;
}
if (changed) {
return Optional.of(Pair.of(newExpressions, childIndexToPushedAlias));
}
return Optional.empty();
}
private <C extends LogicalPlan> Plan defaultPushDownProject(MatchingContext<LogicalProject<C>> ctx) {
if (!ctx.connectContext.getSessionVariable().enablePruneNestedColumns) {
return ctx.root;
}
LogicalProject<C> project = ctx.root;
C child = project.child();
PushdownProjectHelper pushdownProjectHelper
= new PushdownProjectHelper(ctx.statementContext, child);
Pair<Boolean, List<NamedExpression>> pushProjects
= pushdownProjectHelper.pushDownExpressions(project.getProjects());
if (pushProjects.first) {
List<Plan> newJoinChildren = pushdownProjectHelper.buildNewChildren();
return new LogicalProject<>(
pushProjects.second,
child.withChildren(newJoinChildren)
);
}
return project;
}
private Plan pushThroughUnion(MatchingContext<LogicalProject<LogicalUnion>> ctx) {
if (!ctx.connectContext.getSessionVariable().enablePruneNestedColumns) {
return ctx.root;
}
LogicalProject<LogicalUnion> project = ctx.root;
LogicalUnion union = project.child();
PushdownProjectHelper pushdownProjectHelper
= new PushdownProjectHelper(ctx.statementContext, project);
Pair<Boolean, List<NamedExpression>> pushProjects
= pushdownProjectHelper.pushDownExpressions(project.getProjects());
if (pushProjects.first) {
List<NamedExpression> unionOutputs = union.getOutputs();
Map<Slot, Integer> slotToColumnIndex = new LinkedHashMap<>();
for (int i = 0; i < unionOutputs.size(); i++) {
NamedExpression output = unionOutputs.get(i);
slotToColumnIndex.put(output.toSlot(), i);
}
Collection<NamedExpression> pushDownProjections
= pushdownProjectHelper.childToPushDownProjects.values();
List<Plan> newChildren = new ArrayList<>();
List<List<SlotReference>> newChildrenOutputs = new ArrayList<>();
for (Plan child : union.children()) {
List<NamedExpression> pushedOutput = replaceSlot(
ctx.statementContext,
pushDownProjections,
slot -> {
Integer sourceColumnIndex = slotToColumnIndex.get(slot);
if (sourceColumnIndex != null) {
return child.getOutput().get(sourceColumnIndex).toSlot();
}
return slot;
}
);
LogicalProject<Plan> newChild = new LogicalProject<>(
ImmutableList.<NamedExpression>builder()
.addAll(child.getOutput())
.addAll(pushedOutput)
.build(),
child
);
newChildrenOutputs.add((List) newChild.getOutput());
newChildren.add(newChild);
}
for (List<NamedExpression> originConstantExprs : union.getConstantExprsList()) {
List<NamedExpression> pushedOutput = replaceSlot(
ctx.statementContext,
pushDownProjections,
slot -> {
Integer sourceColumnIndex = slotToColumnIndex.get(slot);
if (sourceColumnIndex != null) {
return originConstantExprs.get(sourceColumnIndex).toSlot();
}
return slot;
}
);
LogicalOneRowRelation originOneRowRelation = new LogicalOneRowRelation(
ctx.statementContext.getNextRelationId(),
originConstantExprs
);
LogicalProject<Plan> newChild = new LogicalProject<>(
ImmutableList.<NamedExpression>builder()
.addAll(originOneRowRelation.getOutput())
.addAll(pushedOutput)
.build(),
originOneRowRelation
);
newChildrenOutputs.add((List) newChild.getOutput());
newChildren.add(newChild);
}
List<NamedExpression> newUnionOutputs = new ArrayList<>(union.getOutputs());
for (NamedExpression projection : pushDownProjections) {
newUnionOutputs.add(projection.toSlot());
}
return new LogicalProject<>(
pushProjects.second,
new LogicalUnion(
union.getQualifier(),
newUnionOutputs,
newChildrenOutputs,
ImmutableList.of(),
union.hasPushedFilter(),
newChildren
)
);
}
return project;
}
private List<NamedExpression> replaceSlot(
StatementContext statementContext,
Collection<NamedExpression> pushDownProjections,
Function<Slot, Slot> slotReplace) {
List<NamedExpression> pushedOutput = new ArrayList<>();
for (NamedExpression projection : pushDownProjections) {
NamedExpression newOutput = (NamedExpression) projection.rewriteUp(e -> {
if (e instanceof Slot) {
Slot newSlot = slotReplace.apply((Slot) e);
if (newSlot != null) {
return newSlot;
}
}
return e;
});
if (newOutput instanceof Alias) {
pushedOutput.add(new Alias(statementContext.getNextExprId(), newOutput.child(0)));
} else {
pushedOutput.add(new Alias(statementContext.getNextExprId(), newOutput));
}
}
return pushedOutput;
}
private static class PushdownProjectHelper {
private final Plan plan;
private final StatementContext statementContext;
private final Map<Expression, Pair<Slot, Plan>> exprToChildAndSlot;
private final Multimap<Plan, NamedExpression> childToPushDownProjects;
public PushdownProjectHelper(StatementContext statementContext, Plan plan) {
this.statementContext = statementContext;
this.plan = plan;
this.exprToChildAndSlot = new LinkedHashMap<>();
this.childToPushDownProjects = ArrayListMultimap.create();
}
public <C extends Collection<E>, E extends Expression> Pair<Boolean, C> pushDownExpressions(C expressions) {
ImmutableCollection.Builder<E> builder;
if (expressions instanceof List) {
builder = ImmutableList.builderWithExpectedSize(expressions.size());
} else {
builder = ImmutableSet.builderWithExpectedSize(expressions.size());
}
boolean extracted = false;
for (E expression : expressions) {
Optional<E> result = pushDownExpression(expression);
if (!result.isPresent()) {
builder.add(expression);
} else {
extracted = true;
builder.add(result.get());
}
}
if (extracted) {
return Pair.of(true, (C) builder.build());
} else {
return Pair.of(false, expressions);
}
}
public <E extends Expression> Optional<E> pushDownExpression(E expression) {
if (!(expression instanceof PreferPushDownProject
|| (expression instanceof Alias && expression.child(0) instanceof PreferPushDownProject))) {
return Optional.empty();
}
Pair<Slot, Plan> existPushdown = exprToChildAndSlot.get(expression);
if (existPushdown != null) {
return Optional.of((E) existPushdown.first);
}
Alias pushDownAlias = null;
if (expression instanceof Alias) {
pushDownAlias = (Alias) expression;
} else {
pushDownAlias = new Alias(statementContext.getNextExprId(), expression);
}
Set<Slot> inputSlots = expression.getInputSlots();
for (Plan child : plan.children()) {
if (child.getOutputSet().containsAll(inputSlots)) {
Slot remaimSlot = pushDownAlias.toSlot();
exprToChildAndSlot.put(expression, Pair.of(remaimSlot, child));
childToPushDownProjects.put(child, pushDownAlias);
return Optional.of((E) remaimSlot);
}
}
return Optional.empty();
}
public List<Plan> buildNewChildren() {
if (childToPushDownProjects.isEmpty()) {
return plan.children();
}
ImmutableList.Builder<Plan> newChildren = ImmutableList.builderWithExpectedSize(plan.arity());
for (Plan child : plan.children()) {
Collection<NamedExpression> newProject = childToPushDownProjects.get(child);
if (newProject.isEmpty()) {
newChildren.add(child);
} else {
newChildren.add(
new LogicalProject<>(
ImmutableList.<NamedExpression>builder()
.addAll(child.getOutput())
.addAll(newProject)
.build(),
child
)
);
}
}
return newChildren.build();
}
}
}