MemoStatsAndCostRecomputer.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.stats;

import org.apache.doris.common.Pair;
import org.apache.doris.nereids.CascadesContext;
import org.apache.doris.nereids.cost.Cost;
import org.apache.doris.nereids.cost.CostCalculator;
import org.apache.doris.nereids.memo.Group;
import org.apache.doris.nereids.memo.GroupExpression;
import org.apache.doris.nereids.properties.PhysicalProperties;
import org.apache.doris.nereids.trees.expressions.CTEId;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.algebra.Join;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer;
import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEConsumer;
import org.apache.doris.nereids.trees.plans.physical.PhysicalCTEProducer;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.qe.ConnectContext;
import org.apache.doris.statistics.Statistics;

import com.google.common.collect.Lists;

import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Set;

/**
 * Re-estimate memo logical row counts and rebuild physical costs.
 * and rebuild physical cost state.
 */
public final class MemoStatsAndCostRecomputer {
    private static final double CHOSEN_PROJECT_STATS_DIVERGENCE_RATIO_THRESHOLD = 1_000D;
    private final CascadesContext cascadesContext;
    private final Map<CTEId, Statistics> cteIdToStats = new HashMap<>();
    private final LogicalExpressionRowCountSyncPolicy logicalExpressionRowCountSyncPolicy;

    private MemoStatsAndCostRecomputer(CascadesContext cascadesContext,
            LogicalExpressionRowCountSyncPolicy logicalExpressionRowCountSyncPolicy) {
        this.cascadesContext = cascadesContext;
        this.logicalExpressionRowCountSyncPolicy = logicalExpressionRowCountSyncPolicy;
    }

    /**
     * recompute
     */
    public static void recompute(Group rootGroup, PhysicalProperties physicalProperties,
            CascadesContext cascadesContext) {
        recompute(rootGroup, physicalProperties, cascadesContext,
                LogicalExpressionRowCountSyncPolicy.KEEP_INDIVIDUAL_EXPRESSION_ROW_COUNT);
    }

    /**
     * recompute with configurable logical expression row count sync behavior.
     */
    public static void recompute(Group rootGroup, PhysicalProperties physicalProperties,
            CascadesContext cascadesContext,
            LogicalExpressionRowCountSyncPolicy logicalExpressionRowCountSyncPolicy) {
        MemoStatsAndCostRecomputer recomputer = new MemoStatsAndCostRecomputer(cascadesContext,
                logicalExpressionRowCountSyncPolicy);
        recomputer.seedProducerStats(rootGroup, new HashSet<>());
        recomputer.reestimateLogicalStatsBottomUp(rootGroup, new HashSet<>());
        // Run a second pass so CTE consumers and their ancestors can settle on producer stats refreshed above.
        recomputer.reestimateLogicalStatsBottomUp(rootGroup, new HashSet<>());
        recomputer.recomputePhysicalCostsBottomUp(rootGroup, new HashSet<>());
    }

    private void seedProducerStats(Group group, Set<Group> visited) {
        if (!visited.add(group)) {
            return;
        }
        Statistics statistics = group.getStatistics();
        if (statistics != null) {
            recordProducerStats(group, statistics);
        }
        for (Group child : getTraversalChildren(group)) {
            seedProducerStats(child, visited);
        }
    }

    private void reestimateLogicalStatsBottomUp(Group group, Set<Group> visited) {
        if (!visited.add(group)) {
            return;
        }
        for (Group child : getTraversalChildren(group)) {
            reestimateLogicalStatsBottomUp(child, visited);
        }
        reestimateCurrentGroup(group);
        refreshEnforcerRowCount(group);
    }

