PlanReceiver.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder.hypergraph.receiver;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupExpressionJob;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.Edge;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.edge.JoinEdge;
import org.apache.doris.nereids.memo.CopyInResult;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.stream.Collectors;
import java.util.stream.Stream;
/**
* The Receiver is used for cached the plan that has been emitted and build the new plan
*/
public class PlanReceiver implements AbstractReceiver {
// limit define the max number of csg-cmp pair in this Receiver
HashMap<Long, Group> planTable = new HashMap<>();
HashMap<Long, BitSet> usdEdges = new HashMap<>();
HashMap<Long, List<NamedExpression>> complexProjectMap = new HashMap<>();
int limit;
int emitCount = 0;
JobContext jobContext;
HyperGraph hyperGraph;
final Set<Slot> finalOutputs;
long startTime = System.currentTimeMillis();
long timeLimit = ConnectContext.get().getSessionVariable().joinReorderTimeLimit;
public PlanReceiver(JobContext jobContext, int limit, HyperGraph hyperGraph, Set<Slot> outputs) {
this.jobContext = jobContext;
this.limit = limit;
this.hyperGraph = hyperGraph;
this.finalOutputs = outputs;
}
/**
* Emit a new plan from bottom to top
* <p>
* The purpose of EmitCsgCmp is to combine the optimal plans for S1 and S2 into a csg-cmp-pair.
* It requires calculating the proper join predicate and costs of the resulting joins.
* In the end, update dpTables.
*
* @param left the bitmap of left child tree
* @param right the bitmap of the right child tree
* @param edges the join conditions that can be added in this operator
* @return the left and the right can be connected by the edge
*/
@Override
public boolean emitCsgCmp(long left, long right, List<JoinEdge> edges) {
Preconditions.checkArgument(planTable.containsKey(left));
Preconditions.checkArgument(planTable.containsKey(right));
processMissedEdges(left, right, edges);
emitCount += 1;
if (emitCount > limit || System.currentTimeMillis() - startTime > timeLimit) {
return false;
}
Memo memo = jobContext.getCascadesContext().getMemo();
GroupPlan leftPlan = new GroupPlan(planTable.get(left));
GroupPlan rightPlan = new GroupPlan(planTable.get(right));
// First, we implement all possible physical plans
// In this step, we don't generate logical expression because they are useless in DPhyp.
List<Expression> hashConjuncts = new ArrayList<>();
List<Expression> otherConjuncts = new ArrayList<>();
JoinType joinType = JoinEdge.extractJoinTypeAndConjuncts(edges, hashConjuncts, otherConjuncts);
if (joinType == null) {
return true;
}
long fullKey = LongBitmap.newBitmapUnion(left, right);
LogicalPlan logicalPlan = proposeJoin(joinType, leftPlan, rightPlan, hashConjuncts,
otherConjuncts);
logicalPlan = proposeProject(logicalPlan, edges, left, right);
// Second, we copy all physical plan to Group and generate properties and calculate cost
if (!planTable.containsKey(fullKey)) {
planTable.put(fullKey, memo.newGroup(logicalPlan.getLogicalProperties()));
}
Group group = planTable.get(fullKey);
CopyInResult copyInResult = memo.copyIn(logicalPlan, group, false, planTable);
proposeAllDistributedPlans(copyInResult.correspondingExpression);
return true;
}
// be aware that the requiredOutputSlots is a superset of the actual output of current node
// check proposeProject method to get how to create a project node for the outputs of current node.
private Set<Slot> calculateRequiredSlots(long left, long right, List<JoinEdge> edges) {
// required output slots = final outputs + slot of unused edges + complex project exprs(if there is any)
// 1. add finalOutputs to requiredOutputSlots
Set<Slot> requiredOutputSlots = new HashSet<>(this.finalOutputs);
BitSet usedEdgesBitmap = new BitSet();
usedEdgesBitmap.or(usdEdges.get(left));
usedEdgesBitmap.or(usdEdges.get(right));
for (Edge edge : edges) {
usedEdgesBitmap.set(edge.getIndex());
}
// 2. add unused edges' input slots to requiredOutputSlots
usdEdges.put(LongBitmap.newBitmapUnion(left, right), usedEdgesBitmap);
for (Edge edge : hyperGraph.getJoinEdges()) {
if (!usedEdgesBitmap.get(edge.getIndex())) {
requiredOutputSlots.addAll(edge.getInputSlots());
}
}
// 3. add input slots of all complex projects which should be done by all upper level (parents) nodes
// dphyper enumerate subsets before supersets, so all subsets' complex projects should be excluded here
// because it's been processed by subsets already
long fullKey = LongBitmap.newBitmapUnion(left, right);
hyperGraph.getComplexProject().entrySet().stream()
.filter(l -> !LongBitmap.isSubset(l.getKey(), fullKey))
.flatMap(l -> l.getValue().stream())
.forEach(expr -> requiredOutputSlots.addAll(expr.getInputSlots()));
return requiredOutputSlots;
}
// add any missed edge into edges to connect left and right
private void processMissedEdges(long left, long right, List<JoinEdge> edges) {
// find all used edges
BitSet usedEdgesBitmap = new BitSet();
usedEdgesBitmap.or(usdEdges.get(left));
usedEdgesBitmap.or(usdEdges.get(right));
edges.forEach(edge -> usedEdgesBitmap.set(edge.getIndex()));
// find all referenced nodes
long allReferenceNodes = LongBitmap.or(left, right);
// find the edge which is not in usedEdgesBitmap and its referenced nodes is subset of allReferenceNodes
for (JoinEdge edge : hyperGraph.getJoinEdges()) {
long referenceNodes = LongBitmap.newBitmapUnion(edge.getLeftRequiredNodes(), edge.getRightRequiredNodes());
if (LongBitmap.isSubset(referenceNodes, allReferenceNodes)
&& !usedEdgesBitmap.get(edge.getIndex())) {
// add the missed edge to edges
edges.add(edge);
}
}
}
private void proposeAllDistributedPlans(GroupExpression groupExpression) {
jobContext.getCascadesContext().pushJob(new OptimizeGroupExpressionJob(groupExpression,
new JobContext(jobContext.getCascadesContext(), PhysicalProperties.ANY, Double.MAX_VALUE)));
if (!groupExpression.isStatDerived()) {
jobContext.getCascadesContext().pushJob(new DeriveStatsJob(groupExpression,
jobContext.getCascadesContext().getCurrentJobContext()));
}
jobContext.getCascadesContext().getJobScheduler().executeJobPool(jobContext.getCascadesContext());
}
private LogicalPlan proposeJoin(JoinType joinType, Plan left, Plan right, List<Expression> hashConjuncts,
List<Expression> otherConjuncts) {
return new LogicalJoin<>(joinType, hashConjuncts, otherConjuncts, left, right, null);
}
@Override
public void addGroup(long bitmap, Group group) {
Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 1);
usdEdges.put(bitmap, new BitSet());
Plan plan = proposeProject(new GroupPlan(group), new ArrayList<>(), bitmap, bitmap);
if (!(plan instanceof GroupPlan)) {
CopyInResult copyInResult = jobContext.getCascadesContext().getMemo().copyIn(plan, null, false, planTable);
group = copyInResult.correspondingExpression.getOwnerGroup();
}
planTable.put(bitmap, group);
usdEdges.put(bitmap, new BitSet());
}
@Override
public boolean contain(long bitmap) {
return planTable.containsKey(bitmap);
}
@Override
public void reset() {
emitCount = 0;
planTable.clear();
usdEdges.clear();
complexProjectMap.clear();
complexProjectMap.putAll(hyperGraph.getComplexProject());
startTime = System.currentTimeMillis();
}
@Override
public Group getBestPlan(long bitmap) {
return planTable.get(bitmap);
}
private LogicalPlan proposeProject(LogicalPlan join, List<JoinEdge> edges, long left, long right) {
long fullKey = LongBitmap.newBitmapUnion(left, right);
List<Slot> outputs = join.getOutput();
Set<Slot> outputSet = join.getOutputSet();
List<NamedExpression> complexProjects = new ArrayList<>();
// Calculate complex expression should be done by current(fullKey) node
// the complex projects includes final output of current node(the complex project of fullKey)
// and any complex projects don't belong to subsets of fullKey except that fullKey is not a join node
List<Long> bitmaps = complexProjectMap.keySet().stream().filter(bitmap -> LongBitmap
.isSubset(bitmap, fullKey)
&& ((!LongBitmap.isSubset(bitmap, left) && !LongBitmap.isSubset(bitmap, right))
|| left == right))
.collect(Collectors.toList());
// complexProjectMap is created by a bottom up traverse of join tree, so child node is put before parent node
// in the bitmaps
bitmaps.sort(Long::compare);
for (long bitmap : bitmaps) {
if (complexProjects.isEmpty()) {
complexProjects.addAll(complexProjectMap.get(bitmap));
} else {
// Rewrite project expression by its children
complexProjects.addAll(
PlanUtils.mergeProjections(complexProjects, complexProjectMap.get(bitmap)));
}
}
// calculate required columns by all parents
Set<Slot> requireSlots = calculateRequiredSlots(left, right, edges);
List<NamedExpression> allProjects = Stream.concat(
outputs.stream().filter(requireSlots::contains),
complexProjects.stream().filter(e -> requireSlots.contains(e.toSlot()))
).collect(Collectors.toList());
// propose logical project
if (allProjects.isEmpty()) {
allProjects.add(ExpressionUtils.selectMinimumColumn(outputs));
}
if (outputSet.equals(new HashSet<>(allProjects))) {
return join;
}
Set<Slot> childOutputSet = join.getOutputSet();
List<NamedExpression> projects = allProjects.stream()
.filter(expr ->
childOutputSet.containsAll(expr.getInputSlots()))
.collect(Collectors.toList());
LogicalPlan project = join;
if (!outputSet.equals(new HashSet<>(projects))) {
project = new LogicalProject<>(projects, join);
}
Preconditions.checkState(!projects.isEmpty() && projects.size() == allProjects.size(),
" there are some projects left %s %s", projects, allProjects);
return project;
}
}