PushDownLimitDistinctThroughJoin.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.rules.rewrite;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import java.util.List;
import java.util.stream.Collectors;
/**
* Same with PushdownLimit
*/
public class PushDownLimitDistinctThroughJoin implements RewriteRuleFactory {
@Override
public List<Rule> buildRules() {
return ImmutableList.of(
// limit -> distinct -> join
logicalLimit(logicalAggregate(logicalJoin())
.when(LogicalAggregate::isDistinct))
.when(limit ->
ConnectContext.get() != null
&& ConnectContext.get().getSessionVariable().topnOptLimitThreshold
>= limit.getLimit() + limit.getOffset())
.then(limit -> {
LogicalAggregate<LogicalJoin<Plan, Plan>> agg = limit.child();
LogicalJoin<Plan, Plan> join = agg.child();
Plan newJoin = pushLimitThroughJoin(limit, join);
if (newJoin == null || join.children().equals(newJoin.children())) {
return null;
}
return limit.withChildren(agg.withChildren(newJoin));
})
.toRule(RuleType.PUSH_DOWN_LIMIT_DISTINCT_THROUGH_JOIN),
// limit -> distinct -> project -> join
logicalLimit(logicalAggregate(logicalProject(logicalJoin()).when(LogicalProject::isAllSlots))
.when(LogicalAggregate::isDistinct))
.then(limit -> {
LogicalAggregate<LogicalProject<LogicalJoin<Plan, Plan>>> agg = limit.child();
LogicalProject<LogicalJoin<Plan, Plan>> project = agg.child();
LogicalJoin<Plan, Plan> join = project.child();
Plan newJoin = pushLimitThroughJoin(limit, join);
if (newJoin == null || join.children().equals(newJoin.children())) {
return null;
}
return limit.withChildren(agg.withChildren(project.withChildren(newJoin)));
}).toRule(RuleType.PUSH_DOWN_LIMIT_DISTINCT_THROUGH_PROJECT_JOIN)
);
}
private Plan pushLimitThroughJoin(LogicalLimit<?> limit, LogicalJoin<Plan, Plan> join) {
LogicalAggregate<?> agg = (LogicalAggregate<?>) limit.child();
List<Slot> groupBySlots = agg.getGroupByExpressions().stream()
.flatMap(e -> e.getInputSlots().stream()).collect(Collectors.toList());
switch (join.getJoinType()) {
case LEFT_OUTER_JOIN:
if (join.left().getOutputSet().containsAll(groupBySlots)
&& join.left().getOutputSet().equals(agg.getOutputSet())) {
return join.withChildren(limit.withLimitChild(limit.getLimit() + limit.getOffset(), 0,
agg.withChildren(join.left())), join.right());
}
return null;
case RIGHT_OUTER_JOIN:
if (join.right().getOutputSet().containsAll(groupBySlots)
&& join.right().getOutputSet().equals(agg.getOutputSet())) {
return join.withChildren(join.left(), limit.withLimitChild(limit.getLimit() + limit.getOffset(), 0,
agg.withChildren(join.right())));
}
return null;
case CROSS_JOIN:
if (join.left().getOutputSet().containsAll(groupBySlots)
&& join.left().getOutputSet().equals(agg.getOutputSet())) {
return join.withChildren(limit.withLimitChild(limit.getLimit() + limit.getOffset(), 0,
agg.withChildren(join.left())), join.right());
} else if (join.right().getOutputSet().containsAll(groupBySlots)
&& join.right().getOutputSet().equals(agg.getOutputSet())) {
return join.withChildren(join.left(), limit.withLimitChild(limit.getLimit() + limit.getOffset(), 0,
agg.withChildren(join.right())));
} else {
return null;
}
default:
return null;
}
}
}