    private void reestimateCurrentGroup(Group group) {
        List<GroupExpression> estimableExpressions = getEstimableLogicalExpressions(group);
        if (estimableExpressions.isEmpty()) {
            if (group.getLogicalExpressions().isEmpty()) {
                reestimatePhysicalOnlyGroup(group);
            }
            return;
        }
        Statistics originalStatistics = group.getStatistics();
        Map<GroupExpression, Statistics> candidateStatisticsByExpression = new LinkedHashMap<>();
        for (GroupExpression logicalExpression : estimableExpressions) {
            List<Statistics> originalChildStatistics = replaceChildStatisticsForLogicalEstimation(logicalExpression);
            group.setStatistics(null);
            try {
                estimateStats(logicalExpression);
            } finally {
                restoreChildStatistics(logicalExpression, originalChildStatistics);
            }
            Statistics estimatedStatistics = group.getStatistics();
            if (estimatedStatistics == null || !isValidCandidateStatistics(estimatedStatistics)) {
                continue;
            }
            logicalExpression.setEstOutputRowCount(estimatedStatistics.getRowCount());
            candidateStatisticsByExpression.put(logicalExpression, new Statistics(estimatedStatistics));
        }
        if (candidateStatisticsByExpression.isEmpty()) {
            group.setStatistics(originalStatistics);
            return;
        }
        LogicalRowCountAggregationPolicy aggregationPolicy = getLogicalRowCountAggregationPolicy();
        Map<GroupExpression, Statistics> selectedCandidateStatisticsByExpression = filterCandidateStatisticsByPolicy(
                aggregationPolicy, candidateStatisticsByExpression);
        List<Statistics> candidateStatistics = new ArrayList<>(selectedCandidateStatisticsByExpression.values());
        double aggregatedRowCount = aggregationPolicy.aggregate(candidateStatistics);
        Statistics updatedStatistics = resolveUpdatedGroupStatistics(group, selectedCandidateStatisticsByExpression,
                candidateStatistics, aggregatedRowCount, originalStatistics);
        group.setStatistics(updatedStatistics);
        repairInvalidLogicalExpressionRowCounts(group, aggregatedRowCount);
        refreshPhysicalExpressionRowCount(group, updatedStatistics.getRowCount());
        recordProducerStats(group, updatedStatistics);
        if (shouldSyncLogicalExpressionRowCount()) {
            syncLogicalExpressionRowCount(group, updatedStatistics.getRowCount());
        }
    }

    private void reestimatePhysicalOnlyGroup(Group group) {
        List<GroupExpression> estimableExpressions = getEstimablePhysicalExpressions(group);
        if (estimableExpressions.isEmpty()) {
            return;
        }
        Statistics originalStatistics = group.getStatistics();
        Map<GroupExpression, Statistics> candidateStatisticsByExpression = new LinkedHashMap<>();
        for (GroupExpression physicalExpression : estimableExpressions) {
            group.setStatistics(null);
            estimateStats(physicalExpression);
            Statistics estimatedStatistics = group.getStatistics();
            if (estimatedStatistics == null || !isValidCandidateStatistics(estimatedStatistics)) {
                continue;
            }
            physicalExpression.setEstOutputRowCount(estimatedStatistics.getRowCount());
            candidateStatisticsByExpression.put(physicalExpression, new Statistics(estimatedStatistics));
        }
        if (candidateStatisticsByExpression.isEmpty()) {
            group.setStatistics(originalStatistics);
            return;
        }
        Statistics updatedStatistics = choosePhysicalOnlyGroupStatistics(group, candidateStatisticsByExpression,
                originalStatistics);
        group.setStatistics(updatedStatistics);
        refreshPhysicalExpressionRowCount(group, updatedStatistics.getRowCount());
        recordProducerStats(group, updatedStatistics);
    }

    private boolean isValidCandidateStatistics(Statistics statistics) {
        return Double.isFinite(statistics.getRowCount()) && statistics.getRowCount() >= 0;
    }

    private void estimateStats(GroupExpression groupExpression) {
        ConnectContext connectContext = cascadesContext.getConnectContext();
        StatsCalculator statsCalculator = new StatsCalculator(
                groupExpression,
                connectContext.getSessionVariable().getForbidUnknownColStats(),
                connectContext.getTotalColumnStatisticMap(),
                connectContext.getSessionVariable().isPlayNereidsDump(),
                cteIdToStats,
                cascadesContext);
        statsCalculator.estimate();
    }

