LeadingHint.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.hint;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.jobs.joinorder.hypergraph.bitmap.LongBitmap;
import org.apache.doris.nereids.trees.expressions.ExprId;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.plans.DistributeType;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalFilter;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.logical.LogicalRelation;
import org.apache.doris.nereids.trees.plans.logical.LogicalSubQueryAlias;
import org.apache.doris.nereids.util.JoinUtils;
import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.Maps;
import com.google.common.collect.Sets;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;
import java.util.Stack;
/**
* select hint.
* e.g. set_var(query_timeout='1800', exec_mem_limit='2147483648')
*/
public class LeadingHint extends Hint {
private String originalString = "";
private List<String> addJoinParameters;
private List<String> normalizedParameters;
private final List<String> tablelist = new ArrayList<>();
private final Map<Integer, DistributeHint> distributeHints = new HashMap<>();
private final Map<RelationId, LogicalPlan> relationIdToScanMap = Maps.newLinkedHashMap();
private final List<Pair<RelationId, String>> relationIdAndTableName = new ArrayList<>();
private final Map<ExprId, String> exprIdToTableNameMap = Maps.newLinkedHashMap();
private final List<Pair<Long, Expression>> filters = new ArrayList<>();
private final Map<Expression, JoinType> conditionJoinType = Maps.newLinkedHashMap();
private final List<JoinConstraint> joinConstraintList = new ArrayList<>();
private Long innerJoinBitmap = 0L;
private Long totalBitmap = 0L;
public LeadingHint(String hintName) {
super(hintName);
}
/**
* Leading hint data structure before using
* @param hintName Leading
* @param parameters table name mixed with left and right brace
*/
public LeadingHint(String hintName, List<String> parameters, String originalString) {
super(hintName);
this.originalString = originalString;
addJoinParameters = insertJoinIntoParameters(parameters);
normalizedParameters = parseIntoReversePolishNotation(addJoinParameters);
}
/**
* insert join string into leading string
* @param list of sql input leading string
* @return list of string adding joins into tables
*/
public static List<String> insertJoinIntoParameters(List<String> list) {
List<String> output = new ArrayList<>();
for (String item : list) {
if (item.equals("shuffle") || item.equals("broadcast")) {
output.remove(output.size() - 1);
output.add(item);
continue;
} else if (item.equals("{")) {
output.add(item);
continue;
} else if (item.equals("}")) {
output.remove(output.size() - 1);
output.add(item);
} else {
output.add(item);
}
output.add("join");
}
output.remove(output.size() - 1);
return output;
}
/**
* parse list string of original leading string with join string to Reverse Polish notation
* @param list of leading with join string
* @return Reverse Polish notation which can be used directly changed into logical join
*/
public List<String> parseIntoReversePolishNotation(List<String> list) {
Stack<String> s1 = new Stack<>();
List<String> s2 = new ArrayList<>();
for (String item : list) {
if (!(item.equals("shuffle") || item.equals("broadcast") || item.equals("{")
|| item.equals("}") || item.equals("join"))) {
tablelist.add(item);
s2.add(item);
} else if (item.equals("{")) {
s1.push(item);
} else if (item.equals("}")) {
while (!s1.peek().equals("{")) {
String pop = s1.pop();
s2.add(pop);
}
s1.pop();
} else {
if (item.equals("shuffle")) {
distributeHints.put(item.hashCode(), new DistributeHint(DistributeType.SHUFFLE_RIGHT));
} else if (item.equals("broadcast")) {
distributeHints.put(item.hashCode(), new DistributeHint(DistributeType.BROADCAST_RIGHT));
}
while (s1.size() != 0 && !s1.peek().equals("{")) {
s2.add(s1.pop());
}
s1.push(item);
}
}
while (s1.size() > 0) {
s2.add(s1.pop());
}
return s2;
}
public List<String> getTablelist() {
return tablelist;
}
public Map<RelationId, LogicalPlan> getRelationIdToScanMap() {
return relationIdToScanMap;
}
@Override
public String getExplainString() {
if (!this.isSuccess()) {
return originalString;
}
StringBuilder out = new StringBuilder();
for (String parameter : addJoinParameters) {
if (parameter.equals("{") || parameter.equals("}") || parameter.equals("[") || parameter.equals("]")) {
out.append(parameter + " ");
} else if (parameter.equals("shuffle") || parameter.equals("broadcast")) {
DistributeHint distributeHint = distributeHints.get(parameter.hashCode());
if (distributeHint.isSuccess()) {
out.append(parameter + " ");
}
} else if (parameter.equals("join")) {
continue;
} else {
out.append(parameter + " ");
}
}
return "leading(" + out.toString() + ")";
}
/**
* Get logical plan by table name recorded in leading hint. if can not get, means leading has syntax error
* or need to update. So return null should be deal with when call
* @param name table name
* @return logical plan recorded when binding
*/
public LogicalPlan getLogicalPlanByName(String name) {
RelationId id = findRelationIdAndTableName(name);
if (id == null) {
this.setStatus(HintStatus.SYNTAX_ERROR);
this.setErrorMessage("can not find table: " + name);
return null;
}
return relationIdToScanMap.get(id);
}
/**
* putting pair into list, if relation id already exist update table name
* @param relationIdTableNamePair pair of relation id and table name to be inserted
*/
public void putRelationIdAndTableName(Pair<RelationId, String> relationIdTableNamePair) {
boolean isUpdate = false;
for (Pair<RelationId, String> pair : relationIdAndTableName) {
if (pair.first.equals(relationIdTableNamePair.first)) {
pair.second = relationIdTableNamePair.second;
isUpdate = true;
}
if (pair.second.equals(relationIdTableNamePair.second)) {
pair.first = relationIdTableNamePair.first;
isUpdate = true;
}
}
if (!isUpdate) {
relationIdAndTableName.add(relationIdTableNamePair);
}
}
/**
* putting pair into list, if relation id already exist update table name
* @param relationIdTableNamePair pair of relation id and table name to be inserted
*/
public void updateRelationIdByTableName(Pair<RelationId, String> relationIdTableNamePair) {
boolean isUpdate = false;
for (Pair<RelationId, String> pair : relationIdAndTableName) {
if (pair.second.equals(relationIdTableNamePair.second)) {
pair.first = relationIdTableNamePair.first;
isUpdate = true;
}
}
if (!isUpdate) {
relationIdAndTableName.add(relationIdTableNamePair);
}
}
/**
* find relation id and table name pair, relation id is unique, but table name is not
* @param name table name
* @return relation id
*/
public RelationId findRelationIdAndTableName(String name) {
for (Pair<RelationId, String> pair : relationIdAndTableName) {
if (pair.second.equals(name)) {
return pair.first;
}
}
return null;
}
private Optional<String> hasSameName() {
Set<String> tableSet = Sets.newHashSet();
for (String table : tablelist) {
if (!tableSet.add(table)) {
return Optional.of(table);
}
}
return Optional.empty();
}
public Map<ExprId, String> getExprIdToTableNameMap() {
return exprIdToTableNameMap;
}
public List<Pair<Long, Expression>> getFilters() {
return filters;
}
public void putConditionJoinType(Expression filter, JoinType joinType) {
conditionJoinType.put(filter, joinType);
}
/**
* find out whether conditions can match original joinType
* @param conditions conditions needs to put on this join
* @param joinType join type computed by join constraint
* @return can conditions matched
*/
public boolean isConditionJoinTypeMatched(List<Expression> conditions, JoinType joinType) {
for (Expression condition : conditions) {
JoinType originalJoinType = conditionJoinType.get(condition);
if (originalJoinType.equals(joinType)
|| originalJoinType.isOneSideOuterJoin() && joinType.isOneSideOuterJoin()
|| originalJoinType.isSemiJoin() && joinType.isSemiJoin()
|| originalJoinType.isAntiJoin() && joinType.isAntiJoin()) {
continue;
}
return false;
}
return true;
}
public List<JoinConstraint> getJoinConstraintList() {
return joinConstraintList;
}
public Long getInnerJoinBitmap() {
return innerJoinBitmap;
}
public void setInnerJoinBitmap(Long innerJoinBitmap) {
this.innerJoinBitmap = innerJoinBitmap;
}
public Long getTotalBitmap() {
return totalBitmap;
}
/**
* set total bitmap used in leading before we get into leading join
*/
public void setTotalBitmap(Set<RelationId> inputRelationSets) {
Long totalBitmap = 0L;
Optional<String> duplicateTableName = hasSameName();
if (duplicateTableName.isPresent()) {
this.setStatus(HintStatus.SYNTAX_ERROR);
this.setErrorMessage("duplicated table:" + duplicateTableName.get());
}
Set<RelationId> existRelationSets = new HashSet<>();
for (int index = 0; index < getTablelist().size(); index++) {
RelationId id = findRelationIdAndTableName(getTablelist().get(index));
if (id == null) {
this.setStatus(HintStatus.SYNTAX_ERROR);
this.setErrorMessage("can not find table: " + getTablelist().get(index));
return;
}
existRelationSets.add(id);
totalBitmap = LongBitmap.set(totalBitmap, id.asInt());
}
if (getTablelist().size() < inputRelationSets.size()) {
Set<RelationId> missRelationIds = new HashSet<>();
missRelationIds.addAll(inputRelationSets);
missRelationIds.removeAll(existRelationSets);
String missingTablenames = getMissingTableNames(missRelationIds);
this.setStatus(HintStatus.SYNTAX_ERROR);
this.setErrorMessage("leading should have all tables in query block, missing tables: " + missingTablenames);
}
this.totalBitmap = totalBitmap;
}
private String getMissingTableNames(Set<RelationId> missRelationIds) {
String missTableNames = "";
for (RelationId id : missRelationIds) {
for (Pair<RelationId, String> pair : relationIdAndTableName) {
if (pair.first.equals(id)) {
missTableNames += pair.second + " ";
}
}
}
return missTableNames;
}
/**
* try to get join constraint, if can not get, it means join is inner join,
* @param joinTableBitmap table bitmap below this join
* @param leftTableBitmap table bitmap below right child
* @param rightTableBitmap table bitmap below right child
* @return boolean value used for judging whether the join is legal, and should this join need to reverse
*/
public Pair<JoinConstraint, Boolean> getJoinConstraint(Long joinTableBitmap, Long leftTableBitmap,
Long rightTableBitmap) {
boolean reversed = false;
boolean mustBeLeftjoin = false;
JoinConstraint matchedJoinConstraint = null;
for (JoinConstraint joinConstraint : joinConstraintList) {
if (joinConstraint.getJoinType().isFullOuterJoin()) {
if (leftTableBitmap.equals(joinConstraint.getLeftHand())
&& rightTableBitmap.equals(joinConstraint.getRightHand())
|| rightTableBitmap.equals(joinConstraint.getLeftHand())
&& leftTableBitmap.equals(joinConstraint.getRightHand())) {
if (matchedJoinConstraint != null) {
return Pair.of(null, false);
}
matchedJoinConstraint = joinConstraint;
reversed = false;
break;
} else {
continue;
}
}
if (!LongBitmap.isOverlap(joinConstraint.getMinRightHand(), joinTableBitmap)) {
continue;
}
if (LongBitmap.isSubset(joinTableBitmap, joinConstraint.getMinRightHand())) {
continue;
}
if (LongBitmap.isSubset(joinConstraint.getMinLeftHand(), leftTableBitmap)
&& LongBitmap.isSubset(joinConstraint.getMinRightHand(), leftTableBitmap)) {
continue;
}
if (LongBitmap.isSubset(joinConstraint.getMinLeftHand(), rightTableBitmap)
&& LongBitmap.isSubset(joinConstraint.getMinRightHand(), rightTableBitmap)) {
continue;
}
if (joinConstraint.getJoinType().isSemiJoin()) {
if (LongBitmap.isSubset(joinConstraint.getRightHand(), leftTableBitmap)
&& !LongBitmap.isSubset(joinConstraint.getRightHand(), leftTableBitmap)) {
continue;
}
if (LongBitmap.isSubset(joinConstraint.getRightHand(), rightTableBitmap)
&& !joinConstraint.getRightHand().equals(rightTableBitmap)) {
continue;
}
}
if (LongBitmap.isSubset(joinConstraint.getMinLeftHand(), leftTableBitmap)
&& LongBitmap.isSubset(joinConstraint.getMinRightHand(), rightTableBitmap)) {
if (matchedJoinConstraint != null) {
return Pair.of(null, false);
}
matchedJoinConstraint = joinConstraint;
reversed = false;
} else if (LongBitmap.isSubset(joinConstraint.getMinLeftHand(), rightTableBitmap)
&& LongBitmap.isSubset(joinConstraint.getMinRightHand(), leftTableBitmap)) {
if (matchedJoinConstraint != null) {
return Pair.of(null, false);
}
matchedJoinConstraint = joinConstraint;
reversed = true;
} else if (joinConstraint.getJoinType().isSemiJoin()
&& joinConstraint.getRightHand().equals(rightTableBitmap)) {
if (matchedJoinConstraint != null) {
return Pair.of(null, false);
}
matchedJoinConstraint = joinConstraint;
reversed = false;
} else if (joinConstraint.getJoinType().isSemiJoin()
&& joinConstraint.getRightHand().equals(leftTableBitmap)) {
/* Reversed semijoin case */
if (matchedJoinConstraint != null) {
return Pair.of(null, false);
}
matchedJoinConstraint = joinConstraint;
reversed = true;
} else {
if (LongBitmap.isOverlap(leftTableBitmap, joinConstraint.getMinRightHand())
&& LongBitmap.isOverlap(rightTableBitmap, joinConstraint.getMinRightHand())) {
continue;
}
if (!joinConstraint.getJoinType().isLeftJoin()
|| LongBitmap.isOverlap(joinTableBitmap, joinConstraint.getMinLeftHand())) {
return Pair.of(null, false);
}
mustBeLeftjoin = true;
}
}
if (mustBeLeftjoin && (matchedJoinConstraint == null || !matchedJoinConstraint.getJoinType().isLeftJoin()
|| !matchedJoinConstraint.isLhsStrict())) {
return Pair.of(null, false);
}
// this means inner join
if (matchedJoinConstraint == null) {
return Pair.of(null, true);
}
matchedJoinConstraint.setReversed(reversed);
return Pair.of(matchedJoinConstraint, true);
}
/**
* Try to get join type of two random logical scan or join node table bitmap
* @param left left side table bitmap
* @param right right side table bitmap
* @return join type or failure
*/
public JoinType computeJoinType(Long left, Long right, List<Expression> conditions) {
Pair<JoinConstraint, Boolean> joinConstraintBooleanPair
= getJoinConstraint(LongBitmap.or(left, right), left, right);
if (!joinConstraintBooleanPair.second) {
this.setStatus(HintStatus.UNUSED);
} else if (joinConstraintBooleanPair.first == null) {
if (conditions.isEmpty()) {
return JoinType.CROSS_JOIN;
}
return JoinType.INNER_JOIN;
} else {
JoinConstraint joinConstraint = joinConstraintBooleanPair.first;
if (joinConstraint.isReversed()) {
return joinConstraint.getJoinType().swap();
} else {
return joinConstraint.getJoinType();
}
}
if (conditions.isEmpty()) {
return JoinType.CROSS_JOIN;
}
return JoinType.INNER_JOIN;
}
private DistributeHint getDistributeJoinHint(String distributeJoinType) {
DistributeHint distributeHint = null;
if (distributeJoinType.equals("join")) {
distributeHint = new DistributeHint(DistributeType.NONE);
} else if (distributeJoinType.equals("shuffle") || distributeJoinType.equals("broadcast")) {
distributeHint = distributeHints.get(distributeJoinType.hashCode());
}
distributeHint.setSuccessInLeading(true);
if (!ConnectContext.get().getStatementContext().getHints().contains(distributeHint)) {
ConnectContext.get().getStatementContext().addHint(distributeHint);
}
distributeHints.put(0, distributeHint);
return distributeHint;
}
private LogicalPlan makeJoinPlan(LogicalPlan leftChild, LogicalPlan rightChild, String distributeJoinType) {
List<Expression> conditions = getJoinConditions(
getFilters(), leftChild, rightChild);
Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable(
leftChild.getOutput(), rightChild.getOutput(), conditions);
// leading hint would set status inside if not success
JoinType joinType = computeJoinType(getBitmap(leftChild),
getBitmap(rightChild), conditions);
if (joinType == null) {
this.setStatus(HintStatus.SYNTAX_ERROR);
this.setErrorMessage("JoinType can not be null");
} else if (!isConditionJoinTypeMatched(conditions, joinType)) {
this.setStatus(HintStatus.UNUSED);
this.setErrorMessage("condition does not matched joinType");
}
if (!this.isSuccess()) {
return null;
}
// get joinType
DistributeHint distributeHint = getDistributeJoinHint(distributeJoinType);
LogicalJoin logicalJoin = new LogicalJoin<>(joinType, pair.first,
pair.second,
distributeHint,
Optional.empty(),
leftChild,
rightChild, null);
logicalJoin.getJoinReorderContext().setLeadingJoin(true);
logicalJoin.setBitmap(LongBitmap.or(getBitmap(leftChild), getBitmap(rightChild)));
return logicalJoin;
}
/**
* using leading to generate plan, it could be failed, if failed set leading status to unused or syntax error
* @return plan
*/
public Plan generateLeadingJoinPlan() {
Stack<LogicalPlan> stack = new Stack<>();
for (String item : normalizedParameters) {
if (item.equals("join") || item.equals("shuffle") || item.equals("broadcast")) {
LogicalPlan rightChild = stack.pop();
LogicalPlan leftChild = stack.pop();
LogicalPlan joinPlan = makeJoinPlan(leftChild, rightChild, item);
if (joinPlan == null) {
return null;
}
stack.push(joinPlan);
} else {
LogicalPlan logicalPlan = getLogicalPlanByName(item);
logicalPlan = makeFilterPlanIfExist(getFilters(), logicalPlan);
stack.push(logicalPlan);
}
}
LogicalJoin finalJoin = (LogicalJoin) stack.pop();
// we want all filters been remove
assert (filters.isEmpty());
if (finalJoin != null) {
this.setStatus(HintStatus.SUCCESS);
}
return finalJoin;
}
private DistributeHint getJoinHint(Integer index) {
if (distributeHints.get(index) == null) {
return new DistributeHint(DistributeType.NONE);
}
distributeHints.get(index).setSuccessInLeading(true);
return distributeHints.get(index);
}
private List<Expression> getJoinConditions(List<Pair<Long, Expression>> filters,
LogicalPlan left, LogicalPlan right) {
List<Expression> joinConditions = new ArrayList<>();
for (int i = filters.size() - 1; i >= 0; i--) {
Pair<Long, Expression> filterPair = filters.get(i);
Long tablesBitMap = LongBitmap.or(getBitmap(left), getBitmap(right));
// left one is smaller set
if (LongBitmap.isSubset(filterPair.first, tablesBitMap)) {
joinConditions.add(filterPair.second);
filters.remove(i);
}
}
return joinConditions;
}
private List<Expression> getLastConditions(List<Pair<Long, Expression>> filters) {
List<Expression> joinConditions = new ArrayList<>();
for (int i = filters.size() - 1; i >= 0; i--) {
Pair<Long, Expression> filterPair = filters.get(i);
joinConditions.add(filterPair.second);
filters.remove(i);
}
return joinConditions;
}
private LogicalPlan makeFilterPlanIfExist(List<Pair<Long, Expression>> filters, LogicalPlan scan) {
Set<Expression> newConjuncts = new HashSet<>();
for (int i = filters.size() - 1; i >= 0; i--) {
Pair<Long, Expression> filterPair = filters.get(i);
if (LongBitmap.isSubset(filterPair.first, getBitmap(scan))) {
newConjuncts.add(filterPair.second);
filters.remove(i);
}
}
if (newConjuncts.isEmpty()) {
return scan;
} else {
return new LogicalFilter<>(newConjuncts, scan);
}
}
private Long getBitmap(LogicalPlan root) {
if (root instanceof LogicalJoin) {
return ((LogicalJoin) root).getBitmap();
} else if (root instanceof LogicalRelation) {
return LongBitmap.set(0L, (((LogicalRelation) root).getRelationId().asInt()));
} else if (root instanceof LogicalFilter) {
return getBitmap((LogicalPlan) root.child(0));
} else if (root instanceof LogicalProject) {
return getBitmap((LogicalPlan) root.child(0));
} else if (root instanceof LogicalSubQueryAlias) {
return LongBitmap.set(0L, (((LogicalSubQueryAlias) root).getRelationId().asInt()));
} else {
return null;
}
}
}