GroupExpression.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.memo;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.metrics.EventChannel;
import org.apache.doris.nereids.metrics.EventProducer;
import org.apache.doris.nereids.metrics.consumer.LogConsumer;
import org.apache.doris.nereids.metrics.event.CostStateUpdateEvent;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.rules.Rule;
import org.apache.doris.nereids.rules.RuleType;
import org.apache.doris.nereids.trees.expressions.StatementScopeIdGenerator;
import org.apache.doris.nereids.trees.plans.ObjectId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.util.Utils;
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Joiner;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;

import java.text.DecimalFormat;
import java.util.BitSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Optional;
import java.util.stream.Collectors;

/**
 * Representation for group expression in cascades optimizer.
 */
public class GroupExpression {
    private static final EventProducer COST_STATE_TRACER = new EventProducer(CostStateUpdateEvent.class,
            EventChannel.getDefaultChannel().addConsumers(new LogConsumer(CostStateUpdateEvent.class,
                    EventChannel.LOG)));
    private Cost cost;
    private Group ownerGroup;
    private final List<Group> children;
    private final Plan plan;
    private final BitSet ruleMasks;
    private boolean statDerived;

    private double estOutputRowCount = -1;

    // Record the rule that generate this plan. It's used for debugging
    private Rule fromRule;

    // Mapping from output properties to the corresponding best cost, statistics, and child properties.
    // key is the physical properties the group expression support for its parent
    // and value is cost and request physical properties to its children.
    private final Map<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> lowestCostTable;
    // Each physical group expression maintains mapping incoming requests to the corresponding child requests.
    // key is the output physical properties satisfying the incoming request properties
    // value is the request physical properties
    private final Map<PhysicalProperties, PhysicalProperties> requestPropertiesMap;

    // After mergeGroup(), source Group was cleaned up, but it may be in the Job Stack. So use this to mark and skip it.
    private boolean isUnused = false;

    private final ObjectId id = StatementScopeIdGenerator.newObjectId();

    /**
     * Just for UT.
     */
    public GroupExpression(Plan plan) {
        this(plan, Lists.newArrayList());
    }

