Group.java
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
package org.apache.doris.nereids.memo;
import org.apache.doris.common.Pair;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.properties.LogicalProperties;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.literal.Literal;
import org.apache.doris.nereids.trees.plans.JoinType;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.logical.LogicalJoin;
import org.apache.doris.nereids.trees.plans.logical.LogicalPlan;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPlan;
import org.apache.doris.nereids.util.TreeStringUtils;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.statistics.Statistics;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.IdentityHashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.Objects;
import java.util.Optional;
import java.util.function.Function;
import java.util.stream.Collectors;
/**
* Representation for group in cascades optimizer.
*/
public class Group {
private final GroupId groupId;
// Save all parent GroupExpression to avoid traversing whole Memo.
private final IdentityHashMap<GroupExpression, Void> parentExpressions = new IdentityHashMap<>();
private final List<GroupExpression> logicalExpressions = Lists.newArrayList();
private final List<GroupExpression> physicalExpressions = Lists.newArrayList();
private final Map<GroupExpression, GroupExpression> enforcers = Maps.newHashMap();
private boolean isStatsReliable = true;
private LogicalProperties logicalProperties;
// Map of cost lower bounds
// Map required plan props to cost lower bound of corresponding plan
private final Map<PhysicalProperties, Pair<Cost, GroupExpression>> lowestCostPlans = Maps.newLinkedHashMap();
private boolean isExplored = false;
private Statistics statistics;
private PhysicalProperties chosenProperties;
private int chosenGroupExpressionId = -1;
private List<PhysicalProperties> chosenEnforcerPropertiesList = new ArrayList<>();
private List<Integer> chosenEnforcerIdList = new ArrayList<>();
private StructInfoMap structInfoMap = new StructInfoMap();
/**
* Constructor for Group.
*
* @param groupExpression first {@link GroupExpression} in this Group
*/
public Group(GroupId groupId, GroupExpression groupExpression, LogicalProperties logicalProperties) {
this.groupId = groupId;
addGroupExpression(groupExpression);
this.logicalProperties = logicalProperties;
}
/**
* Construct a Group without any group expression
*
* @param groupId the groupId in memo
*/
public Group(GroupId groupId, LogicalProperties logicalProperties) {
this.groupId = groupId;
this.logicalProperties = logicalProperties;
}
public GroupId getGroupId() {
return groupId;
}
public List<PhysicalProperties> getAllProperties() {
return new ArrayList<>(lowestCostPlans.keySet());
}
/**
* Add new {@link GroupExpression} into this group.
*
* @param groupExpression {@link GroupExpression} to be added
* @return added {@link GroupExpression}
*/
public GroupExpression addGroupExpression(GroupExpression groupExpression) {
if (groupExpression.getPlan() instanceof LogicalPlan) {
logicalExpressions.add(groupExpression);
} else {
physicalExpressions.add(groupExpression);
}
groupExpression.setOwnerGroup(this);
return groupExpression;
}
public void setStatsReliable(boolean statsReliable) {
this.isStatsReliable = statsReliable;
}
public boolean isStatsReliable() {
return isStatsReliable;
}
public void addLogicalExpression(GroupExpression groupExpression) {
groupExpression.setOwnerGroup(this);
logicalExpressions.add(groupExpression);
}
public void addPhysicalExpression(GroupExpression groupExpression) {
groupExpression.setOwnerGroup(this);
physicalExpressions.add(groupExpression);
}
public List<GroupExpression> getLogicalExpressions() {
return logicalExpressions;
}
public GroupExpression logicalExpressionsAt(int index) {
return logicalExpressions.get(index);
}
/**
* Get the first logical group expression in this group.
* If there is no logical group expression or more than one, throw an exception.
*
* @return the first logical group expression in this group
*/
public GroupExpression getLogicalExpression() {
Preconditions.checkArgument(logicalExpressions.size() == 1,
"There should be only one Logical Expression in Group");
return logicalExpressions.get(0);
}
public List<GroupExpression> getPhysicalExpressions() {
return physicalExpressions;
}
/**
* Remove groupExpression from this group.
*
* @param groupExpression to be removed
* @return removed {@link GroupExpression}
*/
public GroupExpression removeGroupExpression(GroupExpression groupExpression) {
// use identityRemove to avoid equals() method
if (groupExpression.getPlan() instanceof LogicalPlan) {
Utils.identityRemove(logicalExpressions, groupExpression);
} else {
Utils.identityRemove(physicalExpressions, groupExpression);
}
groupExpression.setOwnerGroup(null);
return groupExpression;
}
public List<GroupExpression> clearLogicalExpressions() {
List<GroupExpression> move = logicalExpressions.stream()
.peek(groupExpr -> groupExpr.setOwnerGroup(null))
.collect(Collectors.toList());
logicalExpressions.clear();
return move;
}
public List<GroupExpression> clearPhysicalExpressions() {
List<GroupExpression> move = physicalExpressions.stream()
.peek(groupExpr -> groupExpr.setOwnerGroup(null))
.collect(Collectors.toList());
physicalExpressions.clear();
return move;
}
public void clearLowestCostPlans() {
lowestCostPlans.clear();
}
public double getCostLowerBound() {
return -1D;
}
/**
* Get the lowest cost {@link org.apache.doris.nereids.trees.plans.physical.PhysicalPlan}
* which meeting the physical property constraints in this Group.
*
* @param physicalProperties the physical property constraints
* @return {@link Optional} of cost and {@link GroupExpression} of physical plan pair.
*/
public Optional<Pair<Cost, GroupExpression>> getLowestCostPlan(PhysicalProperties physicalProperties) {
if (physicalProperties == null || lowestCostPlans.isEmpty()) {
return Optional.empty();
}
Optional<Pair<Cost, GroupExpression>> costAndGroupExpression =
Optional.ofNullable(lowestCostPlans.get(physicalProperties));
return costAndGroupExpression;
}
public Map<PhysicalProperties, Cost> getLowestCosts() {
return lowestCostPlans.entrySet()
.stream()
.collect(ImmutableMap.toImmutableMap(Entry::getKey, kv -> kv.getValue().first));
}
public GroupExpression getBestPlan(PhysicalProperties properties) {
if (lowestCostPlans.containsKey(properties)) {
return lowestCostPlans.get(properties).second;
}
return null;
}
public void addEnforcer(GroupExpression enforcer) {
enforcer.setOwnerGroup(this);
enforcers.put(enforcer, enforcer);
}
public Map<GroupExpression, GroupExpression> getEnforcers() {
return enforcers;
}
/**
* Set or update lowestCostPlans: properties --> Pair.of(cost, expression)
*/
public void setBestPlan(GroupExpression expression, Cost cost, PhysicalProperties properties) {
if (lowestCostPlans.containsKey(properties)) {
if (lowestCostPlans.get(properties).first.getValue() > cost.getValue()) {
lowestCostPlans.put(properties, Pair.of(cost, expression));
}
} else {
lowestCostPlans.put(properties, Pair.of(cost, expression));
}
}
/**
* replace best plan with new properties
*/
public void replaceBestPlanProperty(PhysicalProperties oldProperty,
PhysicalProperties newProperty, Cost cost) {
Pair<Cost, GroupExpression> pair = lowestCostPlans.get(oldProperty);
GroupExpression lowestGroupExpr = pair.second;
lowestGroupExpr.updateLowestCostTable(newProperty,
lowestGroupExpr.getInputPropertiesList(oldProperty), cost);
lowestCostPlans.remove(oldProperty);
lowestCostPlans.put(newProperty, pair);
}
/**
* replace oldGroupExpression with newGroupExpression in lowestCostPlans.
*/
public void replaceBestPlanGroupExpr(GroupExpression oldGroupExpression, GroupExpression newGroupExpression) {
Map<PhysicalProperties, Pair<Cost, GroupExpression>> needReplaceBestExpressions = Maps.newHashMap();
for (Iterator<Entry<PhysicalProperties, Pair<Cost, GroupExpression>>> iterator =
lowestCostPlans.entrySet().iterator(); iterator.hasNext(); ) {
Map.Entry<PhysicalProperties, Pair<Cost, GroupExpression>> entry = iterator.next();
Pair<Cost, GroupExpression> pair = entry.getValue();
if (pair.second.equals(oldGroupExpression)) {
needReplaceBestExpressions.put(entry.getKey(), Pair.of(pair.first, newGroupExpression));
iterator.remove();
}
}
lowestCostPlans.putAll(needReplaceBestExpressions);
}
public Statistics getStatistics() {
return statistics;
}
public void setStatistics(Statistics statistics) {
this.statistics = statistics;
}
public LogicalProperties getLogicalProperties() {
return logicalProperties;
}
public void setLogicalProperties(LogicalProperties logicalProperties) {
this.logicalProperties = logicalProperties;
}
public boolean isExplored() {
return isExplored;
}
public void setExplored(boolean explored) {
isExplored = explored;
}
public List<GroupExpression> getParentGroupExpressions() {
return ImmutableList.copyOf(parentExpressions.keySet());
}
public void addParentExpression(GroupExpression parent) {
parentExpressions.put(parent, null);
}
/**
* remove the reference to parent groupExpression
*
* @param parent group expression
* @return parentExpressions's num
*/
public int removeParentExpression(GroupExpression parent) {
parentExpressions.remove(parent);
return parentExpressions.size();
}
public void removeParentPhysicalExpressions() {
parentExpressions.entrySet().removeIf(entry -> entry.getKey().getPlan() instanceof PhysicalPlan);
}
/**
* move the ownerGroup to target group.
*
* @param target the new owner group of expressions
*/
public void mergeTo(Group target) {
// move parentExpressions Ownership
parentExpressions.keySet().forEach(parent -> target.addParentExpression(parent));
// move enforcers Ownership
enforcers.forEach((k, v) -> k.children().set(0, target));
// TODO: dedup?
enforcers.forEach((k, v) -> target.addEnforcer(k));
enforcers.clear();
// move LogicalExpression PhysicalExpression Ownership
Map<GroupExpression, GroupExpression> logicalSet = target.getLogicalExpressions().stream()
.collect(Collectors.toMap(Function.identity(), Function.identity()));
for (GroupExpression logicalExpression : logicalExpressions) {
GroupExpression existGroupExpr = logicalSet.get(logicalExpression);
if (existGroupExpr != null) {
Preconditions.checkState(logicalExpression != existGroupExpr, "must not equals");
// lowCostPlans must be physical GroupExpression, don't need to replaceBestPlanGroupExpr
logicalExpression.mergeToNotOwnerRemove(existGroupExpr);
} else {
target.addLogicalExpression(logicalExpression);
}
}
logicalExpressions.clear();
// movePhysicalExpressionOwnership
Map<GroupExpression, GroupExpression> physicalSet = target.getPhysicalExpressions().stream()
.collect(Collectors.toMap(Function.identity(), Function.identity()));
for (GroupExpression physicalExpression : physicalExpressions) {
GroupExpression existGroupExpr = physicalSet.get(physicalExpression);
if (existGroupExpr != null) {
Preconditions.checkState(physicalExpression != existGroupExpr, "must not equals");
physicalExpression.getOwnerGroup().replaceBestPlanGroupExpr(physicalExpression, existGroupExpr);
physicalExpression.mergeToNotOwnerRemove(existGroupExpr);
} else {
target.addPhysicalExpression(physicalExpression);
}
}
physicalExpressions.clear();
// Above we already replaceBestPlanGroupExpr, but we still need to moveLowestCostPlansOwnership.
lowestCostPlans.forEach((physicalProperties, costAndGroupExpr) -> {
// move lowestCostPlans Ownership
if (!target.lowestCostPlans.containsKey(physicalProperties)) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
} else {
if (costAndGroupExpr.first.getValue()
< target.lowestCostPlans.get(physicalProperties).first.getValue()) {
target.lowestCostPlans.put(physicalProperties, costAndGroupExpr);
}
}
});
lowestCostPlans.clear();
// If statistics is null, use other statistics
if (target.statistics == null) {
target.statistics = this.statistics;
}
}
/**
* This function used to check whether the group is an end node in DPHyp
*/
public boolean isValidJoinGroup() {
Plan plan = getLogicalExpression().getPlan();
if (plan instanceof LogicalJoin
&& ((LogicalJoin) plan).getJoinType() == JoinType.INNER_JOIN
&& !((LogicalJoin) plan).isMarkJoin()) {
Preconditions.checkArgument(!((LogicalJoin) plan).getExpressions().isEmpty(),
"inner join must have join conjuncts");
if (((LogicalJoin) plan).getHashJoinConjuncts().isEmpty()
&& ((LogicalJoin) plan).getOtherJoinConjuncts().get(0) instanceof Literal) {
return false;
} else {
// Right now, we only support inner join with some conjuncts referencing any side of the child's output
return true;
}
}
return false;
}
public StructInfoMap getstructInfoMap() {
return structInfoMap;
}
public boolean isProjectGroup() {
return getLogicalExpression().getPlan() instanceof LogicalProject;
}
@Override
public boolean equals(Object o) {
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
Group group = (Group) o;
return groupId.equals(group.groupId);
}
@Override
public int hashCode() {
return Objects.hash(groupId);
}
@Override
public String toString() {
StringBuilder str = new StringBuilder("Group[" + groupId + "]\n");
str.append(" logical expressions:\n");
for (GroupExpression logicalExpression : logicalExpressions) {
str.append(" ").append(logicalExpression).append("\n");
}
str.append(" physical expressions:\n");
for (GroupExpression physicalExpression : physicalExpressions) {
str.append(" ").append(physicalExpression).append("\n");
}
str.append(" enforcers:\n");
for (GroupExpression enforcer : enforcers.keySet()) {
str.append(" ").append(enforcer).append("\n");
}
if (!chosenEnforcerIdList.isEmpty()) {
str.append(" chosen enforcer(id, requiredProperties):\n");
for (int i = 0; i < chosenEnforcerIdList.size(); i++) {
str.append(" (").append(i).append(")").append(chosenEnforcerIdList.get(i)).append(", ")
.append(chosenEnforcerPropertiesList.get(i)).append("\n");
}
}
if (chosenGroupExpressionId != -1) {
str.append(" chosen expression id: ").append(chosenGroupExpressionId).append("\n");
str.append(" chosen properties: ").append(chosenProperties).append("\n");
}
str.append(" stats").append("\n");
str.append(getStatistics() == null ? "" : getStatistics().detail(" "));
str.append(" lowest Plan(cost, properties, plan, childrenRequires)");
for (Map.Entry<PhysicalProperties, Pair<Cost, GroupExpression>> entry : lowestCostPlans.entrySet()) {
PhysicalProperties prop = entry.getKey();
Pair<Cost, GroupExpression> costGroupExpressionPair = entry.getValue();
Cost cost = costGroupExpressionPair.first;
GroupExpression child = costGroupExpressionPair.second;
str.append("\n\n ").append(cost.getValue()).append(" ").append(prop)
.append("\n ").append(child).append("\n ")
.append(child.getInputPropertiesListOrEmpty(prop));
}
str.append("\n").append(" struct info map").append("\n");
str.append(structInfoMap);
return str.toString();
}
/**
* Get tree like string describing group.
*
* @return tree like string describing group
*/
public String treeString() {
Function<Object, String> toString = obj -> {
if (obj instanceof Group) {
Group group = (Group) obj;
Map<PhysicalProperties, Cost> lowestCosts = group.getLowestCosts();
return "Group[" + group.groupId + ", lowestCosts: " + lowestCosts + "]";
} else if (obj instanceof GroupExpression) {
GroupExpression groupExpression = (GroupExpression) obj;
Map<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> lowestCostTable
= groupExpression.getLowestCostTable();
Map<PhysicalProperties, PhysicalProperties> requestPropertiesMap
= groupExpression.getRequestPropertiesMap();
Cost cost = groupExpression.getCost();
return groupExpression.getPlan().toString() + " [cost: " + cost + ", lowestCostTable: "
+ lowestCostTable + ", requestPropertiesMap: " + requestPropertiesMap + "]";
} else if (obj instanceof Pair) {
// print logicalExpressions or physicalExpressions
// first is name, second is group expressions
return ((Pair<?, ?>) obj).first.toString();
} else {
return obj.toString();
}
};
Function<Object, List<Object>> getChildren = obj -> {
if (obj instanceof Group) {
Group group = (Group) obj;
List children = new ArrayList<>();
// to <name, children> pair
if (!group.getLogicalExpressions().isEmpty()) {
children.add(Pair.of("logicalExpressions", group.getLogicalExpressions()));
}
if (!group.getPhysicalExpressions().isEmpty()) {
children.add(Pair.of("physicalExpressions", group.getPhysicalExpressions()));
}
return children;
} else if (obj instanceof GroupExpression) {
return (List) ((GroupExpression) obj).children();
} else if (obj instanceof Pair) {
return (List) ((Pair<String, List<GroupExpression>>) obj).second;
} else {
return ImmutableList.of();
}
};
Function<Object, List<Object>> getExtraPlans = obj -> {
if (obj instanceof Plan) {
return (List) ((Plan) obj).extraPlans();
} else {
return ImmutableList.of();
}
};
Function<Object, Boolean> displayExtraPlan = obj -> {
if (obj instanceof Plan) {
return ((Plan) obj).displayExtraPlanFirst();
} else {
return false;
}
};
return TreeStringUtils.treeString(this, toString, getChildren, getExtraPlans, displayExtraPlan);
}
public PhysicalProperties getChosenProperties() {
return chosenProperties;
}
public void setChosenProperties(PhysicalProperties chosenProperties) {
this.chosenProperties = chosenProperties;
}
public void setChosenGroupExpressionId(int chosenGroupExpressionId) {
Preconditions.checkArgument(this.chosenGroupExpressionId == -1,
"chosenGroupExpressionId is already set");
this.chosenGroupExpressionId = chosenGroupExpressionId;
}
public void addChosenEnforcerProperties(PhysicalProperties chosenEnforcerProperties) {
this.chosenEnforcerPropertiesList.add(chosenEnforcerProperties);
}
public void addChosenEnforcerId(int chosenEnforcerId) {
this.chosenEnforcerIdList.add(chosenEnforcerId);
}
}