    private List<Statistics> replaceChildStatisticsForLogicalEstimation(GroupExpression logicalExpression) {
        return Collections.emptyList();
    }

    private void restoreChildStatistics(GroupExpression logicalExpression, List<Statistics> originalChildStatistics) {
        if (originalChildStatistics.size() != logicalExpression.arity()) {
            return;
        }
        for (int i = 0; i < logicalExpression.arity(); i++) {
            logicalExpression.child(i).setStatistics(originalChildStatistics.get(i));
        }
    }

    private void recomputePhysicalCostsBottomUp(Group group, Set<Group> visited) {
        if (!visited.add(group)) {
            return;
        }
        for (Group child : getTraversalChildren(group)) {
            recomputePhysicalCostsBottomUp(child, visited);
        }
        if (group.getStatistics() == null
                || (group.getPhysicalExpressions().isEmpty() && group.getEnforcers().isEmpty())) {
            refreshEnforcerRowCount(group);
            return;
        }
        Map<PhysicalProperties, Pair<Cost, GroupExpression>> originalLowestCostPlans =
                snapshotLowestCostPlans(group);
        group.clearLowestCostPlans();
        for (GroupExpression physicalExpression : group.getPhysicalExpressions()) {
            recomputeGroupExpressionCost(group, physicalExpression);
        }
        refreshEnforcerRowCount(group);
        for (GroupExpression enforcer : group.getEnforcers().values()) {
            recomputeGroupExpressionCost(group, enforcer);
        }
        restoreMissingLowestCostPlans(group, originalLowestCostPlans);
    }

    private void recomputeGroupExpressionCost(Group ownerGroup, GroupExpression groupExpression) {
        if (ownerGroup.getStatistics() == null || !hasCompleteChildStatistics(groupExpression)) {
            return;
        }
        Cost originalCost = groupExpression.getCost();
        Map<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> originalLowestCostTable
                = new LinkedHashMap<>(groupExpression.getLowestCostTable());
        Map<PhysicalProperties, PhysicalProperties> originalRequestPropertiesMap
                = new LinkedHashMap<>(groupExpression.getRequestPropertiesMap());
        groupExpression.clearCostState();

        Cost bestNodeCost = null;
        for (Map.Entry<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> entry
                : originalLowestCostTable.entrySet()) {
            PhysicalProperties outputProperties = entry.getKey();
            List<PhysicalProperties> childInputProperties = entry.getValue().second;
            if (!hasAvailableChildBestPlan(groupExpression, childInputProperties)) {
                continue;
            }
            Cost nodeCost = CostCalculator.calculateCost(cascadesContext.getConnectContext(),
                    groupExpression, childInputProperties);
            Cost totalCost = nodeCost;
            for (int i = 0; i < childInputProperties.size(); i++) {
                Optional<Pair<Cost, GroupExpression>> childBestPlan = groupExpression.child(i)
                        .getLowestCostPlan(childInputProperties.get(i));
                if (!childBestPlan.isPresent()) {
                    totalCost = null;
                    break;
                }
                totalCost = CostCalculator.addChildCost(cascadesContext.getConnectContext(),
                        groupExpression.getPlan(), totalCost, childBestPlan.get().first, i);
            }
            if (totalCost == null) {
                continue;
            }
            groupExpression.updateLowestCostTable(
                    outputProperties, childInputProperties, totalCost);
            ownerGroup.setBestPlan(groupExpression, totalCost, outputProperties);
            if (bestNodeCost == null || nodeCost.getValue() < bestNodeCost.getValue()) {
                bestNodeCost = nodeCost;
            }
        }
        restoreMissingExpressionCostState(groupExpression, originalLowestCostTable, originalRequestPropertiesMap);
        if (bestNodeCost != null) {
            groupExpression.setCost(bestNodeCost);
        } else {
            groupExpression.setCost(originalCost);
        }
        for (Map.Entry<PhysicalProperties, PhysicalProperties> entry
                : originalRequestPropertiesMap.entrySet()) {
            if (groupExpression.getLowestCostTable().containsKey(entry.getKey())) {
                groupExpression.putOutputPropertiesMap(entry.getValue(), entry.getKey());
            }
        }
    }

