Coverage Report

Created: 2026-03-13 03:47

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/util/reservoir_sampler.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/ReservoirSampler.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include <cmath>
24
#include <cstddef>
25
#include <functional>
26
#include <limits>
27
28
#include "core/pod_array_fwd.h"
29
#include "core/string_buffer.hpp"
30
#include "core/types.h"
31
32
namespace doris {
33
34
/// Implementing the Reservoir Sampling algorithm. Incrementally selects from the added objects a random subset of the sample_count size.
35
/// Can approximately get quantiles.
36
/// Call `quantile` takes O(sample_count log sample_count), if after the previous call `quantile` there was at least one call `insert`. Otherwise O(1).
37
/// That is, it makes sense to first add, then get quantiles without adding.
38
/*
39
 * single stream/sequence (oneseq)
40
 */
41
42
template <typename T>
43
struct default_multiplier {
44
    // Not defined for an arbitrary type
45
};
46
47
template <typename T>
48
struct default_increment {
49
    // Not defined for an arbitrary type
50
};
51
52
#define PCG_DEFINE_CONSTANT(type, what, kind, constant) \
53
    template <>                                         \
54
    struct what##_##kind<type> {                        \
55
726k
        static constexpr type kind() {                  \
56
726k
            return constant;                            \
57
726k
        }                                               \
_ZN5doris18default_multiplierImE10multiplierEv
Line
Count
Source
55
726k
        static constexpr type kind() {                  \
56
726k
            return constant;                            \
57
726k
        }                                               \
Unexecuted instantiation: _ZN5doris18default_multiplierIhE10multiplierEv
Unexecuted instantiation: _ZN5doris17default_incrementIhE9incrementEv
Unexecuted instantiation: _ZN5doris18default_multiplierItE10multiplierEv
Unexecuted instantiation: _ZN5doris17default_incrementItE9incrementEv
Unexecuted instantiation: _ZN5doris18default_multiplierIjE10multiplierEv
Unexecuted instantiation: _ZN5doris17default_incrementIjE9incrementEv
Unexecuted instantiation: _ZN5doris17default_incrementImE9incrementEv
58
    };
59
60
PCG_DEFINE_CONSTANT(uint8_t, default, multiplier, 141U)
61
PCG_DEFINE_CONSTANT(uint8_t, default, increment, 77U)
62
63
PCG_DEFINE_CONSTANT(uint16_t, default, multiplier, 12829U)
64
PCG_DEFINE_CONSTANT(uint16_t, default, increment, 47989U)
65
66
PCG_DEFINE_CONSTANT(uint32_t, default, multiplier, 747796405U)
67
PCG_DEFINE_CONSTANT(uint32_t, default, increment, 2891336453U)
68
69
PCG_DEFINE_CONSTANT(uint64_t, default, multiplier, 6364136223846793005ULL)
70
PCG_DEFINE_CONSTANT(uint64_t, default, increment, 1442695040888963407ULL)
71
72
template <typename itype>
73
class oneseq_stream : public default_increment<itype> {
74
protected:
75
    static constexpr bool is_mcg = false;
76
77
    // Is never called, but is provided for symmetry with specific_stream
78
    void set_stream(...) { abort(); }
79
80
public:
81
    typedef itype state_type;
82
83
    static constexpr itype stream() { return default_increment<itype>::increment() >> 1; }
84
85
    static constexpr bool can_specify_stream = false;
86
87
    static constexpr size_t streams_pow2() { return 0U; }
88
89
protected:
90
    constexpr oneseq_stream() = default;
91
};
92
93
/*
94
 * no stream (mcg)
95
 */
96
97
template <typename itype>
98
class no_stream {
99
protected:
100
    static constexpr bool is_mcg = true;
101
102
    // Is never called, but is provided for symmetry with specific_stream
103
0
    void set_stream(...) { abort(); }
104
105
public:
106
    typedef itype state_type;
107
108
729k
    static constexpr itype increment() { return 0; }
109
110
    static constexpr bool can_specify_stream = false;
111
112
    static constexpr size_t streams_pow2() { return 0U; }
113
114
protected:
115
    constexpr no_stream() = default;
116
};
117
118
template <typename xtype, typename itype, typename output_mixin, bool output_previous = true,
119
          typename stream_mixin = oneseq_stream<itype>,
120
          typename multiplier_mixin = default_multiplier<itype>>
