PlanReceiver.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.jobs.joinorder.hypergraphv2.receiver;
import org.apache.doris.nereids.jobs.JobContext;
import org.apache.doris.nereids.jobs.cascades.DeriveStatsJob;
import org.apache.doris.nereids.jobs.cascades.OptimizeGroupExpressionJob;
import org.apache.doris.nereids.jobs.joinorder.hypergraphv2.HyperGraph;
import org.apache.doris.nereids.jobs.joinorder.hypergraphv2.bitmap.LongBitmap;
import org.apache.doris.nereids.jobs.joinorder.hypergraphv2.edge.Edge;
import org.apache.doris.nereids.memo.CopyInResult;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.memo.Memo;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.expressions.literal.TinyIntLiteral;
import org.apache.doris.nereids.trees.plans.GroupPlan;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.base.Preconditions;
import java.util.ArrayList;
import java.util.BitSet;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
/**
* The Receiver is used for cached the plan that has been emitted and build the new plan, it's the dp table in paper
*/
public class PlanReceiver extends AbstractReceiver {
final Set<Slot> finalRequiredSlots;
final List<NamedExpression> finalProjects;
// dp table to cache all valid sub-plans
HashMap<Long, Group> planTable = new HashMap<>();
HashMap<Long, BitSet> usdEdges = new HashMap<>();
// limit define the max number of csg-cmp pair in this Receiver
int limit;
int emitCount = 0;
JobContext jobContext;
HyperGraph hyperGraph;
long allNodeBitmap;
long startTime = System.currentTimeMillis();
long timeLimit = ConnectContext.get().getSessionVariable().joinReorderTimeLimit;
boolean fullKeyEmitted = false;
EmitState emitState = EmitState.NONE;
boolean missingEdgeFail = false;
boolean conflictRuleFail = false;
boolean joinTypeError = false;
/**
* PlanReceiver
*/
public PlanReceiver(JobContext jobContext, int limit, HyperGraph hyperGraph) {
this.jobContext = jobContext;
this.limit = limit;
this.hyperGraph = hyperGraph;
this.finalProjects = hyperGraph.getFinalProjects();
this.finalRequiredSlots = ExpressionUtils.getInputSlotSet(finalProjects);
this.allNodeBitmap = hyperGraph.getNodesMap();
}
/**
* Emit a new plan from bottom to top
* <p>
* The purpose of EmitCsgCmp is to combine the optimal plans for S1 and S2 into a csg-cmp-pair.
* It requires calculating the proper join predicate and costs of the resulting joins.
* In the end, update dpTables.
*
* @param left the bitmap of left child tree
* @param right the bitmap of the right child tree
* @param edges the join conditions that can be added in this operator
* @return the left and the right can be connected by the edge
*/
@Override
public EmitState emitCsgCmp(long left, long right, List<Edge> edges) {
Preconditions.checkArgument(planTable.containsKey(left));
Preconditions.checkArgument(planTable.containsKey(right));
missingEdgeFail = false;
conflictRuleFail = false;
joinTypeError = false;
fullKeyEmitted = false;
if (LongBitmap.newBitmapUnion(left, right) == allNodeBitmap) {
fullKeyEmitted = true;
}
List<Edge> missingEdges = new ArrayList<>();
if (!processMissedEdges(left, right, edges, missingEdges)) {
if (fullKeyEmitted) {
missingEdgeFail = true;
}
return EmitState.CONTINUE;
}
emitCount += 1;
if (emitCount > limit || System.currentTimeMillis() - startTime > timeLimit) {
return EmitState.FAIL;
}
edges.addAll(missingEdges);
if (!checkConflictRule(left, right, edges)) {
if (fullKeyEmitted) {
conflictRuleFail = true;
}
return EmitState.CONTINUE;
}
if ((edges.get(0).getLeftExtendedNodes() & left) == 0) {
// swap left and right
long tmp = left;
left = right;
right = tmp;
}
Memo memo = jobContext.getCascadesContext().getMemo();
GroupPlan leftPlan = planTable.get(left).getGroupPlan();
GroupPlan rightPlan = planTable.get(right).getGroupPlan();
// First, we implement all possible physical plans
// In this step, we don't generate logical expression because they are useless in DPhyp.
List<Expression> hashConjuncts = new ArrayList<>();
List<Expression> otherConjuncts = new ArrayList<>();
JoinType joinType = Edge.extractJoinTypeAndConjuncts(edges, missingEdges, hashConjuncts, otherConjuncts);
if (joinType == null) {
if (fullKeyEmitted) {
joinTypeError = true;
}
return EmitState.CONTINUE;
}
long fullKey = LongBitmap.newBitmapUnion(left, right);
LogicalPlan logicalJoin = proposeJoin(joinType, leftPlan, rightPlan, hashConjuncts,
otherConjuncts);
LogicalPlan logicalPlan = proposeProject(logicalJoin, edges, left, right);
// Second, we copy all physical plan to Group and generate properties and calculate cost
if (!planTable.containsKey(fullKey)) {
planTable.put(fullKey, memo.newGroup(logicalPlan.getLogicalProperties()));
}
Group group = planTable.get(fullKey);
CopyInResult copyInResult = memo.copyIn(logicalPlan, group, planTable);
proposeAllDistributedPlans(copyInResult.correspondingExpression);
return EmitState.SUCCESS;
}
// be aware that the requiredOutputSlots is a superset of the actual output of current node
// check proposeProject method to get how to create a project node for the outputs of current node.
private Set<Slot> calculateRequiredSlots(long left, long right, List<Edge> edges) {
// required output slots = final outputs + slot of unused edges
// 1. add finalOutputs to requiredOutputSlots
Set<Slot> requiredOutputSlots = new HashSet<>(this.finalRequiredSlots);
BitSet usedEdgesBitmap = new BitSet();
usedEdgesBitmap.or(usdEdges.get(left));
usedEdgesBitmap.or(usdEdges.get(right));
for (Edge edge : edges) {
usedEdgesBitmap.set(edge.getIndex());
}
// 2. add unused edges' input slots to requiredOutputSlots
usdEdges.put(LongBitmap.newBitmapUnion(left, right), usedEdgesBitmap);
for (Edge edge : hyperGraph.getJoinEdges()) {
if (!usedEdgesBitmap.get(edge.getIndex())) {
requiredOutputSlots.addAll(edge.getInputSlots());
}
}
return requiredOutputSlots;
}
// add all missed edge into edges to connect left and right, considering sql bellow:
// select * from t0 join t1 on t0.c1 = t1.c1 join t2 on t0.c2 = t2.c2 and t0.c1 = t1.c1 + t2.c2;
// the hyperGraph's joinEdges is:
// joinEdges = {RegularImmutableList@18366} size = 3
// 0 = {Edge@18370} "<{0} --INNER_JOIN-- {1}>"
// 1 = {Edge@18371} "<{0} --INNER_JOIN-- {2}>"
// 2 = {Edge@18372} "<{0, 1} --INNER_JOIN-- {2}>"
// the hyper predicate t0.c1 = t1.c1 + t2.c2 is encoded as hyper edge 2.
// The hyper edge(Edge2) means 0 and 1 must be joined before join 2 according to the paper.
// Unfortunately, it's not correct, because we can join 0 and 2 first then join 1.
// Ideally, we should create new Edge <{0, 2} --INNER_JOIN-- {1}> based on Edge2 to solve the problem,
// The root cause is hyper predicate should be encoded as one or more hyper edges in different scenarios.
// But we are not able to do so in all cases (complex expression and outer joins).
// So we use processMissedEdges to find all valid edges when join 0, 1, 2 as fallback plan.
private boolean processMissedEdges(long left, long right, List<Edge> edges, List<Edge> missingEdges) {
// find all used edges
BitSet usedEdgesBitmap = new BitSet();
usedEdgesBitmap.or(usdEdges.get(left));
usedEdgesBitmap.or(usdEdges.get(right));
edges.forEach(edge -> usedEdgesBitmap.set(edge.getIndex()));
// find all referenced nodes
long allReferenceNodes = LongBitmap.or(left, right);
// find the edge which is not in usedEdgesBitmap and its referenced nodes is subset of allReferenceNodes
for (Edge edge : hyperGraph.getJoinEdges()) {
if (LongBitmap.isSubset(edge.getReferenceNodes(), allReferenceNodes)
&& !usedEdgesBitmap.get(edge.getIndex())) {
if (edge.isEnforcedOrder()) {
return false;
} else {
// add the missed edge to edges
missingEdges.add(edge);
}
}
}
return true;
}
private void proposeAllDistributedPlans(GroupExpression groupExpression) {
jobContext.getCascadesContext().pushJob(new OptimizeGroupExpressionJob(groupExpression,
new JobContext(jobContext.getCascadesContext(), PhysicalProperties.ANY)));
if (!groupExpression.isStatDerived()) {
jobContext.getCascadesContext().pushJob(new DeriveStatsJob(groupExpression,
jobContext.getCascadesContext().getCurrentJobContext()));
}
jobContext.getCascadesContext().getJobScheduler().executeJobPool(jobContext.getCascadesContext());
}
private LogicalPlan proposeJoin(JoinType joinType, Plan left, Plan right, List<Expression> hashConjuncts,
List<Expression> otherConjuncts) {
return new LogicalJoin<>(joinType, hashConjuncts, otherConjuncts, left, right, null);
}
@Override
public void addGroup(long bitmap, Group group) {
Preconditions.checkArgument(LongBitmap.getCardinality(bitmap) == 1);
usdEdges.put(bitmap, new BitSet());
Plan plan = proposeProject(group.getGroupPlan(), new ArrayList<>(), bitmap, bitmap);
if (!(plan instanceof GroupPlan)) {
CopyInResult copyInResult = jobContext.getCascadesContext().getMemo().copyIn(plan, null, false, planTable);
group = copyInResult.correspondingExpression.getOwnerGroup();
}
planTable.put(bitmap, group);
usdEdges.put(bitmap, new BitSet());
}
@Override
public boolean contain(long bitmap) {
return planTable.containsKey(bitmap);
}
@Override
public void reset() {
emitCount = 0;
planTable.clear();
usdEdges.clear();
startTime = System.currentTimeMillis();
emitState = EmitState.NONE;
fullKeyEmitted = false;
}
@Override
public Group getBestPlan(long bitmap) {
Group group = planTable.get(bitmap);
return group;
}
private LogicalPlan proposeProject(LogicalPlan join, List<Edge> edges, long left, long right) {
Set<Slot> outputSet = join.getOutputSet();
if (LongBitmap.newBitmapUnion(left, right) == allNodeBitmap
&& !outputSet.equals(new HashSet<>(finalProjects))) {
// add final project for the join cluster
return new LogicalProject<>(finalProjects, join);
} else {
// calculate required columns by all parents
Set<Slot> requireSlots = calculateRequiredSlots(left, right, edges);
List<NamedExpression> allProjects = new ArrayList<>(outputSet.size());
for (Slot slot : outputSet) {
if (requireSlots.contains(slot)) {
allProjects.add(slot);
}
}
// propose logical project
if (allProjects.isEmpty()) {
allProjects.add(new Alias(new ExprId(-1), new TinyIntLiteral((byte) 1)));
}
if (outputSet.equals(new HashSet<>(allProjects))) {
return join;
} else {
return new LogicalProject<>(allProjects, join);
}
}
}
}