    private void restoreMissingExpressionCostState(GroupExpression groupExpression,
            Map<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> originalLowestCostTable,
            Map<PhysicalProperties, PhysicalProperties> originalRequestPropertiesMap) {
        for (Map.Entry<PhysicalProperties, Pair<Cost, List<PhysicalProperties>>> entry
                : originalLowestCostTable.entrySet()) {
            if (!groupExpression.getLowestCostTable().containsKey(entry.getKey())) {
                groupExpression.updateLowestCostTable(entry.getKey(), entry.getValue().second, entry.getValue().first);
            }
        }
        for (Map.Entry<PhysicalProperties, PhysicalProperties> entry : originalRequestPropertiesMap.entrySet()) {
            if (groupExpression.getLowestCostTable().containsKey(entry.getKey())) {
                groupExpression.putOutputPropertiesMap(entry.getValue(), entry.getKey());
            }
        }
    }

    private boolean hasAvailableChildBestPlan(GroupExpression groupExpression,
            List<PhysicalProperties> childInputProperties) {
        if (childInputProperties.size() != groupExpression.arity()) {
            return false;
        }
        for (int i = 0; i < childInputProperties.size(); i++) {
            if (!groupExpression.child(i)
                    .getLowestCostPlan(childInputProperties.get(i)).isPresent()) {
                return false;
            }
        }
        return true;
    }

    private void syncLogicalExpressionRowCount(Group group, double rowCount) {
        for (GroupExpression logicalExpression : group.getLogicalExpressions()) {
            if (logicalExpression.getEstOutputRowCount() > 0
                    || !Double.isFinite(logicalExpression.getEstOutputRowCount())) {
                logicalExpression.setEstOutputRowCount(rowCount);
            }
        }
    }

    private void refreshPhysicalExpressionRowCount(Group group, double rowCount) {
        for (GroupExpression physicalExpression : group.getPhysicalExpressions()) {
            physicalExpression.setEstOutputRowCount(getPhysicalExpressionRowCount(physicalExpression, rowCount));
        }
    }

    private double getPhysicalExpressionRowCount(GroupExpression physicalExpression, double rowCount) {
        if (physicalExpression.getPlan() instanceof PhysicalProject && physicalExpression.arity() == 1) {
            Statistics childStatistics = physicalExpression.child(0).getStatistics();
            if (childStatistics != null && Double.isFinite(childStatistics.getRowCount())
                    && childStatistics.getRowCount() >= 0) {
                return childStatistics.getRowCount();
            }
        }
        return rowCount;
    }

    private boolean shouldSyncLogicalExpressionRowCount() {
        return logicalExpressionRowCountSyncPolicy
                == LogicalExpressionRowCountSyncPolicy.SYNC_WITH_GROUP_ROW_COUNT;
    }

    private Map<GroupExpression, Statistics> filterCandidateStatisticsByPolicy(
            LogicalRowCountAggregationPolicy aggregationPolicy,
            Map<GroupExpression, Statistics> candidateStatisticsByExpression) {
        if (aggregationPolicy != LogicalRowCountAggregationPolicy.TRUST_JOIN_COUNT
                || candidateStatisticsByExpression.size() < 2) {
            return candidateStatisticsByExpression;
        }
        int maxTrustJoinCount = Integer.MIN_VALUE;
        Map<GroupExpression, Statistics> selectedCandidateStatisticsByExpression = new LinkedHashMap<>();
        for (Map.Entry<GroupExpression, Statistics> entry : candidateStatisticsByExpression.entrySet()) {
            int trustJoinCount = countTrustJoins(entry.getKey(), new HashSet<>());
            if (trustJoinCount > maxTrustJoinCount) {
                selectedCandidateStatisticsByExpression.clear();
                maxTrustJoinCount = trustJoinCount;
            }
            if (trustJoinCount == maxTrustJoinCount) {
                selectedCandidateStatisticsByExpression.put(entry.getKey(), entry.getValue());
            }
        }
        return selectedCandidateStatisticsByExpression;
    }