121
class engine : protected output_mixin, public stream_mixin, protected multiplier_mixin {
122
protected:
123
    itype state_;
124
125
    struct can_specify_stream_tag {};
126
    struct no_specifiable_stream_tag {};
127
128
    using stream_mixin::increment;
129
    using multiplier_mixin::multiplier;
130
131
public:
132
    typedef xtype result_type;
133
    typedef itype state_type;
134
135
    static constexpr size_t period_pow2() {
136
        return sizeof(state_type) * 8 - 2 * stream_mixin::is_mcg;
137
    }
138
139
    // It would be nice to use std::numeric_limits for these, but
140
    // we can't be sure that it'd be defined for the 128-bit types.
141
142
    static constexpr result_type min() { return result_type(0UL); }
143
144
756k
    static constexpr result_type max() { return result_type(~result_type(0UL)); }
145
146
protected:
147
725k
    itype bump(itype state) { return state * multiplier() + increment(); }
148
149
0
    itype base_generate() { return state_ = bump(state_); }
150
151
730k
    itype base_generate0() {
152
730k
        itype old_state = state_;
153
730k
        state_ = bump(state_);
154
730k
        return old_state;
155
730k
    }
156
157
public:
158
732k
    result_type operator()() {
159
733k
        if (output_previous) {
160
733k
            return this->output(base_generate0());
161
18.4E
        } else {
162
18.4E
            return this->output(base_generate());
163
18.4E
        }
164
732k
    }
165
166
    result_type operator()(result_type upper_bound) { return bounded_rand(*this, upper_bound); }
167
168
    engine(itype state = itype(0xcafef00dd15ea5e5ULL))
169
169
            : state_(this->is_mcg ? state | state_type(3U) : bump(state + this->increment())) {
170
        // Nothing else to do.
171
169
    }
172
173
    // This function may or may not exist.  It thus has to be a template
174
    // to use SFINAE; users don't have to worry about its template-ness.
175
176
    template <typename sm = stream_mixin>
177
    engine(itype state, typename sm::stream_state stream_seed)
178
            : stream_mixin(stream_seed),
179
              state_(this->is_mcg ? state | state_type(3U) : bump(state + this->increment())) {
180
        // Nothing else to do.
181
    }
182
183
    template <typename SeedSeq>
184
    engine(SeedSeq&& seedSeq,
185
           typename std::enable_if<!stream_mixin::can_specify_stream &&
186
                                           !std::is_convertible<SeedSeq, itype>::value &&
187
                                           !std::is_convertible<SeedSeq, engine>::value,
188
                                   no_specifiable_stream_tag>::type = {})
189
            : engine(generate_one<itype>(std::forward<SeedSeq>(seedSeq))) {
190
        // Nothing else to do.
191
    }
192
193
    template <typename SeedSeq>
194
    engine(SeedSeq&& seedSeq,
195
           typename std::enable_if<stream_mixin::can_specify_stream &&
196
                                           !std::is_convertible<SeedSeq, itype>::value &&
197
                                           !std::is_convertible<SeedSeq, engine>::value,
198
                                   can_specify_stream_tag>::type = {})
199
            : engine(generate_one<itype, 1, 2>(seedSeq), generate_one<itype, 0, 2>(seedSeq)) {
200
        // Nothing else to do.
201
    }
202
203
    template <typename... Args>
204
86
    void seed(Args&&... args) {
205
86
        new (this) engine(std::forward<Args>(args)...);
206
86
    }
207
208
    template <typename CharT, typename Traits, typename xtype1, typename itype1,
209
              typename output_mixin1, bool output_previous1, typename stream_mixin1,
210
              typename multiplier_mixin1>
211
    friend std::basic_ostream<CharT, Traits>& operator<<(
212
            std::basic_ostream<CharT, Traits>& out,
213
            const engine<xtype1, itype1, output_mixin1, output_previous1, stream_mixin1,
214
29
                         multiplier_mixin1>& rng) {
215
29
        auto orig_flags = out.flags(std::ios_base::dec | std::ios_base::left);
216
29
        auto space = out.widen(' ');
217
29
        auto orig_fill = out.fill();
218
219
29
        out << rng.multiplier() << space << rng.increment() << space << rng.state_;
220
221
29
        out.flags(orig_flags);
222
29
        out.fill(orig_fill);
223
29
        return out;
224
29
    }
225
226
    template <typename CharT, typename Traits, typename xtype1, typename itype1,
227
              typename output_mixin1, bool output_previous1, typename stream_mixin1,
228
              typename multiplier_mixin1>
229
    friend std::basic_istream<CharT, Traits>& operator>>(
230
            std::basic_istream<CharT, Traits>& in,
231
            engine<xtype1, itype1, output_mixin1, output_previous1, stream_mixin1,
232
29
                   multiplier_mixin1>& rng) {
233
29
        auto orig_flags = in.flags(std::ios_base::dec | std::ios_base::skipws);
234
235
29
        itype multiplier, increment, state;
236
29
        in >> multiplier >> increment >> state;
237
238
29
        if (!in.fail()) {
239
29
            bool good = true;
240
29
            if (multiplier != rng.multiplier()) {
241
0
                good = false;
242
29
            } else if (rng.can_specify_stream) {
243
0
                rng.set_stream(increment >> 1);
244
29
            } else if (increment != rng.increment()) {
245
0
                good = false;
246
0
            }
247
29
            if (good) {
248
29
                rng.state_ = state;
249
29
            } else {
250
0
                in.clear(std::ios::failbit);
251
0
            }
252
29
        }
253
254
29
        in.flags(orig_flags);
255
29
        return in;
256
29
    }
257
};
258
259
#ifndef PCG_BITCOUNT_T
260
typedef uint8_t bitcount_t;
261
#else
262
typedef PCG_BITCOUNT_T bitcount_t;
263
#endif
264
265
template <typename xtype, typename itype>
266
struct xsh_rs_mixin {
267
717k
    static xtype output(itype internal) {
268
717k
        constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
269
717k
        constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8);
270
717k
        constexpr bitcount_t sparebits = bits - xtypebits;
271
717k
        constexpr bitcount_t opbits = sparebits - 5 >= 64   ? 5
272
717k
                                      : sparebits - 4 >= 32 ? 4
273
717k
                                      : sparebits - 3 >= 16 ? 3
274
717k
                                      : sparebits - 2 >= 4  ? 2
275
717k
                                      : sparebits - 1 >= 1  ? 1
276
717k
                                                            : 0;
277
717k
        constexpr bitcount_t mask = (1 << opbits) - 1;
278
717k
        constexpr bitcount_t maxrandshift = mask;
279
717k
        constexpr bitcount_t topspare = opbits;
280
717k
        constexpr bitcount_t bottomspare = sparebits - topspare;
281
717k
        constexpr bitcount_t xshift = topspare + (xtypebits + maxrandshift) / 2;
282
717k
        bitcount_t rshift = opbits ? bitcount_t(internal >> (bits - opbits)) & mask : 0;
283
717k
        internal ^= internal >> xshift;
284
717k
        auto result = xtype(internal >> (bottomspare - maxrandshift + rshift));
285
717k
        return result;
286
717k
    }
