AbstractBatchJobExecutor.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.executor;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.rewrite.CostBasedRewriteJob;
import org.apache.doris.nereids.jobs.rewrite.CustomRewriteJob;
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteBottomUpJob;
import org.apache.doris.nereids.jobs.rewrite.PlanTreeRewriteTopDownJob;
import org.apache.doris.nereids.jobs.rewrite.RewriteJob;
import org.apache.doris.nereids.jobs.rewrite.RootPlanTreeRewriteJob;
import org.apache.doris.nereids.jobs.rewrite.TopicRewriteJob;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleFactory;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.visitor.CustomRewriter;
import com.google.common.collect.ImmutableList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import java.util.Set;
import java.util.function.Predicate;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* Base class for executing all jobs.
*
* Each batch of rules will be uniformly executed.
*/
public abstract class AbstractBatchJobExecutor {
private static final ThreadLocal<Set<Class<Plan>>> NOT_TRAVERSE_CHILDREN = new ThreadLocal();
private static final Predicate<Plan> TRAVERSE_ALL_PLANS = plan -> true;
protected CascadesContext cascadesContext;
public AbstractBatchJobExecutor(CascadesContext cascadesContext) {
this.cascadesContext = Objects.requireNonNull(cascadesContext, "cascadesContext can not null");
}
/**
* flat map jobs in TopicRewriteJob to could really run jobs, and filter null.
*/
public static List<RewriteJob> jobs(RewriteJob... jobs) {
return Arrays.stream(jobs)
.filter(Objects::nonNull)
.flatMap(job -> job instanceof TopicRewriteJob
? ((TopicRewriteJob) job).jobs.stream().filter(Objects::nonNull)
: Stream.of(job)
).collect(ImmutableList.toImmutableList());
}
/** notTraverseChildrenOf */
public static <T> T notTraverseChildrenOf(
Set<Class<? extends Plan>> notTraverseClasses, Supplier<T> action) {
try {
NOT_TRAVERSE_CHILDREN.set((Set) notTraverseClasses);
return action.get();
} finally {
NOT_TRAVERSE_CHILDREN.remove();
}
}
public static TopicRewriteJob topic(String topicName, RewriteJob... jobs) {
return new TopicRewriteJob(topicName, Arrays.asList(jobs));
}
public static RewriteJob costBased(RewriteJob... jobs) {
return new CostBasedRewriteJob(Arrays.asList(jobs));
}
public static RewriteJob bottomUp(RuleFactory... ruleFactories) {
return bottomUp(Arrays.asList(ruleFactories));
}
public static RewriteJob bottomUp(List<RuleFactory> ruleFactories) {
List<Rule> rules = ruleFactories.stream()
.map(RuleFactory::buildRules)
.flatMap(List::stream)
.collect(ImmutableList.toImmutableList());
return new RootPlanTreeRewriteJob(rules, PlanTreeRewriteBottomUpJob::new, getTraversePredicate(), true);
}
public static RewriteJob topDown(RuleFactory... ruleFactories) {
return topDown(Arrays.asList(ruleFactories));
}
public static RewriteJob topDown(List<RuleFactory> ruleFactories) {
return topDown(ruleFactories, true);
}
public static RewriteJob topDown(List<RuleFactory> ruleFactories, boolean once) {
List<Rule> rules = ruleFactories.stream()
.map(RuleFactory::buildRules)
.flatMap(List::stream)
.collect(ImmutableList.toImmutableList());
return new RootPlanTreeRewriteJob(rules, PlanTreeRewriteTopDownJob::new, getTraversePredicate(), once);
}
public static RewriteJob custom(RuleType ruleType, Supplier<CustomRewriter> planRewriter) {
return new CustomRewriteJob(planRewriter, ruleType);
}
/**
* execute.
*/
public void execute() {
for (int i = 0; i < getJobs().size(); i++) {
JobContext jobContext = cascadesContext.getCurrentJobContext();
RewriteJob currentJob = getJobs().get(i);
if (currentJob instanceof CostBasedRewriteJob) {
List<RewriteJob> remainJobs = getJobs().subList(i + 1, getJobs().size()).stream()
.filter(j -> !(j instanceof CostBasedRewriteJob))
.collect(Collectors.toList());
jobContext.setRemainJobs(remainJobs);
}
do {
jobContext.setRewritten(false);
currentJob.execute(jobContext);
} while (!currentJob.isOnce() && jobContext.isRewritten());
}
}
public abstract List<RewriteJob> getJobs();
private static Predicate<Plan> getTraversePredicate() {
Set<Class<Plan>> notTraverseChildren = NOT_TRAVERSE_CHILDREN.get();
return notTraverseChildren == null
? TRAVERSE_ALL_PLANS
: new NotTraverseChildren(notTraverseChildren);
}
private static class NotTraverseChildren implements Predicate<Plan> {
private final Set<Class<Plan>> notTraverseChildren;
public NotTraverseChildren(Set<Class<Plan>> notTraverseChildren) {
this.notTraverseChildren = Objects.requireNonNull(notTraverseChildren, "notTraversePlans can not be null");
}
@Override
public boolean test(Plan plan) {
return !notTraverseChildren.contains(plan.getClass());
}
}
}