    private int countTrustJoins(GroupExpression groupExpression, Set<GroupExpression> visiting) {
        if (!visiting.add(groupExpression)) {
            return 0;
        }
        int trustJoinCount = isTrustJoin(groupExpression) ? 1 : 0;
        for (Group child : groupExpression.children()) {
            if (!child.getLogicalExpressions().isEmpty()) {
                trustJoinCount += countTrustJoins(child.getFirstLogicalExpression(), visiting);
            }
        }
        visiting.remove(groupExpression);
        return trustJoinCount;
    }

    private boolean isTrustJoin(GroupExpression groupExpression) {
        if (groupExpression.arity() != 2 || !(groupExpression.getPlan() instanceof Join)) {
            return false;
        }
        Statistics leftStats = groupExpression.child(0).getStatistics();
        Statistics rightStats = groupExpression.child(1).getStatistics();
        if (leftStats == null || rightStats == null) {
            return false;
        }
        return JoinEstimation.hasTrustableEqualCondition(leftStats, rightStats,
                (Join) groupExpression.getPlan());
    }

    private LogicalRowCountAggregationPolicy getLogicalRowCountAggregationPolicy() {
        ConnectContext connectContext = cascadesContext == null ? null : cascadesContext.getConnectContext();
        if (connectContext == null || connectContext.getSessionVariable() == null) {
            return LogicalRowCountAggregationPolicy.AVERAGE;
        }
        return LogicalRowCountAggregationPolicy.fromSessionValue(
                connectContext.getSessionVariable().getMemoLogicalRowCountAggregationPolicy());
    }

    private void refreshEnforcerRowCount(Group group) {
        Statistics statistics = group.getStatistics();
        if (statistics == null) {
            return;
        }
        for (GroupExpression enforcer : group.getEnforcers().values()) {
            enforcer.setEstOutputRowCount(statistics.getRowCount());
        }
    }

    private void recordProducerStats(Group group, Statistics statistics) {
        if (cascadesContext == null || statistics == null) {
            return;
        }
        for (GroupExpression logicalExpression : group.getLogicalExpressions()) {
            Plan plan = logicalExpression.getPlan();
            if (plan instanceof LogicalCTEProducer) {
                cteIdToStats.put(((LogicalCTEProducer<?>) plan).getCteId(), new Statistics(statistics));
            }
        }
        for (GroupExpression physicalExpression : group.getPhysicalExpressions()) {
            Plan plan = physicalExpression.getPlan();
            if (plan instanceof PhysicalCTEProducer) {
                cteIdToStats.put(((PhysicalCTEProducer<?>) plan).getCteId(), new Statistics(statistics));
            }
        }
    }

    private Map<PhysicalProperties, Pair<Cost, GroupExpression>> snapshotLowestCostPlans(Group group) {
        Map<PhysicalProperties, Pair<Cost, GroupExpression>> snapshot = new LinkedHashMap<>();
        for (PhysicalProperties properties : group.getAllProperties()) {
            group.getLowestCostPlan(properties).ifPresent(plan -> snapshot.put(properties, plan));
        }
        return snapshot;
    }

    private void restoreMissingLowestCostPlans(Group group,
            Map<PhysicalProperties, Pair<Cost, GroupExpression>> lowestCostPlans) {
        for (Map.Entry<PhysicalProperties, Pair<Cost, GroupExpression>> entry : lowestCostPlans.entrySet()) {
            if (!group.getLowestCostPlan(entry.getKey()).isPresent()) {
                group.putBestPlan(entry.getValue().second, entry.getValue().first, entry.getKey());
            }
        }
    }

