PushDownAggThroughJoinOnPkFk.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.properties.DataTrait;
import org.apache.doris.nereids.properties.FuncDeps;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Project;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.JoinUtils;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Sets;
import org.apache.thrift.annotation.Nullable;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Set;
/**
* Push down agg through join with foreign key:
* Agg(group by fk/pk)
* |
* Join(pk = fk)
* / \
* pk fk
* ======>
* Join(pk = fk)
* / \
* | Agg(group by fk)
* | |
* pk fk
*/
public class PushDownAggThroughJoinOnPkFk implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
logicalAggregate(
innerLogicalJoin()
.when(j -> !j.isMarkJoin()
&& j.getOtherJoinConjuncts().isEmpty()))
.when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance))
.thenApply(ctx -> pushAgg(ctx.root, ctx.root.child()))
.toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ON_PKFK),
logicalAggregate(
logicalProject(
innerLogicalJoin()
.when(j -> j.getJoinType().isInnerJoin()
&& !j.isMarkJoin()
&& j.getOtherJoinConjuncts().isEmpty()))
.when(Project::isAllSlots))
.when(agg -> agg.getGroupByExpressions().stream().allMatch(Slot.class::isInstance))
.thenApply(ctx -> pushAgg(ctx.root, ctx.root.child().child()))
.toRule(RuleType.PUSH_DOWN_AGG_THROUGH_JOIN_ON_PKFK)
);
}
private @Nullable Plan pushAgg(LogicalAggregate<?> agg, LogicalJoin<?, ?> join) {
InnerJoinCluster innerJoinCluster = new InnerJoinCluster();
innerJoinCluster.collectContiguousInnerJoins(join);
if (!innerJoinCluster.isValid()) {
return null;
}
for (Entry<BitSet, LogicalJoin<?, ?>> e : innerJoinCluster.getJoinsMap().entrySet()) {
LogicalJoin<?, ?> subJoin = e.getValue();
Pair<Plan, Plan> primaryAndForeign = tryExtractPrimaryForeign(subJoin);
if (primaryAndForeign == null) {
continue;
}
LogicalAggregate<?> newAgg =
eliminatePrimaryOutput(agg, subJoin, primaryAndForeign.first, primaryAndForeign.second);
if (newAgg == null) {
return null;
}
LogicalJoin<?, ?> newJoin = innerJoinCluster
.constructJoinWithPrimary(e.getKey(), subJoin, primaryAndForeign.first);
if (newJoin != null && newJoin.left() == primaryAndForeign.first) {
newJoin = (LogicalJoin<?, ?>) newJoin
.withChildren(newJoin.left(), newAgg.withChildren(newJoin.right()));
if (Sets.union(newJoin.left().getOutputSet(), newJoin.right().getOutputSet())
.containsAll(newJoin.getInputSlots())) {
return newJoin;
}
} else if (newJoin != null && newJoin.right() == primaryAndForeign.first) {
newJoin = (LogicalJoin<?, ?>) newJoin
.withChildren(newAgg.withChildren(newJoin.left()), newJoin.right());
if (Sets.union(newJoin.left().getOutputSet(), newJoin.right().getOutputSet())
.containsAll(newJoin.getInputSlots())) {
return newJoin;
}
}
}
return null;
}
// eliminate the slot of primary plan in agg
private LogicalAggregate<?> eliminatePrimaryOutput(LogicalAggregate<?> agg, Plan child,
Plan primary, Plan foreign) {
Set<Slot> aggInputs = agg.getInputSlots();
if (primary.getOutputSet().stream().noneMatch(aggInputs::contains)) {
return agg;
}
Set<Slot> primaryOutputSet = primary.getOutputSet();
Set<Slot> primarySlots = Sets.intersection(aggInputs, primaryOutputSet);
DataTrait dataTrait = child.getLogicalProperties().getTrait();
FuncDeps funcDeps = dataTrait.getAllValidFuncDeps(Sets.union(foreign.getOutputSet(), primary.getOutputSet()));
HashMap<Slot, Slot> primaryToForeignDeps = new HashMap<>();
for (Slot slot : primarySlots) {
Set<Set<Slot>> replacedSlotSets = funcDeps.findDeterminats(ImmutableSet.of(slot));
for (Set<Slot> replacedSlots : replacedSlotSets) {
if (primaryOutputSet.stream().noneMatch(replacedSlots::contains)
&& replacedSlots.size() == 1) {
primaryToForeignDeps.put(slot, replacedSlots.iterator().next());
break;
}
}
}
Set<Expression> newGroupBySlots = constructNewGroupBy(agg, primaryOutputSet, primaryToForeignDeps);
List<NamedExpression> newOutput = constructNewOutput(
agg, primaryOutputSet, primaryToForeignDeps, funcDeps, primary);
if (newGroupBySlots == null || newOutput == null) {
return null;
}
return agg.withGroupByAndOutput(ImmutableList.copyOf(newGroupBySlots), ImmutableList.copyOf(newOutput));
}
private @Nullable Set<Expression> constructNewGroupBy(LogicalAggregate<?> agg, Set<Slot> primaryOutputs,
Map<Slot, Slot> primaryToForeignBiDeps) {
Set<Expression> newGroupBySlots = new HashSet<>();
for (Expression expression : agg.getGroupByExpressions()) {
if (!(expression instanceof Slot)) {
return null;
}
if (primaryOutputs.contains((Slot) expression)
&& !primaryToForeignBiDeps.containsKey((Slot) expression)) {
return null;
}
expression = primaryToForeignBiDeps.getOrDefault(expression, (Slot) expression);
newGroupBySlots.add(expression);
}
return newGroupBySlots;
}
private @Nullable List<NamedExpression> constructNewOutput(LogicalAggregate<?> agg, Set<Slot> primaryOutput,
Map<Slot, Slot> primaryToForeignDeps, FuncDeps funcDeps, Plan primaryPlan) {
List<NamedExpression> newOutput = new ArrayList<>();
for (NamedExpression expression : agg.getOutputExpressions()) {
// There are three cases for output expressions:
// 1. Slot: the slot is from primary plan, we need to replace it with
// the corresponding slot from foreign plan,
// or skip it when it isn't in group by.
// 2. Count: the count is from primary plan,
// we need to replace the slot in the count with the corresponding slot
// from foreign plan
if (expression instanceof Slot && primaryPlan.getOutput().contains(expression)) {
if (primaryToForeignDeps.containsKey(expression)) {
expression = primaryToForeignDeps.getOrDefault(expression, expression.toSlot());
} else {
continue;
}
}
if (expression instanceof Alias
&& expression.child(0) instanceof Count
&& expression.child(0).child(0) instanceof Slot) {
// count(slot) can be rewritten by circle deps
Slot slot = (Slot) expression.child(0).child(0);
if (primaryToForeignDeps.containsKey(slot)
&& funcDeps.isCircleDeps(
ImmutableSet.of(slot), ImmutableSet.of(primaryToForeignDeps.get(slot)))) {
expression = (NamedExpression) expression.rewriteUp(e ->
e instanceof Slot
? primaryToForeignDeps.getOrDefault((Slot) e, (Slot) e)
: e);
}
}
if (!(expression instanceof Slot)
&& expression.getInputSlots().stream().anyMatch(primaryOutput::contains)) {
return null;
}
newOutput.add(expression);
}
return newOutput;
}
// try to extract primary key table and foreign key table
private @Nullable Pair<Plan, Plan> tryExtractPrimaryForeign(LogicalJoin<?, ?> join) {
Plan primary;
Plan foreign;
if (JoinUtils.canEliminateByFk(join, join.left(), join.right())) {
primary = join.left();
foreign = join.right();
} else if (JoinUtils.canEliminateByFk(join, join.right(), join.left())) {
primary = join.right();
foreign = join.left();
} else {
return null;
}
return Pair.of(primary, foreign);
}
/**
* This class flattens nested join clusters and optimizes aggregation pushdown.
*
* Example of flattening:
* Join1 Join1 Join2
* / \ / \ / \
* a Join2 =====> a b b c
* / \
* b c
*
* After flattening, we attempt to push down aggregations for each join.
* For instance, if b is a primary key table and c is a foreign key table:
*
* Original (can't push down): After flattening (can push down):
* agg(Join1) Join1 Join2
* / \ / \ / \
* a Join2 =====> a b b agg(c)
* / \
* b c
*
* Finally, we can reorganize the join tree:
* Join2
* / \
* agg(c) Join1
* / \
* a b
*/
static class InnerJoinCluster {
private final Map<BitSet, LogicalJoin<?, ?>> innerJoins = new HashMap<>();
private final List<Plan> leaf = new ArrayList<>();
void collectContiguousInnerJoins(Plan plan) {
if (!isSlotProject(plan) && !isInnerJoin(plan)) {
leaf.add(plan);
return;
}
for (Plan child : plan.children()) {
collectContiguousInnerJoins(child);
}
if (isInnerJoin(plan)) {
LogicalJoin<?, ?> join = (LogicalJoin<?, ?>) plan;
Set<Slot> inputSlots = join.getInputSlots();
BitSet childrenIndices = new BitSet();
List<Plan> children = new ArrayList<>();
for (int i = 0; i < leaf.size(); i++) {
if (!Sets.intersection(leaf.get(i).getOutputSet(), inputSlots).isEmpty()) {
childrenIndices.set(i);
children.add(leaf.get(i));
}
}
if (childrenIndices.cardinality() == 2) {
join = join.withChildren(children);
}
innerJoins.put(childrenIndices, join);
}
}
boolean isValid() {
// we cannot handle the case that there is any join with more than 2 children
return innerJoins.keySet().stream().allMatch(x -> x.cardinality() == 2);
}
@Nullable LogicalJoin<?, ?> constructJoinWithPrimary(BitSet bitSet, LogicalJoin<?, ?> join, Plan primary) {
Set<BitSet> forbiddenJoin = new HashSet<>();
forbiddenJoin.add(bitSet);
BitSet totalBitset = new BitSet();
totalBitset.set(0, leaf.size());
totalBitset.set(leaf.indexOf(primary), false);
Plan childPlan = constructPlan(totalBitset, forbiddenJoin);
if (childPlan == null) {
return null;
}
return (LogicalJoin<?, ?>) join.withChildren(childPlan, primary);
}
@Nullable Plan constructPlan(BitSet bitSet, Set<BitSet> forbiddenJoin) {
if (bitSet.cardinality() == 1) {
return leaf.get(bitSet.nextSetBit(0));
}
BitSet currentBitset = new BitSet();
Plan currentPlan = null;
while (!currentBitset.equals(bitSet)) {
boolean addJoin = false;
for (Entry<BitSet, LogicalJoin<?, ?>> entry : innerJoins.entrySet()) {
if (forbiddenJoin.contains(entry.getKey())) {
continue;
}
if (currentBitset.isEmpty()) {
addJoin = true;
currentBitset.or(entry.getKey());
currentPlan = entry.getValue();
forbiddenJoin.add(entry.getKey());
} else if (currentBitset.intersects(entry.getKey())) {
addJoin = true;
currentBitset.or(entry.getKey());
currentPlan = currentPlan.withChildren(currentPlan, entry.getValue());
forbiddenJoin.add(entry.getKey());
}
}
if (!addJoin) {
// if we cannot find any join to add, just return null
// It means we cannot construct a join
return null;
}
}
return currentPlan;
}
Map<BitSet, LogicalJoin<?, ?>> getJoinsMap() {
return innerJoins;
}
boolean isSlotProject(Plan plan) {
return plan instanceof LogicalProject
&& ((LogicalProject<?>) (plan)).isAllSlots();
}
boolean isInnerJoin(Plan plan) {
return plan instanceof LogicalJoin
&& ((LogicalJoin<?, ?>) plan).getJoinType().isInnerJoin()
&& !((LogicalJoin<?, ?>) plan).isMarkJoin()
&& ((LogicalJoin<?, ?>) plan).getOtherJoinConjuncts().isEmpty();
}
}
}