JoinCostEvaluation.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.planner;

import org.apache.doris.qe.ConnectContext;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/**
 * Evaluate the cost of Broadcast and Shuffle Join to choose between the two
 *
 * broadcast: send the rightChildFragment's output to each node executing the leftChildFragment; the cost across
 * all nodes is proportional to the total amount of data sent.
 * shuffle: also called Partitioned Join. That is, small tables and large tables are hashed according to the Join key,
 * and then distributed Join is performed.
 *
 * NOTICE:
 * for now, only MysqlScanNode and OlapScanNode has Cardinality. OlapScanNode's cardinality is calculated by row num
 * and data size, and MysqlScanNode's cardinality is always 1. Other ScanNode's cardinality is -1.
 * So if there are other kind of scan node in join query, it won't be able to calculate the cost of join normally
 * and result in both "broadcastCost" and "partitionCost" be 0. And this will lead to a SHUFFLE join.
 */
public class JoinCostEvaluation {
    private static final Logger LOG = LogManager.getLogger(JoinCostEvaluation.class);

    private final long rhsTreeCardinality;
    private final float rhsTreeAvgRowSize;
    private final int rhsTreeTupleIdNum;
    private final long lhsTreeCardinality;
    private final float lhsTreeAvgRowSize;
    private final int lhsTreeNumNodes;
    private long broadcastCost = 0;
    private long partitionCost = 0;

    JoinCostEvaluation(PlanNode node, PlanFragment rightChildFragment, PlanFragment leftChildFragment) {
        PlanNode rhsTree = rightChildFragment.getPlanRoot();
        rhsTreeCardinality = rhsTree.getCardinality();
        rhsTreeAvgRowSize = rhsTree.getAvgRowSize();
        rhsTreeTupleIdNum = rhsTree.getTupleIds().size();
        PlanNode lhsTree = leftChildFragment.getPlanRoot();
        lhsTreeCardinality = lhsTree.getCardinality();
        lhsTreeAvgRowSize = lhsTree.getAvgRowSize();
        lhsTreeNumNodes = leftChildFragment.getNumNodes();

        String nodeOverview = setNodeOverview(node, rightChildFragment, leftChildFragment);
        broadcastCost(nodeOverview);
        shuffleCost(nodeOverview);
    }

    private String setNodeOverview(PlanNode node, PlanFragment rightChildFragment, PlanFragment leftChildFragment) {
        return "root node id=" + node.getId().toString() + ": " + node.planNodeName
                + " right fragment id=" + rightChildFragment.getFragmentId().toString()
                + " left fragment id=" + leftChildFragment.getFragmentId().toString();
    }

    private void broadcastCost(String nodeOverview) {
        if (rhsTreeCardinality != -1 && lhsTreeNumNodes != -1) {
            broadcastCost = Math.round((double) rhsTreeCardinality * rhsTreeAvgRowSize) * lhsTreeNumNodes;
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug(nodeOverview);
            LOG.debug("broadcast: cost=" + broadcastCost);
            LOG.debug("rhs card=" + rhsTreeCardinality
                    + " rhs row_size=" + rhsTreeAvgRowSize
                    + " lhs nodes=" + lhsTreeNumNodes);
        }
    }

    /**
     * repartition: both left- and rightChildFragment are partitioned on the join exprs
     * TODO: take existing partition of input fragments into account to avoid unnecessary repartitioning
     */
    private void shuffleCost(String nodeOverview) {
        if (lhsTreeCardinality != -1 && rhsTreeCardinality != -1) {
            partitionCost = Math.round(
                    (double) lhsTreeCardinality * lhsTreeAvgRowSize + (double) rhsTreeCardinality * rhsTreeAvgRowSize);
        }
        if (LOG.isDebugEnabled()) {
            LOG.debug("nodeOverview: {}", nodeOverview);
            LOG.debug("partition: cost={} ", partitionCost);
            LOG.debug("lhs card={} row_size={}", lhsTreeCardinality, lhsTreeAvgRowSize);
            LOG.debug("rhs card={} row_size={}", rhsTreeCardinality, rhsTreeAvgRowSize);
        }
    }

    /**
     * When broadcastCost and partitionCost are equal, there is no uniform standard for which join implementation
     * is better. Some scenarios are suitable for broadcast join, and some scenarios are suitable for shuffle join.
     * Therefore, we add a SessionVariable to help users choose a better join implementation.
     */
    public boolean isBroadcastCostSmaller()  {
        String joinMethod = ConnectContext.get().getSessionVariable().getPreferJoinMethod();
        if (joinMethod.equalsIgnoreCase("broadcast")) {
            return broadcastCost <= partitionCost;
        } else {
            return broadcastCost < partitionCost;
        }
    }

    /**
     * Estimate the memory cost of constructing Hash Table in Broadcast Join.
     * The memory cost by the Hash Table = ((cardinality/0.75[1]) * 8[2])[3] + (cardinality * avgRowSize)[4]
     * + (nodeArrayLen[5] * 16[6])[7] + (nodeArrayLen * tupleNum[8] * 8[9])[10]. consists of four parts:
     * 1) All bucket pointers. 2) Length of the node array. 3) Overhead of all nodes. 4) Tuple pointers of all nodes.
     * - [1] Expansion factor of the number of HashTable buckets;
     * - [2] The pointer length of each bucket of HashTable;
     * - [3] bucketPointerSpace: The memory cost by all bucket pointers of HashTable;
     * - [4] rhsDataSize: The memory cost by all nodes of HashTable, equal to the amount of data that the right table
     *       participates in the construction of HashTable;
     * - [5] HashTable stores the length of the node array, which is larger than the actual cardinality. The initial
     *       value is 4096. When the storage is full, one-half of the current array length is added each time.
     *       The length of the array after each increment is actually a sequence of numbers:
     *          4096 = pow(3/2, 0) * 4096,
     *          6144 = pow(3/2, 1) * 4096,
     *          9216 = pow(3/2, 2) * 4096,
     *          13824 = pow(3/2, 3) * 4096,
     *       finally need to satisfy len(node array)> cardinality,
     *       so the number of increments = int((ln(cardinality/4096) / ln(3/2)) + 1),
     *       finally len(node array) = pow(3/2, int((ln(cardinality/4096) / ln(3/2)) + 1) * 4096
     * - [6] The overhead length of each node of HashTable, including the next node pointer, Hash value,
     *       and a bool type variable;
     * - [7] nodeOverheadSpace: The memory cost by the overhead of all nodes in the HashTable;
     * - [8] Number of Tuples participating in the build;
     * - [9] The length of each Tuple pointer;
     * - [10] nodeTuplePointerSpace: The memory cost by Tuple pointers of all nodes in HashTable;
     */
    public long constructHashTableSpace() {
        double bucketPointerSpace = ((double) rhsTreeCardinality / 0.75) * 8;
        double nodeArrayLen =
                Math.pow(1.5, (int) ((Math.log((double) rhsTreeCardinality / 4096) / Math.log(1.5)) + 1)) * 4096;
        double nodeOverheadSpace = nodeArrayLen * 16;
        double nodeTuplePointerSpace = nodeArrayLen * rhsTreeTupleIdNum * 8;
        return Math.round((bucketPointerSpace + (double) rhsTreeCardinality * rhsTreeAvgRowSize
                + nodeOverheadSpace + nodeTuplePointerSpace) * PlannerContext.HASH_TBL_SPACE_OVERHEAD);
    }
}