    private Statistics chooseRepresentativeStatistics(List<Statistics> candidateStatistics,
            double aggregatedRowCount, Statistics originalStatistics) {
        Statistics bestMatch = null;
        double bestDistance = Double.POSITIVE_INFINITY;
        for (Statistics candidate : candidateStatistics) {
            double distance = Math.abs(candidate.getRowCount() - aggregatedRowCount);
            if (distance < bestDistance) {
                bestDistance = distance;
                bestMatch = candidate;
            }
        }
        if (bestMatch != null) {
            return bestMatch;
        }
        if (originalStatistics != null) {
            return new Statistics(originalStatistics);
        }
        return new Statistics(aggregatedRowCount, new HashMap<>());
    }

    private Statistics resolveUpdatedGroupStatistics(Group group,
            Map<GroupExpression, Statistics> candidateStatisticsByExpression,
            List<Statistics> candidateStatistics, double aggregatedRowCount,
            Statistics originalStatistics) {
        Statistics chosenProjectStatistics = resolveChosenProjectStatistics(group, candidateStatisticsByExpression,
                aggregatedRowCount);
        if (chosenProjectStatistics != null) {
            return chosenProjectStatistics;
        }
        Statistics representativeStatistics = chooseRepresentativeStatistics(
                candidateStatistics, aggregatedRowCount, originalStatistics);
        return representativeStatistics.withRowCountAndEnforceValid(aggregatedRowCount);
    }

    private Statistics resolveChosenProjectStatistics(Group group,
            Map<GroupExpression, Statistics> candidateStatisticsByExpression,
            double aggregatedRowCount) {
        if (!shouldPreserveChosenProjectStatistics(group, candidateStatisticsByExpression, aggregatedRowCount)) {
            return null;
        }
        Optional<Pair<Cost, GroupExpression>> lowestCostPlan = group.getLowestCostPlan(PhysicalProperties.ANY);
        if (!lowestCostPlan.isPresent()) {
            return null;
        }
        GroupExpression chosenPhysicalExpression = lowestCostPlan.get().second;
        for (Map.Entry<GroupExpression, Statistics> entry : candidateStatisticsByExpression.entrySet()) {
            if (entry.getKey().children().equals(chosenPhysicalExpression.children())) {
                return new Statistics(entry.getValue());
            }
        }
        return null;
    }

    private boolean shouldPreserveChosenProjectStatistics(Group group,
            Map<GroupExpression, Statistics> candidateStatisticsByExpression,
            double aggregatedRowCount) {
        if (candidateStatisticsByExpression.size() < 2 || !(aggregatedRowCount > 0)) {
            return false;
        }
        if (!group.getLowestCostPlan(PhysicalProperties.ANY).isPresent()) {
            return false;
        }
        GroupExpression chosenPhysicalExpression = group.getLowestCostPlan(PhysicalProperties.ANY).get().second;
        if (!(chosenPhysicalExpression.getPlan() instanceof PhysicalProject)) {
            return false;
        }
        for (GroupExpression logicalExpression : candidateStatisticsByExpression.keySet()) {
            if (!(logicalExpression.getPlan() instanceof LogicalProject)) {
                return false;
            }
        }
        double minRowCount = Double.POSITIVE_INFINITY;
        double maxRowCount = 0;
        double chosenRowCount = Double.NaN;
        for (Map.Entry<GroupExpression, Statistics> entry : candidateStatisticsByExpression.entrySet()) {
            double rowCount = entry.getValue().getRowCount();
            if (!(rowCount > 0)) {
                return false;
            }
            minRowCount = Math.min(minRowCount, rowCount);
            maxRowCount = Math.max(maxRowCount, rowCount);
            if (entry.getKey().children().equals(chosenPhysicalExpression.children())) {
                chosenRowCount = rowCount;
            }
        }
        if (!Double.isFinite(chosenRowCount) || chosenRowCount >= aggregatedRowCount) {
            return false;
        }
        return maxRowCount / minRowCount >= CHOSEN_PROJECT_STATS_DIVERGENCE_RATIO_THRESHOLD;
    }