    /**
     * Notice!!!: children will use param `children` directly, So don't modify it after this constructor outside.
     * Constructor for GroupExpression.
     *
     * @param plan {@link Plan} to reference
     * @param children children groups in memo
     */
    public GroupExpression(Plan plan, List<Group> children) {
        this.plan = Objects.requireNonNull(plan, "plan can not be null")
                .withGroupExpression(Optional.of(this));
        this.children = Objects.requireNonNull(children, "children can not be null");
        this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal());
        this.statDerived = false;
        this.lowestCostTable = Maps.newHashMap();
        this.requestPropertiesMap = Maps.newHashMap();
        for (Group child : children) {
            child.addParentExpression(this);
        }
    }

    public PhysicalProperties getOutputProperties(PhysicalProperties requestProperties) {
        PhysicalProperties outputProperties = requestPropertiesMap.get(requestProperties);
        Preconditions.checkNotNull(outputProperties);
        return outputProperties;
    }

    public int arity() {
        return children.size();
    }

    public void setFromRule(Rule rule) {
        this.fromRule = rule;
    }

    public Group getOwnerGroup() {
        return ownerGroup;
    }

    public void setOwnerGroup(Group ownerGroup) {
        this.ownerGroup = ownerGroup;
    }

    public Plan getPlan() {
        return plan;
    }

    public Group child(int i) {
        return children.get(i);
    }

    public void setChild(int i, Group group) {
        child(i).removeParentExpression(this);
        children.set(i, group);
        group.addParentExpression(this);
    }

    public List<Group> children() {
        return children;
    }

    /**
     * replaceChild.
     *
     * @param oldChild origin child group
     * @param newChild new child group
     */
    public void replaceChild(Group oldChild, Group newChild) {
        oldChild.removeParentExpression(this);
        newChild.addParentExpression(this);
        Utils.replaceList(children, oldChild, newChild);
    }

    public boolean hasApplied(Rule rule) {
        return ruleMasks.get(rule.getRuleType().ordinal());
    }

    public boolean notApplied(Rule rule) {
        return !hasApplied(rule);
    }

    public void setApplied(Rule rule) {
        ruleMasks.set(rule.getRuleType().ordinal());
    }

    public void propagateApplied(GroupExpression toGroupExpression) {
        toGroupExpression.ruleMasks.or(ruleMasks);
    }

    public void clearApplied() {
        ruleMasks.clear();
    }

    public boolean isStatDerived() {
        return statDerived;
    }

    public void setStatDerived(boolean statDerived) {
        this.statDerived = statDerived;
    }

    /**
     * Check this GroupExpression isUnused. See detail of `isUnused` in its comment.
     */
    public boolean isUnused() {
        if (isUnused) {
            Preconditions.checkState(children.isEmpty() && ownerGroup == null);
            return true;
        }
        Preconditions.checkState(ownerGroup != null);
        return false;
    }

    public void setUnused(boolean isUnused) {
        this.isUnused = isUnused;
    }

    public Map<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> getLowestCostTable() {
        return lowestCostTable;
    }

    public List<PhysicalProperties> getInputPropertiesList(PhysicalProperties require) {
        Preconditions.checkState(lowestCostTable.containsKey(require));
        return lowestCostTable.get(require).second;
    }

    public List<PhysicalProperties> getInputPropertiesListOrEmpty(PhysicalProperties require) {
        Pair<Cost, List<PhysicalProperties>> costAndChildRequire = lowestCostTable.get(require);
        return costAndChildRequire == null ? ImmutableList.of() : costAndChildRequire.second;
    }

    /**
     * Add a (outputProperties) -> (cost, childrenInputProperties) in lowestCostTable.
     * if the outputProperties exists, will be covered.
     *
     * @return true if lowest cost table change.
     */
    public boolean updateLowestCostTable(PhysicalProperties outputProperties,
            List<PhysicalProperties> childrenInputProperties, Cost cost) {
        COST_STATE_TRACER.log(CostStateUpdateEvent.of(this, cost.getValue(), outputProperties));
        if (lowestCostTable.containsKey(outputProperties)) {
            if (lowestCostTable.get(outputProperties).first.getValue() > cost.getValue()) {
                lowestCostTable.put(outputProperties, Pair.of(cost, childrenInputProperties));
                return true;
            } else {
                return false;
            }
        } else {
            lowestCostTable.put(outputProperties, Pair.of(cost, childrenInputProperties));
            return true;
        }
    }

    /**
     * get the lowest cost when satisfy property
     *
     * @param property property that needs to be satisfied
     * @return Lowest cost to satisfy that property
     */
    public double getCostByProperties(PhysicalProperties property) {
        Preconditions.checkState(lowestCostTable.containsKey(property));
        return lowestCostTable.get(property).first.getValue();
    }

    public Cost getCostValueByProperties(PhysicalProperties property) {
        Preconditions.checkState(lowestCostTable.containsKey(property));
        return lowestCostTable.get(property).first;
    }

    public void putOutputPropertiesMap(PhysicalProperties outputProperties,
            PhysicalProperties requiredProperties) {
        this.requestPropertiesMap.put(requiredProperties, outputProperties);
    }

    /**
     * Merge GroupExpression.
     */
    public void mergeTo(GroupExpression target) {
        this.ownerGroup.removeGroupExpression(this);
        this.mergeToNotOwnerRemove(target);
    }

    /**
     * Merge GroupExpression, but owner don't remove this GroupExpression.
     */
    public void mergeToNotOwnerRemove(GroupExpression target) {
        // LowestCostTable
        this.getLowestCostTable()
                .forEach((properties, pair) -> target.updateLowestCostTable(properties, pair.second, pair.first));
        // requestPropertiesMap
        // ATTN: when do merge, we should update target requestPropertiesMap
        //   ONLY IF the cost of source's request property lower than target one.
        //   Otherwise, the requestPropertiesMap will not sync with lowestCostTable.
        //   Then, we will get wrong output property when get the final plan.
        for (Map.Entry<PhysicalProperties, PhysicalProperties> entry : requestPropertiesMap.entrySet()) {
            PhysicalProperties request = entry.getKey();
            if (!target.requestPropertiesMap.containsKey(request)) {
                target.requestPropertiesMap.put(entry.getKey(), entry.getValue());
            } else {
                PhysicalProperties sourceOutput = entry.getValue();
                PhysicalProperties targetOutput = target.getRequestPropertiesMap().get(request);
                if (this.getLowestCostTable().containsKey(sourceOutput)
                        && target.getLowestCostTable().containsKey(targetOutput)) {
                    Cost sourceCost = this.getLowestCostTable().get(sourceOutput).first;
                    Cost targetCost = target.getLowestCostTable().get(targetOutput).first;
                    if (sourceCost.getValue() < targetCost.getValue()) {
                        target.requestPropertiesMap.put(entry.getKey(), entry.getValue());
                    }
                }
            }
        }
        // ruleMasks
        target.ruleMasks.or(this.ruleMasks);

        // clear
        this.children.forEach(child -> child.removeParentExpression(this));
        this.children.clear();
        this.ownerGroup = null;
    }

    public Cost getCost() {
        return cost;
    }

    public void setCost(Cost cost) {
        this.cost = cost;
    }

    @Override
    public boolean equals(Object o) {
        if (this == o) {
            return true;
        }
        if (o == null || getClass() != o.getClass()) {
            return false;
        }
        GroupExpression that = (GroupExpression) o;
        return children.equals(that.children) && plan.equals(that.plan);
    }

    @Override
    public int hashCode() {
        return Objects.hash(children, plan);
    }

    public Statistics childStatistics(int idx) {
        return child(idx).getStatistics();
    }

    public void setEstOutputRowCount(double estOutputRowCount) {
        this.estOutputRowCount = estOutputRowCount;
    }

    public Map<PhysicalProperties, PhysicalProperties> getRequestPropertiesMap() {
        return ImmutableMap.copyOf(requestPropertiesMap);
    }

    @Override
    public String toString() {
        DecimalFormat format = new DecimalFormat("#,###.##");
        StringBuilder builder = new StringBuilder("id:");
        builder.append(id.asInt());
        if (ownerGroup == null) {
            builder.append("OWNER GROUP IS NULL[]");
        } else {
            builder.append("#").append(ownerGroup.getGroupId().asInt());
        }
        if (cost != null) {
            builder.append(" cost=").append(cost.getValue() + " " + cost);
        } else {
            builder.append(" cost=null");
        }
        builder.append(" estRows=").append(format.format(estOutputRowCount));
        builder.append(" children=[").append(Joiner.on(", ").join(
                        children.stream().map(Group::getGroupId).collect(Collectors.toList())))
                .append(" ]");
        builder.append(" (plan=").append(plan.toString()).append(")");
        return builder.toString();
    }

    public ObjectId getId() {
        return id;
    }

    /**
     * the first child plan of clazz
     * @param clazz the operator type, like join/aggregate
     * @return child operator of type clazz, if not found, return null
     */
    public Plan getFirstChildPlan(Class clazz) {
        for (Group childGroup : children) {
            for (GroupExpression logical : childGroup.getLogicalExpressions()) {
                if (clazz.isInstance(logical.getPlan())) {
                    return logical.getPlan();
                }
            }
        }
        // for dphyp
        for (Group childGroup : children) {
            for (GroupExpression physical : childGroup.getPhysicalExpressions()) {
                if (clazz.isInstance(physical.getPlan())) {
                    return physical.getPlan();
                }
            }
        }
        return null;
    }
}