CostModelV2.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.cost;

import org.apache.doris.nereids.PlanContext;
import org.apache.doris.nereids.properties.DistributionSpec;
import org.apache.doris.nereids.properties.DistributionSpecReplicated;
import org.apache.doris.nereids.trees.expressions.Slot;
import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalJoin;
import org.apache.doris.nereids.trees.plans.physical.AbstractPhysicalSort;
import org.apache.doris.nereids.trees.plans.physical.PhysicalAssertNumRows;
import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
import org.apache.doris.nereids.trees.plans.physical.PhysicalEsScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFileScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalFilter;
import org.apache.doris.nereids.trees.plans.physical.PhysicalGenerate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalHashJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalJdbcScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
import org.apache.doris.nereids.trees.plans.physical.PhysicalNestedLoopJoin;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOdbcScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalOlapScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalPartitionTopN;
import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
import org.apache.doris.nereids.trees.plans.physical.PhysicalRepeat;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSchemaScan;
import org.apache.doris.nereids.trees.plans.physical.PhysicalSetOperation;
import org.apache.doris.nereids.trees.plans.physical.PhysicalStorageLayerAggregate;
import org.apache.doris.nereids.trees.plans.physical.PhysicalUnion;
import org.apache.doris.nereids.trees.plans.physical.PhysicalWindow;
import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor;
import org.apache.doris.qe.SessionVariable;
import org.apache.doris.statistics.Statistics;

import com.google.common.base.Preconditions;

import java.util.stream.Collectors;

/**
 * This is a cost model to calculate the runCost and startCost of each operator
 */

class CostModelV2 extends PlanVisitor<Cost, PlanContext> {
    static double HASH_COST = 1.0;
    static double PROBE_COST = 1.2;
    static double CMP_COST = 1.5;
    static double PUSH_DOWN_AGG_COST = 0.1;

    private final SessionVariable sessionVariable;

    CostModelV2(SessionVariable sessionVariable) {
        this.sessionVariable = sessionVariable;
    }

    public static Cost addChildCost(Plan plan, Cost planCost, Cost childCost, int index) {
        Preconditions.checkArgument(childCost instanceof CostV2 && planCost instanceof CostV2);
        CostV2 planCostV2 = (CostV2) planCost;
        CostV2 childCostV2 = (CostV2) childCost;
        if (plan instanceof PhysicalLimit) {
            planCostV2 = new CostV2(childCostV2.getStartCost(), childCostV2.getRunCost() * planCostV2.getLimitRation(),
                    childCostV2.getMemory());
        } else if (plan instanceof AbstractPhysicalJoin) {
            if (index == 0) {
                planCostV2.updateChildCost(childCostV2.getStartCost(), childCostV2.getRunCost(),
                        childCostV2.getMemory());
            } else {
                planCostV2.updateChildCost(childCostV2.getRunCost(), 0, childCostV2.getMemory());
            }
        } else {
            planCostV2.updateChildCost(childCostV2.getStartCost(), childCostV2.getRunCost(), childCostV2.getMemory());
        }
        if (index == plan.arity() - 1) {
            planCostV2.finish();
        }
        return planCostV2;
    }

    @Override
    public Cost visit(Plan plan, PlanContext context) {
        return CostV2.zero();
    }

    @Override
    public Cost visitPhysicalOlapScan(PhysicalOlapScan physicalOlapScan, PlanContext context) {
        return calculateScanWithoutRF(context.getStatisticsWithCheck());
    }

    public Cost visitPhysicalSchemaScan(PhysicalSchemaScan physicalSchemaScan, PlanContext context) {
        return calculateScanWithoutRF(context.getStatisticsWithCheck());
    }

    @Override
    public Cost visitPhysicalStorageLayerAggregate(PhysicalStorageLayerAggregate storageLayerAggregate,
            PlanContext context) {

        Statistics stats = context.getStatisticsWithCheck();

        double ioCost = stats.computeSize(storageLayerAggregate.getOutput());

        double runCost1 = CostWeight.get(sessionVariable).weightSum(0, ioCost, 0) / stats.getBENumber();

        // Note the stats of this operator is the stats of relation.
        // We need add a plenty for this cost. Maybe changing rowCount of storageLayer is better
        double startCost = runCost1 / 2;
        double totalCost = startCost;

        double runCost = totalCost - startCost;
        return new CostV2(startCost, runCost, 0);
    }

    @Override
    public Cost visitPhysicalFileScan(PhysicalFileScan physicalFileScan, PlanContext context) {
        return calculateScanWithoutRF(context.getStatisticsWithCheck());
    }