    private List<GroupExpression> getEstimableLogicalExpressions(Group group) {
        List<GroupExpression> logicalExpressions = group.getLogicalExpressions();
        if (logicalExpressions.isEmpty()) {
            return Collections.emptyList();
        }
        List<GroupExpression> estimableExpressions = new ArrayList<>();
        for (GroupExpression logicalExpression : logicalExpressions) {
            if (hasEstimableLogicalExpressionRowCount(group, logicalExpression)
                    && hasCompleteChildStatistics(logicalExpression)
                    && hasAvailableCteStatistics(logicalExpression)) {
                estimableExpressions.add(logicalExpression);
            }
        }
        return estimableExpressions;
    }

    private List<GroupExpression> getEstimablePhysicalExpressions(Group group) {
        List<GroupExpression> estimableExpressions = new ArrayList<>();
        for (GroupExpression physicalExpression : group.getPhysicalExpressions()) {
            if (physicalExpression.arity() > 0
                    && !dependsOnOwnerGroupStatistics(group, physicalExpression)
                    && hasCompleteChildStatistics(physicalExpression)
                    && hasAvailableCteStatistics(physicalExpression)) {
                estimableExpressions.add(physicalExpression);
            }
        }
        for (GroupExpression enforcer : group.getEnforcers().values()) {
            if (enforcer.arity() > 0
                    && !dependsOnOwnerGroupStatistics(group, enforcer)
                    && hasCompleteChildStatistics(enforcer)
                    && hasAvailableCteStatistics(enforcer)) {
                estimableExpressions.add(enforcer);
            }
        }
        return estimableExpressions;
    }

    private Statistics choosePhysicalOnlyGroupStatistics(Group group,
            Map<GroupExpression, Statistics> candidateStatisticsByExpression,
            Statistics originalStatistics) {
        Optional<Pair<Cost, GroupExpression>> lowestCostPlan = group.getLowestCostPlan(PhysicalProperties.ANY);
        if (lowestCostPlan.isPresent()) {
            Statistics chosenStatistics = candidateStatisticsByExpression.get(lowestCostPlan.get().second);
            if (chosenStatistics != null) {
                return new Statistics(chosenStatistics);
            }
        }
        if (originalStatistics != null && Double.isFinite(originalStatistics.getRowCount())) {
            Statistics bestMatch = null;
            double bestDistance = Double.POSITIVE_INFINITY;
            for (Statistics candidateStatistics : candidateStatisticsByExpression.values()) {
                double distance = Math.abs(candidateStatistics.getRowCount() - originalStatistics.getRowCount());
                if (distance < bestDistance) {
                    bestDistance = distance;
                    bestMatch = candidateStatistics;
                }
            }
            if (bestMatch != null) {
                return new Statistics(bestMatch);
            }
        }
        return new Statistics(candidateStatisticsByExpression.values().iterator().next());
    }

    private void repairInvalidLogicalExpressionRowCounts(Group group, double rowCount) {
        for (GroupExpression logicalExpression : group.getLogicalExpressions()) {
            double expressionRowCount = logicalExpression.getEstOutputRowCount();
            if (!Double.isFinite(expressionRowCount) || expressionRowCount <= 0) {
                logicalExpression.setEstOutputRowCount(rowCount);
            }
        }
    }

    private boolean hasEstimableLogicalExpressionRowCount(Group group, GroupExpression logicalExpression) {
        return logicalExpression.getEstOutputRowCount() > 0
                || (!shouldSyncLogicalExpressionRowCount() && group.getStatistics() != null);
    }