287
};
288
289
template <typename xtype, typename itype, template <typename XT, typename IT> class output_mixin,
290
          bool output_previous = (sizeof(itype) <= 8)>
291
using mcg_base =
292
        engine<xtype, itype, output_mixin<xtype, itype>, output_previous, no_stream<itype>>;
293
294
class ReservoirSampler {
295
public:
296
83
    explicit ReservoirSampler(size_t sample_count_ = 8192) : sample_count(sample_count_) {
297
83
        rng.seed(123456);
298
83
    }
299
300
815k
    void insert(const double v) {
301
815k
        if (std::isnan(v)) {
302
0
            return;
303
0
        }
304
305
815k
        sorted = false;
306
815k
        ++total_values;
307
815k
        if (samples.size() < sample_count) {
308
31.8k
            samples.push_back(v);
309
783k
        } else {
310
783k
            uint64_t rnd = gen_random(total_values);
311
783k
            if (rnd < sample_count) {
312
118k
                samples[rnd] = v;
313
118k
            }
314
783k
        }
315
815k
    }
316
317
3
    void clear() {
318
3
        samples.clear();
319
3
        sorted = false;
320
3
        total_values = 0;
321
3
        rng.seed(123456);
322
3
    }
323
324
23
    double quantileInterpolated(double level) {
325
23
        if (samples.empty()) {
326
0
            return std::numeric_limits<double>::quiet_NaN();
327
0
        }
328
23
        sortIfNeeded();
329
330
23
        double index = std::max(0., std::min(samples.size() - 1., level * (samples.size() - 1)));
331
332
        /// To get the value of a fractional index, we linearly interpolate between neighboring values.
333
23
        auto left_index = static_cast<size_t>(index);
334
23
        size_t right_index = left_index + 1;
335
23
        if (right_index == samples.size()) {
336
1
            return samples[left_index];
337
1
        }
338
339
22
        double left_coef = right_index - index;
340
22
        double right_coef = index - left_index;
341
22
        return samples[left_index] * left_coef + samples[right_index] * right_coef;
342
23
    }
343
344
29
    void merge(const ReservoirSampler& b) {
345
29
        if (sample_count != b.sample_count) {
346
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR,
347
0
                                   "Cannot merge ReservoirSampler's with different sample_count");
348
0
        }
349
29
        sorted = false;
350
351
29
        if (b.total_values <= sample_count) {
352
75
            for (double sample : b.samples) {
353
75
                insert(sample);
354
75
            }
355
24
        } else if (total_values <= sample_count) {
356
1
            Array from = std::move(samples);
357
1
            samples.assign(b.samples.begin(), b.samples.end());
358
1
            total_values = b.total_values;
359
1
            for (double i : from) {
360
0
                insert(i);
361
0
            }
362
4
        } else {
363
            /// Replace every element in our reservoir to the b's reservoir
364
            /// with the probability of b.total_values / (a.total_values + b.total_values)
365
            /// Do it more roughly than true random sampling to save performance.
366
4
            total_values += b.total_values;
367
            /// Will replace every frequency'th element in a to element from b.
368
4
            double frequency = static_cast<double>(total_values) / b.total_values;
369
            /// When frequency is too low, replace just one random element with the corresponding probability.
370
4
            if (frequency * 2 >= sample_count) {
371
0
                uint64_t rnd = gen_random(static_cast<uint64_t>(frequency));
372
0
                if (rnd < sample_count) {
373
0
                    samples[rnd] = b.samples[rnd];
374
0
                }
375
4
            } else {
376
10.4k
                for (double i = 0; i < sample_count; i += frequency) {
377
10.4k
                    auto idx = static_cast<size_t>(i);
378
10.4k
                    samples[idx] = b.samples[idx];
379
10.4k
                }
380
4
            }
381
4
        }
382
29
    }
