Coverage Report

Created: 2026-04-15 19:01

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exec/sink/scale_writer_partitioning_exchanger.hpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <memory>
21
#include <vector>
22
23
#include "core/block/block.h"
24
#include "exec/connector/skewed_partition_rebalancer.h"
25
#include "exec/partitioner/partitioner.h"
26
27
namespace doris {
28
class ScaleWriterPartitioner final : public PartitionerBase {
29
public:
30
    ScaleWriterPartitioner(int channel_size, int partition_count, int task_count,
31
                           int task_bucket_count,
32
                           long min_partition_data_processed_rebalance_threshold,
33
                           long min_data_processed_rebalance_threshold)
34
0
            : PartitionerBase(partition_count),
35
0
              _channel_size(channel_size),
36
0
              _partition_rebalancer(partition_count, task_count, task_bucket_count,
37
0
                                    min_partition_data_processed_rebalance_threshold,
38
0
                                    min_data_processed_rebalance_threshold),
39
0
              _partition_row_counts(partition_count, 0),
40
0
              _partition_writer_ids(partition_count, -1),
41
0
              _partition_writer_indexes(partition_count, 0),
42
0
              _task_count(task_count),
43
0
              _task_bucket_count(task_bucket_count),
44
              _min_partition_data_processed_rebalance_threshold(
45
0
                      min_partition_data_processed_rebalance_threshold),
46
0
              _min_data_processed_rebalance_threshold(min_data_processed_rebalance_threshold) {
47
0
        _crc_partitioner =
48
0
                std::make_unique<Crc32HashPartitioner<ShuffleChannelIds>>(_partition_count);
49
0
    }
50
51
0
    ~ScaleWriterPartitioner() override = default;
52
53
0
    Status init(const std::vector<TExpr>& texprs) override {
54
0
        return _crc_partitioner->init(texprs);
55
0
    }
56
57
0
    Status prepare(RuntimeState* state, const RowDescriptor& row_desc) override {
58
0
        return _crc_partitioner->prepare(state, row_desc);
59
0
    }
60
61
0
    Status open(RuntimeState* state) override { return _crc_partitioner->open(state); }
62
63
0
    Status close(RuntimeState* state) override { return _crc_partitioner->close(state); }
64
65
0
    Status do_partitioning(RuntimeState* state, Block* block) const override {
66
0
        _hash_vals.resize(block->rows());
67
0
        for (int partition_id = 0; partition_id < _partition_row_counts.size(); partition_id++) {
68
0
            _partition_row_counts[partition_id] = 0;
69
0
            _partition_writer_ids[partition_id] = -1;
70
0
        }
71
72
0
        _partition_rebalancer.rebalance();
73
74
0
        RETURN_IF_ERROR(_crc_partitioner->do_partitioning(state, block));
75
0
        const auto& channel_ids = _crc_partitioner->get_channel_ids();
76
0
        for (size_t position = 0; position < block->rows(); position++) {
77
0
            auto partition_id = channel_ids[position];
78
0
            _partition_row_counts[partition_id] += 1;
79
80
            // Get writer id for this partition by looking at the scaling state
81
0
            int writer_id = _partition_writer_ids[partition_id];
82
0
            if (writer_id == -1) {
83
0
                writer_id = _get_next_writer_id(partition_id);
84
0
                _partition_writer_ids[partition_id] = writer_id;
85
0
            }
86
0
            _hash_vals[position] = writer_id;
87
0
        }
88
89
0
        for (int partition_id = 0; partition_id < _partition_row_counts.size(); partition_id++) {
90
0
            _partition_rebalancer.add_partition_row_count(partition_id,
91
0
                                                          _partition_row_counts[partition_id]);
92
0
        }
93
0
        _partition_rebalancer.add_data_processed(block->bytes());
94
95
0
        return Status::OK();
96
0
    }
97
98
0
    const std::vector<HashValType>& get_channel_ids() const override { return _hash_vals; }
99
100
0
    Status clone(RuntimeState* state, std::unique_ptr<PartitionerBase>& partitioner) override {
101
0
        partitioner = std::make_unique<ScaleWriterPartitioner>(
102
0
                _channel_size, (int)_partition_count, _task_count, _task_bucket_count,
103
0
                _min_partition_data_processed_rebalance_threshold,
104
0
                _min_data_processed_rebalance_threshold);
105
0
        return Status::OK();
106
0
    }
107
108
private:
109
0
    int _get_next_writer_id(HashValType partition_id) const {
110
0
        return _partition_rebalancer.get_task_id(partition_id,
111
0
                                                 _partition_writer_indexes[partition_id]++);
112
0
    }
113
114
    int _channel_size;
115
    std::unique_ptr<PartitionerBase> _crc_partitioner;
116
    mutable SkewedPartitionRebalancer _partition_rebalancer;
117
    mutable std::vector<int> _partition_row_counts;
118
    mutable std::vector<int> _partition_writer_ids;
119
    mutable std::vector<int> _partition_writer_indexes;
120
    mutable std::vector<HashValType> _hash_vals;
121
    const int _task_count;
122
    const int _task_bucket_count;
123
    const long _min_partition_data_processed_rebalance_threshold;
124
    const long _min_data_processed_rebalance_threshold;
125
};
126
} // namespace doris