PlanTreeRewriteBottomUpJob.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.List;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Predicate;
/**
* PlanTreeRewriteBottomUpJob
* The job is used for bottom-up rewrite. If some rewrite rules can take effect,
* we will process all the rules from the leaf node again. So there are some rules that can take effect interactively,
* we should use the 'Bottom-Up' job to handle it.
*/
public class PlanTreeRewriteBottomUpJob extends PlanTreeRewriteJob {
// REWRITE_STATE_KEY represents the key to store the 'RewriteState'. Each plan node has their own 'RewriteState'.
// Different 'RewriteState' has different actions,
// so we will do specified action for each node based on their 'RewriteState'.
private static final String REWRITE_STATE_KEY = "rewrite_state";
private final RewriteJobContext rewriteJobContext;
private final List<Rule> rules;
private final int batchId;
enum RewriteState {
// 'REWRITE_THIS' means the current plan node can be handled immediately. If the plan state is 'REWRITE_THIS',
// it means all of its children's state are 'REWRITTEN'. Because we handle the plan tree bottom up.
REWRITE_THIS,
// 'REWRITTEN' means the current plan have been handled already, we don't need to do anything else.
REWRITTEN,
// 'ENSURE_CHILDREN_REWRITTEN' means we need to check the children for the current plan node first.
// It means some plans have changed after rewrite, so we need traverse the plan tree and reset their state.
// All the plan nodes need to be handled again.
ENSURE_CHILDREN_REWRITTEN
}
public PlanTreeRewriteBottomUpJob(
RewriteJobContext rewriteJobContext, JobContext context,
Predicate<Plan> isTraverseChildren, List<Rule> rules) {
super(JobType.BOTTOM_UP_REWRITE, context, isTraverseChildren);
this.rewriteJobContext = Objects.requireNonNull(rewriteJobContext, "rewriteContext cannot be null");
this.rules = Objects.requireNonNull(rules, "rules cannot be null");
this.batchId = rewriteJobContext.batchId;
}
@Override
public void execute() {
// We'll do different actions based on their different states.
// You can check the comment in 'RewriteState' structure for more details.
Plan plan = rewriteJobContext.plan;
RewriteState state = getState(plan, batchId);
switch (state) {
case REWRITE_THIS:
rewriteThis();
return;
case ENSURE_CHILDREN_REWRITTEN:
ensureChildrenRewritten();
return;
case REWRITTEN:
rewriteJobContext.result = plan;
return;
default:
throw new IllegalStateException("Unknown rewrite state: " + state);
}
}
private void rewriteThis() {
// Link the current node with the sub-plan to get the current plan which is used in the rewrite phase later.
Plan plan = linkChildren(rewriteJobContext.plan, rewriteJobContext.childrenContext);
RewriteResult rewriteResult = rewrite(plan, rules, rewriteJobContext);
if (rewriteResult.hasNewPlan) {
RewriteJobContext newJobContext = rewriteJobContext.withPlan(rewriteResult.plan);
RewriteState state = getState(rewriteResult.plan, batchId);
// Some eliminate rule will return a rewritten plan, for example the current node is eliminated
// and return the child plan. So we don't need to handle it again.
if (state == RewriteState.REWRITTEN) {
newJobContext.setResult(rewriteResult.plan);
return;
}
// After the rewrite take effect, we should handle the children part again.
pushJob(new PlanTreeRewriteBottomUpJob(newJobContext, context, isTraverseChildren, rules));
setState(rewriteResult.plan, RewriteState.ENSURE_CHILDREN_REWRITTEN, batchId);
} else {
// No new plan is generated, so just set the state of the current plan to 'REWRITTEN'.
setState(rewriteResult.plan, RewriteState.REWRITTEN, batchId);
rewriteJobContext.setResult(rewriteResult.plan);
}
}
private void ensureChildrenRewritten() {
Plan plan = rewriteJobContext.plan;
int batchId = rewriteJobContext.batchId;
setState(plan, RewriteState.REWRITE_THIS, batchId);
pushJob(new PlanTreeRewriteBottomUpJob(rewriteJobContext, context, isTraverseChildren, rules));
// some rule return new plan tree, which the number of new plan node > 1,
// we should transform this new plan nodes too.
if (isTraverseChildren.test(plan)) {
pushChildrenJobs(plan);
}
}
private void pushChildrenJobs(Plan plan) {
List<Plan> children = plan.children();
switch (children.size()) {
case 0: return;
case 1:
Plan child = children.get(0);
RewriteJobContext childRewriteJobContext = new RewriteJobContext(
child, rewriteJobContext, 0, false, batchId);
pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, isTraverseChildren, rules));
return;
case 2:
Plan right = children.get(1);
RewriteJobContext rightRewriteJobContext = new RewriteJobContext(
right, rewriteJobContext, 1, false, batchId);
pushJob(new PlanTreeRewriteBottomUpJob(rightRewriteJobContext, context, isTraverseChildren, rules));
Plan left = children.get(0);
RewriteJobContext leftRewriteJobContext = new RewriteJobContext(
left, rewriteJobContext, 0, false, batchId);
pushJob(new PlanTreeRewriteBottomUpJob(leftRewriteJobContext, context, isTraverseChildren, rules));
return;
default:
for (int i = children.size() - 1; i >= 0; i--) {
child = children.get(i);
childRewriteJobContext = new RewriteJobContext(
child, rewriteJobContext, i, false, batchId);
pushJob(new PlanTreeRewriteBottomUpJob(childRewriteJobContext, context, isTraverseChildren, rules));
}
}
}
private static RewriteState getState(Plan plan, int currentBatchId) {
Optional<RewriteStateContext> state = plan.getMutableState(REWRITE_STATE_KEY);
if (!state.isPresent()) {
return RewriteState.ENSURE_CHILDREN_REWRITTEN;
}
RewriteStateContext context = state.get();
if (context.batchId != currentBatchId) {
return RewriteState.ENSURE_CHILDREN_REWRITTEN;
}
return context.rewriteState;
}
private static void setState(Plan plan, RewriteState state, int batchId) {
plan.setMutableState(REWRITE_STATE_KEY, new RewriteStateContext(state, batchId));
}
private static class RewriteStateContext {
private final RewriteState rewriteState;
private final int batchId;
public RewriteStateContext(RewriteState rewriteState, int batchId) {
this.rewriteState = rewriteState;
this.batchId = batchId;
}
}
}