383
384
29
    void read(BufferReadable& buf) {
385
29
        buf.read_binary(sample_count);
386
29
        buf.read_binary(total_values);
387
388
29
        size_t size = std::min(total_values, sample_count);
389
29
        static constexpr size_t MAX_RESERVOIR_SIZE = 1024 * 1024 * 1024; // 1GB
390
29
        if (UNLIKELY(size > MAX_RESERVOIR_SIZE)) {
391
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Too large array size (maximum: {})",
392
0
                                   MAX_RESERVOIR_SIZE);
393
0
        }
394
395
29
        std::string rng_string;
396
29
        buf.read_binary(rng_string);
397
29
        std::stringstream rng_buf(rng_string);
398
29
        rng_buf >> rng;
399
400
29
        samples.resize(size);
401
41.0k
        for (double& sample : samples) {
402
41.0k
            buf.read_binary(sample);
403
41.0k
        }
404
405
29
        sorted = false;
406
29
    }
407
408
29
    void write(BufferWritable& buf) const {
409
29
        buf.write_binary(sample_count);
410
29
        buf.write_binary(total_values);
411
412
29
        std::stringstream rng_buf;
413
29
        rng_buf << rng;
414
29
        buf.write_binary(rng_buf.str());
415
416
41.0k
        for (size_t i = 0; i < std::min(sample_count, total_values); ++i) {
417
41.0k
            buf.write_binary(samples[i]);
418
41.0k
        }
419
29
    }
420
421
private:
422
    /// We allocate a little memory on the stack - to avoid allocations when there are many objects with a small number of elements.
423
    using Array = PODArrayWithStackMemory<double, 64>;
424
    using pcg32_fast = mcg_base<uint32_t, uint64_t, xsh_rs_mixin>;
425
426
    size_t sample_count;
427
    size_t total_values = 0;
428
    Array samples;
429
    pcg32_fast rng;
430
    bool sorted = false;
431
432
775k
    uint64_t gen_random(uint64_t limit) {
433
        /// With a large number of values, we will generate random numbers several times slower.
434
775k
        if (limit <= static_cast<uint64_t>(pcg32_fast::max())) {
435
749k
            return rng() % limit;
436
749k
        } else {
437
26.4k
            return (static_cast<uint64_t>(rng()) *
438
26.4k
                            (static_cast<uint64_t>(pcg32_fast::max()) + 1ULL) +
439
26.4k
                    static_cast<uint64_t>(rng())) %
440
26.4k
                   limit;
441
26.4k
        }
442
775k
    }
443
444
23
    void sortIfNeeded() {
445
23
        if (sorted) {
446
13
            return;
447
13
        }
448
10
        sorted = true;
449
10
        std::sort(samples.begin(), samples.end(), std::less<double>());
450
10
    }
451
};
452
453
} // namespace doris