RootPlanTreeRewriteJob.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.rewrite;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.Job;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.JobType;
import org.apache.doris.nereids.jobs.scheduler.JobStack;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.trees.plans.Plan;
import java.util.List;
import java.util.Objects;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Predicate;
/** RootPlanTreeRewriteJob */
public class RootPlanTreeRewriteJob implements RewriteJob {
private static final AtomicInteger BATCH_ID = new AtomicInteger();
private final List<Rule> rules;
private final RewriteJobBuilder rewriteJobBuilder;
private final boolean once;
private final Predicate<Plan> isTraverseChildren;
public RootPlanTreeRewriteJob(List<Rule> rules, RewriteJobBuilder rewriteJobBuilder, boolean once) {
this(rules, rewriteJobBuilder, plan -> true, once);
}
public RootPlanTreeRewriteJob(
List<Rule> rules, RewriteJobBuilder rewriteJobBuilder, Predicate<Plan> isTraverseChildren, boolean once) {
this.rules = Objects.requireNonNull(rules, "rules cannot be null");
this.rewriteJobBuilder = Objects.requireNonNull(rewriteJobBuilder, "rewriteJobBuilder cannot be null");
this.once = once;
this.isTraverseChildren = isTraverseChildren;
}
@Override
public void execute(JobContext context) {
CascadesContext cascadesContext = context.getCascadesContext();
// get plan from the cascades context
Plan root = cascadesContext.getRewritePlan();
// write rewritten root plan to cascades context by the RootRewriteJobContext
int batchId = BATCH_ID.incrementAndGet();
RootRewriteJobContext rewriteJobContext = new RootRewriteJobContext(
root, false, context, batchId);
Job rewriteJob = rewriteJobBuilder.build(rewriteJobContext, context, isTraverseChildren, rules);
context.getScheduleContext().pushJob(rewriteJob);
cascadesContext.getJobScheduler().executeJobPool(cascadesContext);
cascadesContext.setCurrentRootRewriteJobContext(null);
}
@Override
public boolean isOnce() {
return once;
}
/** RewriteJobBuilder */
public interface RewriteJobBuilder {
Job build(RewriteJobContext rewriteJobContext, JobContext jobContext,
Predicate<Plan> isTraverseChildren, List<Rule> rules);
}
/** RootRewriteJobContext */
public static class RootRewriteJobContext extends RewriteJobContext {
private final JobContext jobContext;
RootRewriteJobContext(Plan plan, boolean childrenVisited, JobContext jobContext, int batchId) {
super(plan, null, -1, childrenVisited, batchId);
this.jobContext = Objects.requireNonNull(jobContext, "jobContext cannot be null");
jobContext.getCascadesContext().setCurrentRootRewriteJobContext(this);
}
@Override
public boolean isRewriteRoot() {
return true;
}
@Override
public void setResult(Plan result) {
jobContext.getCascadesContext().setRewritePlan(result);
}
@Override
public RewriteJobContext withChildrenVisited(boolean childrenVisited) {
return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId);
}
@Override
public RewriteJobContext withPlan(Plan plan) {
return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId);
}
@Override
public RewriteJobContext withPlanAndChildrenVisited(Plan plan, boolean childrenVisited) {
return new RootRewriteJobContext(plan, childrenVisited, jobContext, batchId);
}
/** linkChildren */
public Plan getNewestPlan() {
JobStack jobStack = new JobStack();
LinkPlanJob linkPlanJob = new LinkPlanJob(
jobContext, this, null, false, jobStack);
jobStack.push(linkPlanJob);
while (!jobStack.isEmpty()) {
Job job = jobStack.pop();
job.execute();
}
return linkPlanJob.result;
}
}
public List<Rule> getRules() {
return rules;
}
/** use to assemble the rewriting plan */
private static class LinkPlanJob extends Job {
LinkPlanJob parentJob;
RewriteJobContext rewriteJobContext;
Plan[] childrenResult;
Plan result;
boolean linked;
JobStack jobStack;
private LinkPlanJob(JobContext context, RewriteJobContext rewriteJobContext,
LinkPlanJob parentJob, boolean linked, JobStack jobStack) {
super(JobType.LINK_PLAN, context);
this.rewriteJobContext = rewriteJobContext;
this.parentJob = parentJob;
this.linked = linked;
this.childrenResult = new Plan[rewriteJobContext.plan.arity()];
this.jobStack = jobStack;
}
@Override
public void execute() {
if (!linked) {
linked = true;
jobStack.push(this);
for (int i = rewriteJobContext.childrenContext.length - 1; i >= 0; i--) {
RewriteJobContext childContext = rewriteJobContext.childrenContext[i];
if (childContext != null) {
jobStack.push(new LinkPlanJob(context, childContext, this, false, jobStack));
}
}
} else if (rewriteJobContext.result != null) {
linkResult(rewriteJobContext.result);
} else {
Plan[] newChildren = new Plan[childrenResult.length];
for (int i = 0; i < newChildren.length; i++) {
Plan childResult = childrenResult[i];
if (childResult == null) {
childResult = rewriteJobContext.plan.child(i);
}
newChildren[i] = childResult;
}
linkResult(rewriteJobContext.plan.withChildren(newChildren));
}
}
private void linkResult(Plan result) {
if (parentJob != null) {
parentJob.childrenResult[rewriteJobContext.childIndexInParentContext] = result;
} else {
this.result = result;
}
}
}
}