    @Override
    public Cost visitPhysicalProject(PhysicalProject<? extends Plan> physicalProject, PlanContext context) {
        Statistics statistics = context.getStatisticsWithCheck();
        double cpuCost = statistics.getRowCount() * ExprCostModel.calculateExprCost(physicalProject.getProjects());

        double startCost = 0;
        double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / statistics.getBENumber();

        return new CostV2(startCost, runCost, 0);
    }

    @Override
    public Cost visitPhysicalJdbcScan(PhysicalJdbcScan physicalJdbcScan, PlanContext context) {
        return calculateScanWithoutRF(context.getStatisticsWithCheck());
    }

    @Override
    public Cost visitPhysicalOdbcScan(PhysicalOdbcScan physicalOdbcScan, PlanContext context) {
        return calculateScanWithoutRF(context.getStatisticsWithCheck());
    }

    @Override
    public Cost visitPhysicalEsScan(PhysicalEsScan physicalEsScan, PlanContext context) {
        return calculateScanWithoutRF(context.getStatisticsWithCheck());
    }

    @Override
    public Cost visitAbstractPhysicalSort(AbstractPhysicalSort<? extends Plan> sort, PlanContext context) {
        Statistics statistics = context.getStatisticsWithCheck();
        Statistics childStatistics = context.getChildStatistics(0);

        double runCost;
        if (sort.getSortPhase().isMerge()) {
            runCost = statistics.getRowCount() * CMP_COST * Math.log(childStatistics.getBENumber());
        } else {
            runCost = childStatistics.getRowCount() * CMP_COST * Math.log(statistics.getRowCount())
                    / statistics.getBENumber();
        }

        double startCost = runCost;
        return new CostV2(startCost, runCost, statistics.computeSize(sort.getOutput()));
    }

    @Override
    public Cost visitPhysicalPartitionTopN(PhysicalPartitionTopN<? extends Plan> partitionTopN, PlanContext context) {
        Statistics statistics = context.getStatisticsWithCheck();
        Statistics childStatistics = context.getChildStatistics(0);

        // Random set a value. The partitionTopN is generated in the rewrite phase,
        // and it only has one physical implementation. So this cost will not affect the result.
        double runCost = childStatistics.getRowCount() * CMP_COST * Math.log(statistics.getRowCount())
                / statistics.getBENumber();

        double startCost = runCost;
        return new CostV2(startCost, runCost, statistics.computeSize(partitionTopN.getOutput()));
    }

    @Override
    public Cost visitPhysicalDistribute(PhysicalDistribute<? extends Plan> distribute, PlanContext context) {
        Statistics childStatistics = context.getChildStatistics(0);
        double size = childStatistics.computeSize(distribute.getOutput());

        DistributionSpec spec = distribute.getDistributionSpec();
        double netCost;
        if (spec instanceof DistributionSpecReplicated) {
            netCost = getNetCost(size * childStatistics.getBENumber());
        } else {
            netCost = getNetCost(size);
        }

        double startCost = 0;
        double runCost = CostWeight.get(sessionVariable).weightSum(0, 0, netCost) / childStatistics.getBENumber();
        return new CostV2(startCost, runCost, 0);
    }

    @Override
    public Cost visitPhysicalHashAggregate(PhysicalHashAggregate<? extends Plan> aggregate, PlanContext context) {
        Statistics stats = context.getStatisticsWithCheck();
        Statistics childStats = context.getChildStatistics(0);

        double exprCost = ExprCostModel.calculateExprCost(aggregate.getExpressions());
        return calculateAggregate(stats, childStats, exprCost);
    }

    @Override
    public Cost visitPhysicalHashJoin(PhysicalHashJoin<? extends Plan, ? extends Plan> physicalHashJoin,
            PlanContext context) {
        Statistics stats = context.getStatisticsWithCheck();
        Statistics leftStats = context.getChildStatistics(0);
        Statistics rightStats = context.getChildStatistics(1);
        double otherExprCost = ExprCostModel.calculateExprCost(physicalHashJoin.getOtherJoinConjuncts());

        double buildTableCost = rightStats.getRowCount() * HASH_COST;
        if (context.isBroadcastJoin()) {
            buildTableCost *= stats.getBENumber();
        }
        double probeCost = leftStats.getRowCount() * PROBE_COST + stats.getRowCount() * otherExprCost;

        double startCost = CostWeight.get(sessionVariable).weightSum(buildTableCost, 0, 0);
        double runCost = CostWeight.get(sessionVariable).weightSum(probeCost, 0, 0) / stats.getBENumber();
        return new CostV2(startCost, runCost, rightStats.computeSize(physicalHashJoin.right().getOutput()));
    }