    private boolean hasAvailableCteStatistics(GroupExpression groupExpression) {
        Plan plan = groupExpression.getPlan();
        if (plan instanceof LogicalCTEConsumer) {
            return cteIdToStats.containsKey(((LogicalCTEConsumer) plan).getCteId());
        }
        if (plan instanceof PhysicalCTEConsumer) {
            return cteIdToStats.containsKey(((PhysicalCTEConsumer) plan).getCteId());
        }
        return true;
    }

    private List<Group> getTraversalChildren(Group group) {
        List<Group> children = Lists.newArrayList();
        addChildren(children, group.getLogicalExpressions(), group);
        addChildren(children, group.getPhysicalExpressions(), group);
        addChildren(children, group.getEnforcers().values(), group);
        return children;
    }

    private void addChildren(List<Group> children, Iterable<GroupExpression> groupExpressions,
            Group ownerGroup) {
        for (GroupExpression groupExpression : groupExpressions) {
            for (Group child : groupExpression.children()) {
                if (child != ownerGroup) {
                    children.add(child);
                }
            }
        }
    }

    private boolean hasCompleteChildStatistics(GroupExpression groupExpression) {
        for (Group child : groupExpression.children()) {
            if (child.getStatistics() == null) {
                return false;
            }
        }
        return true;
    }

    private boolean dependsOnOwnerGroupStatistics(Group ownerGroup, GroupExpression groupExpression) {
        for (Group child : groupExpression.children()) {
            if (child == ownerGroup) {
                return true;
            }
        }
        return false;
    }

    /**
     * LogicalExpressionRowCountSyncPolicy
     */
    public enum LogicalExpressionRowCountSyncPolicy {
        SYNC_WITH_GROUP_ROW_COUNT,
        KEEP_INDIVIDUAL_EXPRESSION_ROW_COUNT
    }

    private enum LogicalRowCountAggregationPolicy {
        AVERAGE {
            @Override
            double aggregate(List<Statistics> candidateStatistics) {
                return candidateStatistics.stream()
                        .mapToDouble(Statistics::getRowCount)
                        .average()
                        .orElse(Double.NaN);
            }
        },
        MEDIAN {
            @Override
            double aggregate(List<Statistics> candidateStatistics) {
                double[] rowCounts = candidateStatistics.stream()
                        .mapToDouble(Statistics::getRowCount)
                        .sorted()
                        .toArray();
                if (rowCounts.length == 0) {
                    return Double.NaN;
                }
                int middle = rowCounts.length / 2;
                if ((rowCounts.length & 1) == 1) {
                    return rowCounts[middle];
                }
                return (rowCounts[middle - 1] + rowCounts[middle]) / 2;
            }
        },
        MIN {
            @Override
            double aggregate(List<Statistics> candidateStatistics) {
                return candidateStatistics.stream()
                        .mapToDouble(Statistics::getRowCount)
                        .min()
                        .orElse(Double.NaN);
            }
        },
        TRUST_JOIN_COUNT {
            @Override
            double aggregate(List<Statistics> candidateStatistics) {
                return MEDIAN.aggregate(candidateStatistics);
            }
        },
        MAX {
            @Override
            double aggregate(List<Statistics> candidateStatistics) {
                return candidateStatistics.stream()
                        .mapToDouble(Statistics::getRowCount)
                        .max()
                        .orElse(Double.NaN);
            }
        };

        static LogicalRowCountAggregationPolicy fromSessionValue(String value) {
            if (value == null) {
                return AVERAGE;
            }
            switch (value.toLowerCase()) {
                case "average":
                    return AVERAGE;
                case "median":
                    return MEDIAN;
                case "min":
                    return MIN;
                case "trust_join_count":
                    return TRUST_JOIN_COUNT;
                default:
                    throw new IllegalArgumentException("Unknown logical row count aggregation policy: " + value);
            }
        }

        abstract double aggregate(List<Statistics> candidateStatistics);
    }
}