StatsErrorEstimator.java

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

package org.apache.doris.nereids.stats;

import org.apache.doris.common.Pair;
import org.apache.doris.common.profile.ProfileManager;
import org.apache.doris.common.util.DebugUtil;
import org.apache.doris.nereids.trees.plans.AbstractPlan;
import org.apache.doris.persist.gson.GsonUtils;
import org.apache.doris.planner.PlanNode;
import org.apache.doris.planner.PlanNodeId;
import org.apache.doris.statistics.Statistics;
import org.apache.doris.thrift.TReportExecStatusParams;
import org.apache.doris.thrift.TRuntimeProfileNode;
import org.apache.doris.thrift.TUniqueId;

import com.google.gson.annotations.SerializedName;

import java.util.HashMap;
import java.util.Map;
import java.util.Map.Entry;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * Used to estimate the bias of stats estimation.
 */
public class StatsErrorEstimator {

    @SerializedName("legacyPlanIdToPhysicalPlan")
    private Map<Integer, Pair<Double, Double>> legacyPlanIdStats;

    @SerializedName("qError")
    private double qError;

    public StatsErrorEstimator() {
        legacyPlanIdStats = new HashMap<>();
    }

    /**
     * Invoked by PhysicalPlanTranslator, put the translated plan node and corresponding physical plan to estimator.
     */
    public void updateLegacyPlanIdToPhysicalPlan(PlanNode planNode, AbstractPlan physicalPlan) {
        Statistics statistics = physicalPlan.getStats();
        if (statistics == null) {
            return;
        }
        legacyPlanIdStats.put(planNode.getId().asInt(), Pair.of(statistics.getRowCount(),
                (double) 0));
    }

    /**
     *  Q-error:
     *      q = max_{i=1}^{n}(max(\frac{b^\prime}{b}, \frac{b}{b^\prime})
     */
    public double calculateQError() {
        double qError = Double.NEGATIVE_INFINITY;
        for (Entry<Integer, Pair<Double, Double>> entry : legacyPlanIdStats.entrySet()) {
            double exactReturnedRows = entry.getValue().second;
            double estimateReturnedRows = entry.getValue().first;
            qError = Math.max(qError,
                    Math.max(exactReturnedRows / oneIfZero(estimateReturnedRows),
                            estimateReturnedRows / oneIfZero(exactReturnedRows)));
        }
        return qError;
    }

    /**
     * Update extract returned rows incrementally, since there may be many execution instances of plan fragment.
     */
    public void updateExactReturnedRows(TReportExecStatusParams tReportExecStatusParams) {
        TUniqueId tUniqueId = tReportExecStatusParams.query_id;
        for (TRuntimeProfileNode runtimeProfileNode : tReportExecStatusParams.profile.nodes) {
            String name = runtimeProfileNode.name;
            int planId = extractPlanNodeIdFromName(name);
            if (planId == -1) {
                continue;
            }
            double rowsReturned = runtimeProfileNode.counters.stream()
                    .filter(p -> p.name.equals("RowsReturned")).mapToDouble(p -> (double) p.getValue()).sum();
            Pair<Double, Double> pair = legacyPlanIdStats.get(planId);
            if (pair == null) {
                continue;
            }
            pair.second = pair.second + rowsReturned;
        }
        this.qError = calculateQError();
        updateProfile(tUniqueId);
    }

    public void updateProfile(TUniqueId tUniqueId) {
        ProfileManager.getInstance()
                .setStatsErrorEstimator(DebugUtil.printId(tUniqueId), this);
    }

    /**
     * TODO: The execution report from BE doesn't have any schema, so we have to use regex to extract the plan node id.
     */
    private int extractPlanNodeIdFromName(String name) {
        Pattern p = Pattern.compile("\\b(?!dst_id=)id=(\\d+)\\b");
        Matcher m = p.matcher(name);
        if (!m.find()) {
            return -1;
        }
        return Integer.parseInt(m.group(1));
    }

    private Double extractRowsReturned(String rowsReturnedStr) {
        if (rowsReturnedStr == null) {
            return 0.0;
        }
        Pattern p = Pattern.compile("\\((\\d+)\\)");
        Matcher m = p.matcher(rowsReturnedStr);
        if (!m.find()) {
            return 0.0;
        }
        return Double.parseDouble(m.group(1));
    }

    private double oneIfZero(double d) {
        return d == 0.0 ? 1.0 : d;
    }

    public double getQError() {
        return qError;
    }

    public String toJson() {
        return GsonUtils.GSON.toJson(this);
    }

    // For test only.
    public void setExactReturnedRow(PlanNodeId planNodeId, Double d) {
        legacyPlanIdStats.get(planNodeId.asInt()).second += d;
    }
}