    @Override
    public Cost visitPhysicalNestedLoopJoin(PhysicalNestedLoopJoin<? extends Plan, ? extends Plan> nestedLoopJoin,
            PlanContext context) {
        Statistics stats = context.getStatisticsWithCheck();
        Statistics leftStats = context.getChildStatistics(0);
        Statistics rightStats = context.getChildStatistics(1);
        double otherExprCost = ExprCostModel.calculateExprCost(nestedLoopJoin.getOtherJoinConjuncts());

        //NSL materialized right child
        double probeCost = leftStats.getRowCount() * rightStats.getRowCount() * otherExprCost;
        if (!context.isBroadcastJoin()) {
            probeCost /= stats.getBENumber();
        }

        double startCost = 0;
        double runCost = CostWeight.get(sessionVariable).weightSum(probeCost, 0, 0) / stats.getBENumber();
        return new CostV2(startCost, runCost, rightStats.computeSize(nestedLoopJoin.right().getOutput()));
    }

    @Override
    public Cost visitPhysicalAssertNumRows(PhysicalAssertNumRows<? extends Plan> assertNumRows, PlanContext context) {
        return new CostV2(0, 0, 0);
    }

    @Override
    public Cost visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, PlanContext context) {
        CostV2 cost = new CostV2(0, 0, 0);
        long rows = limit.getLimit() + limit.getOffset();
        cost.setLimitRation(rows / context.getChildStatistics(0).getRowCount());
        return cost;
    }

    @Override
    public Cost visitPhysicalGenerate(PhysicalGenerate<? extends Plan> generate, PlanContext context) {
        Statistics statistics = context.getStatisticsWithCheck();

        double exprCost = ExprCostModel.calculateExprCost(generate.getGenerators());
        double cpuCost = exprCost * statistics.getRowCount();

        double startCost = 0;
        double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / statistics.getBENumber();
        return new CostV2(startCost, runCost, 0);
    }

    @Override
    public Cost visitPhysicalRepeat(PhysicalRepeat<? extends Plan> repeat, PlanContext context) {
        //Repeat expand the tuple according the groupSet
        return new CostV2(0, 0, 0);
    }

    @Override
    public Cost visitPhysicalWindow(PhysicalWindow<? extends Plan> window, PlanContext context) {
        Statistics stats = context.getStatisticsWithCheck();
        double exprCost = ExprCostModel.calculateExprCost(window.getWindowExpressions());
        double cpuCost = stats.getRowCount() * exprCost;

        double startCost = 0;
        double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / stats.getBENumber();
        return new CostV2(startCost, runCost, 0);
    }

    @Override
    public Cost visitPhysicalUnion(PhysicalUnion union, PlanContext context) {
        //Union all operation just concat all tuples
        return new CostV2(0, 0, 0);
    }

    @Override
    public Cost visitPhysicalSetOperation(PhysicalSetOperation intersect, PlanContext context) {
        int rowCount = 0;
        double size = 0;
        for (Statistics childStats : context.getChildrenStatistics()) {
            rowCount += childStats.getRowCount();
            size += childStats.computeSize(intersect.getOutput());
        }

        double startCost = CostWeight.get(sessionVariable).weightSum(rowCount * HASH_COST, 0, 0);
        double runCost = 0;

        return new CostV2(startCost, runCost, size);
    }

    @Override
    public Cost visitPhysicalFilter(PhysicalFilter physicalFilter, PlanContext context) {
        Statistics stats = context.getStatisticsWithCheck();

        double exprCost = ExprCostModel.calculateExprCost(physicalFilter.getExpressions());
        double cpuCost = exprCost * stats.getRowCount();

        double startCost = 0;
        double runCost = CostWeight.get(sessionVariable).weightSum(cpuCost, 0, 0) / stats.getBENumber();

        return new CostV2(startCost, runCost, 0);
    }

    private CostV2 calculateScanWithoutRF(Statistics stats) {
        //TODO: consider runtimeFilter
        // double io = stats.computeSize();
        // double startCost = 0;
        // double runCost = CostWeight.get(sessionVariable).weightSum(0, io, 0) / stats.getBENumber();
        // return new CostV2(startCost, runCost, 0);
        return (CostV2) CostV2.zero();
    }

    private CostV2 calculateAggregate(Statistics stats, Statistics childStats, double exprCost) {
        // Build HashTable
        double startCost = CostWeight.get(sessionVariable)
                .weightSum(HASH_COST * childStats.getRowCount() + exprCost * childStats.getRowCount(), 0, 0);
        double runCost = 0;
        return new CostV2(startCost, runCost, stats.computeSize(stats.columnStatistics().keySet()
                .stream().filter(Slot.class::isInstance)
                .map(expr -> (Slot) expr).collect(Collectors.toList())));
    }

    private double getNetCost(double size) {
        // we assume the transferRate is 4MB/s.
        // TODO: setting in session variable
        int transferRate = 4096 * 1024;
        return Math.ceil(size / transferRate);
    }
}