Coverage Report

Created: 2026-06-17 07:44

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/util/reservoir_sampler.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/ReservoirSampler.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include <pdqsort.h>
24
25
#include <cmath>
26
#include <cstddef>
27
#include <functional>
28
#include <limits>
29
30
#include "core/pod_array_fwd.h"
31
#include "core/string_buffer.hpp"
32
#include "core/types.h"
33
34
namespace doris {
35
36
/// Implementing the Reservoir Sampling algorithm. Incrementally selects from the added objects a random subset of the sample_count size.
37
/// Can approximately get quantiles.
38
/// Call `quantile` takes O(sample_count log sample_count), if after the previous call `quantile` there was at least one call `insert`. Otherwise O(1).
39
/// That is, it makes sense to first add, then get quantiles without adding.
40
/*
41
 * single stream/sequence (oneseq)
42
 */
43
44
template <typename T>
45
struct default_multiplier {
46
    // Not defined for an arbitrary type
47
};
48
49
template <typename T>
50
struct default_increment {
51
    // Not defined for an arbitrary type
52
};
53
54
#define PCG_DEFINE_CONSTANT(type, what, kind, constant) \
55
    template <>                                         \
56
    struct what##_##kind<type> {                        \
57
895k
        static constexpr type kind() {                  \
58
895k
            return constant;                            \
59
895k
        }                                               \
_ZN5doris18default_multiplierImE10multiplierEv
Line
Count
Source
57
895k
        static constexpr type kind() {                  \
58
895k
            return constant;                            \
59
895k
        }                                               \
Unexecuted instantiation: _ZN5doris18default_multiplierIhE10multiplierEv
Unexecuted instantiation: _ZN5doris17default_incrementIhE9incrementEv
Unexecuted instantiation: _ZN5doris18default_multiplierItE10multiplierEv
Unexecuted instantiation: _ZN5doris17default_incrementItE9incrementEv
Unexecuted instantiation: _ZN5doris18default_multiplierIjE10multiplierEv
Unexecuted instantiation: _ZN5doris17default_incrementIjE9incrementEv
Unexecuted instantiation: _ZN5doris17default_incrementImE9incrementEv
60
    };
61
62
PCG_DEFINE_CONSTANT(uint8_t, default, multiplier, 141U)
63
PCG_DEFINE_CONSTANT(uint8_t, default, increment, 77U)
64
65
PCG_DEFINE_CONSTANT(uint16_t, default, multiplier, 12829U)
66
PCG_DEFINE_CONSTANT(uint16_t, default, increment, 47989U)
67
68
PCG_DEFINE_CONSTANT(uint32_t, default, multiplier, 747796405U)
69
PCG_DEFINE_CONSTANT(uint32_t, default, increment, 2891336453U)
70
71
PCG_DEFINE_CONSTANT(uint64_t, default, multiplier, 6364136223846793005ULL)
72
PCG_DEFINE_CONSTANT(uint64_t, default, increment, 1442695040888963407ULL)
73
74
template <typename itype>
75
class oneseq_stream : public default_increment<itype> {
76
protected:
77
    static constexpr bool is_mcg = false;
78
79
    // Is never called, but is provided for symmetry with specific_stream
80
    void set_stream(...) { abort(); }
81
82
public:
83
    typedef itype state_type;
84
85
    static constexpr itype stream() { return default_increment<itype>::increment() >> 1; }
86
87
    static constexpr bool can_specify_stream = false;
88
89
    static constexpr size_t streams_pow2() { return 0U; }
90
91
protected:
92
    constexpr oneseq_stream() = default;
93
};
94
95
/*
96
 * no stream (mcg)
97
 */
98
99
template <typename itype>
100
class no_stream {
101
protected:
102
    static constexpr bool is_mcg = true;
103
104
    // Is never called, but is provided for symmetry with specific_stream
105
0
    void set_stream(...) { abort(); }
106
107
public:
108
    typedef itype state_type;
109
110
895k
    static constexpr itype increment() { return 0; }
111
112
    static constexpr bool can_specify_stream = false;
113
114
    static constexpr size_t streams_pow2() { return 0U; }
115
116
protected:
117
    constexpr no_stream() = default;
118
};
119
120
template <typename xtype, typename itype, typename output_mixin, bool output_previous = true,
121
          typename stream_mixin = oneseq_stream<itype>,
122
          typename multiplier_mixin = default_multiplier<itype>>
123
class engine : protected output_mixin, public stream_mixin, protected multiplier_mixin {
124
protected:
125
    itype state_;
126
127
    struct can_specify_stream_tag {};
128
    struct no_specifiable_stream_tag {};
129
130
    using stream_mixin::increment;
131
    using multiplier_mixin::multiplier;
132
133
public:
134
    typedef xtype result_type;
135
    typedef itype state_type;
136
137
    static constexpr size_t period_pow2() {
138
        return sizeof(state_type) * 8 - 2 * stream_mixin::is_mcg;
139
    }
140
141
    // It would be nice to use std::numeric_limits for these, but
142
    // we can't be sure that it'd be defined for the 128-bit types.
143
144
    static constexpr result_type min() { return result_type(0UL); }
145
146
896k
    static constexpr result_type max() { return result_type(~result_type(0UL)); }
147
148
protected:
149
895k
    itype bump(itype state) { return state * multiplier() + increment(); }
150
151
0
    itype base_generate() { return state_ = bump(state_); }
152
153
896k
    itype base_generate0() {
154
896k
        itype old_state = state_;
155
896k
        state_ = bump(state_);
156
896k
        return old_state;
157
896k
    }
158
159
public:
160
896k
    result_type operator()() {
161
896k
        if (output_previous) {
162
896k
            return this->output(base_generate0());
163
18.4E
        } else {
164
18.4E
            return this->output(base_generate());
165
18.4E
        }
166
896k
    }
167
168
    result_type operator()(result_type upper_bound) { return bounded_rand(*this, upper_bound); }
169
170
    engine(itype state = itype(0xcafef00dd15ea5e5ULL))
171
99
            : state_(this->is_mcg ? state | state_type(3U) : bump(state + this->increment())) {
172
        // Nothing else to do.
173
99
    }
174
175
    // This function may or may not exist.  It thus has to be a template
176
    // to use SFINAE; users don't have to worry about its template-ness.
177
178
    template <typename sm = stream_mixin>
179
    engine(itype state, typename sm::stream_state stream_seed)
180
            : stream_mixin(stream_seed),
181
              state_(this->is_mcg ? state | state_type(3U) : bump(state + this->increment())) {
182
        // Nothing else to do.
183
    }
184
185
    template <typename SeedSeq>
186
    engine(SeedSeq&& seedSeq,
187
           typename std::enable_if<!stream_mixin::can_specify_stream &&
188
                                           !std::is_convertible<SeedSeq, itype>::value &&
189
                                           !std::is_convertible<SeedSeq, engine>::value,
190
                                   no_specifiable_stream_tag>::type = {})
191
            : engine(generate_one<itype>(std::forward<SeedSeq>(seedSeq))) {
192
        // Nothing else to do.
193
    }
194
195
    template <typename SeedSeq>
196
    engine(SeedSeq&& seedSeq,
197
           typename std::enable_if<stream_mixin::can_specify_stream &&
198
                                           !std::is_convertible<SeedSeq, itype>::value &&
199
                                           !std::is_convertible<SeedSeq, engine>::value,
200
                                   can_specify_stream_tag>::type = {})
201
            : engine(generate_one<itype, 1, 2>(seedSeq), generate_one<itype, 0, 2>(seedSeq)) {
202
        // Nothing else to do.
203
    }
204
205
    template <typename... Args>
206
51
    void seed(Args&&... args) {
207
51
        new (this) engine(std::forward<Args>(args)...);
208
51
    }
209
210
    template <typename CharT, typename Traits, typename xtype1, typename itype1,
211
              typename output_mixin1, bool output_previous1, typename stream_mixin1,
212
              typename multiplier_mixin1>
213
    friend std::basic_ostream<CharT, Traits>& operator<<(
214
            std::basic_ostream<CharT, Traits>& out,
215
            const engine<xtype1, itype1, output_mixin1, output_previous1, stream_mixin1,
216
16
                         multiplier_mixin1>& rng) {
217
16
        auto orig_flags = out.flags(std::ios_base::dec | std::ios_base::left);
218
16
        auto space = out.widen(' ');
219
16
        auto orig_fill = out.fill();
220
221
16
        out << rng.multiplier() << space << rng.increment() << space << rng.state_;
222
223
16
        out.flags(orig_flags);
224
16
        out.fill(orig_fill);
225
16
        return out;
226
16
    }
227
228
    template <typename CharT, typename Traits, typename xtype1, typename itype1,
229
              typename output_mixin1, bool output_previous1, typename stream_mixin1,
230
              typename multiplier_mixin1>
231
    friend std::basic_istream<CharT, Traits>& operator>>(
232
            std::basic_istream<CharT, Traits>& in,
233
            engine<xtype1, itype1, output_mixin1, output_previous1, stream_mixin1,
234
16
                   multiplier_mixin1>& rng) {
235
16
        auto orig_flags = in.flags(std::ios_base::dec | std::ios_base::skipws);
236
237
16
        itype multiplier, increment, state;
238
16
        in >> multiplier >> increment >> state;
239
240
16
        if (!in.fail()) {
241
16
            bool good = true;
242
16
            if (multiplier != rng.multiplier()) {
243
0
                good = false;
244
16
            } else if (rng.can_specify_stream) {
245
0
                rng.set_stream(increment >> 1);
246
16
            } else if (increment != rng.increment()) {
247
0
                good = false;
248
0
            }
249
16
            if (good) {
250
16
                rng.state_ = state;
251
16
            } else {
252
0
                in.clear(std::ios::failbit);
253
0
            }
254
16
        }
255
256
16
        in.flags(orig_flags);
257
16
        return in;
258
16
    }
259
};
260
261
#ifndef PCG_BITCOUNT_T
262
typedef uint8_t bitcount_t;
263
#else
264
typedef PCG_BITCOUNT_T bitcount_t;
265
#endif
266
267
template <typename xtype, typename itype>
268
struct xsh_rs_mixin {
269
895k
    static xtype output(itype internal) {
270
895k
        constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
271
895k
        constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8);
272
895k
        constexpr bitcount_t sparebits = bits - xtypebits;
273
895k
        constexpr bitcount_t opbits = sparebits - 5 >= 64   ? 5
274
895k
                                      : sparebits - 4 >= 32 ? 4
275
895k
                                      : sparebits - 3 >= 16 ? 3
276
895k
                                      : sparebits - 2 >= 4  ? 2
277
895k
                                      : sparebits - 1 >= 1  ? 1
278
895k
                                                            : 0;
279
895k
        constexpr bitcount_t mask = (1 << opbits) - 1;
280
895k
        constexpr bitcount_t maxrandshift = mask;
281
895k
        constexpr bitcount_t topspare = opbits;
282
895k
        constexpr bitcount_t bottomspare = sparebits - topspare;
283
895k
        constexpr bitcount_t xshift = topspare + (xtypebits + maxrandshift) / 2;
284
18.4E
        bitcount_t rshift = opbits ? bitcount_t(internal >> (bits - opbits)) & mask : 0;
285
895k
        internal ^= internal >> xshift;
286
895k
        auto result = xtype(internal >> (bottomspare - maxrandshift + rshift));
287
895k
        return result;
288
895k
    }
289
};
290
291
template <typename xtype, typename itype, template <typename XT, typename IT> class output_mixin,
292
          bool output_previous = (sizeof(itype) <= 8)>
293
using mcg_base =
294
        engine<xtype, itype, output_mixin<xtype, itype>, output_previous, no_stream<itype>>;
295
296
class ReservoirSampler {
297
public:
298
48
    explicit ReservoirSampler(size_t sample_count_ = 8192) : sample_count(sample_count_) {
299
48
        rng.seed(123456);
300
48
    }
301
302
913k
    void insert(const double v) {
303
913k
        if (std::isnan(v)) {
304
0
            return;
305
0
        }
306
307
913k
        sorted = false;
308
913k
        ++total_values;
309
913k
        if (samples.size() < sample_count) {
310
16.4k
            samples.push_back(v);
311
896k
        } else {
312
896k
            uint64_t rnd = gen_random(total_values);
313
896k
            if (rnd < sample_count) {
314
65.3k
                samples[rnd] = v;
315
65.3k
            }
316
896k
        }
317
913k
    }
318
319
123
    void insert_many(const double* values, size_t size) {
320
913k
        for (size_t i = 0; i < size; ++i) {
321
912k
            insert(values[i]);
322
912k
        }
323
123
    }
324
325
3
    void clear() {
326
3
        samples.clear();
327
3
        sorted = false;
328
3
        total_values = 0;
329
3
        rng.seed(123456);
330
3
    }
331
332
23
    double quantileInterpolated(double level) {
333
23
        if (samples.empty()) {
334
0
            return std::numeric_limits<double>::quiet_NaN();
335
0
        }
336
23
        sortIfNeeded();
337
338
23
        double index = std::max(0., std::min(samples.size() - 1., level * (samples.size() - 1)));
339
340
        /// To get the value of a fractional index, we linearly interpolate between neighboring values.
341
23
        auto left_index = static_cast<size_t>(index);
342
23
        size_t right_index = left_index + 1;
343
23
        if (right_index == samples.size()) {
344
1
            return samples[left_index];
345
1
        }
346
347
22
        double left_coef = right_index - index;
348
22
        double right_coef = index - left_index;
349
22
        return samples[left_index] * left_coef + samples[right_index] * right_coef;
350
23
    }
351
352
16
    void merge(const ReservoirSampler& b) {
353
16
        if (sample_count != b.sample_count) {
354
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR,
355
0
                                   "Cannot merge ReservoirSampler's with different sample_count");
356
0
        }
357
16
        sorted = false;
358
359
16
        if (b.total_values <= sample_count) {
360
75
            for (double sample : b.samples) {
361
75
                insert(sample);
362
75
            }
363
14
        } else if (total_values <= sample_count) {
364
1
            Array from = std::move(samples);
365
1
            samples.assign(b.samples.begin(), b.samples.end());
366
1
            total_values = b.total_values;
367
1
            for (double i : from) {
368
0
                insert(i);
369
0
            }
370
1
        } else {
371
            /// Replace every element in our reservoir to the b's reservoir
372
            /// with the probability of b.total_values / (a.total_values + b.total_values)
373
            /// Do it more roughly than true random sampling to save performance.
374
1
            total_values += b.total_values;
375
            /// Will replace every frequency'th element in a to element from b.
376
1
            double frequency = static_cast<double>(total_values) / b.total_values;
377
            /// When frequency is too low, replace just one random element with the corresponding probability.
378
1
            if (frequency * 2 >= sample_count) {
379
0
                uint64_t rnd = gen_random(static_cast<uint64_t>(frequency));
380
0
                if (rnd < sample_count) {
381
0
                    samples[rnd] = b.samples[rnd];
382
0
                }
383
1
            } else {
384
4.11k
                for (double i = 0; i < sample_count; i += frequency) {
385
4.11k
                    auto idx = static_cast<size_t>(i);
386
4.11k
                    samples[idx] = b.samples[idx];
387
4.11k
                }
388
1
            }
389
1
        }
390
16
    }
391
392
16
    void read(BufferReadable& buf) {
393
16
        buf.read_binary(sample_count);
394
16
        buf.read_binary(total_values);
395
396
16
        size_t size = std::min(total_values, sample_count);
397
16
        static constexpr size_t MAX_RESERVOIR_SIZE = 1024 * 1024 * 1024; // 1GB
398
16
        if (UNLIKELY(size > MAX_RESERVOIR_SIZE)) {
399
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Too large array size (maximum: {})",
400
0
                                   MAX_RESERVOIR_SIZE);
401
0
        }
402
403
16
        std::string rng_string;
404
16
        buf.read_binary(rng_string);
405
16
        std::stringstream rng_buf(rng_string);
406
16
        rng_buf >> rng;
407
408
16
        samples.resize(size);
409
16.4k
        for (double& sample : samples) {
410
16.4k
            buf.read_binary(sample);
411
16.4k
        }
412
413
16
        sorted = false;
414
16
    }
415
416
16
    void write(BufferWritable& buf) const {
417
16
        buf.write_binary(sample_count);
418
16
        buf.write_binary(total_values);
419
420
16
        std::stringstream rng_buf;
421
16
        rng_buf << rng;
422
16
        buf.write_binary(rng_buf.str());
423
424
16.4k
        for (size_t i = 0; i < std::min(sample_count, total_values); ++i) {
425
16.4k
            buf.write_binary(samples[i]);
426
16.4k
        }
427
16
    }
428
429
private:
430
    /// We allocate a little memory on the stack - to avoid allocations when there are many objects with a small number of elements.
431
    using Array = PODArrayWithStackMemory<double, 64>;
432
    using pcg32_fast = mcg_base<uint32_t, uint64_t, xsh_rs_mixin>;
433
434
    size_t sample_count;
435
    size_t total_values = 0;
436
    Array samples;
437
    pcg32_fast rng;
438
    bool sorted = false;
439
440
896k
    uint64_t gen_random(uint64_t limit) {
441
        /// With a large number of values, we will generate random numbers several times slower.
442
896k
        if (limit <= static_cast<uint64_t>(pcg32_fast::max())) {
443
896k
            return rng() % limit;
444
896k
        } else {
445
396
            return (static_cast<uint64_t>(rng()) *
446
396
                            (static_cast<uint64_t>(pcg32_fast::max()) + 1ULL) +
447
396
                    static_cast<uint64_t>(rng())) %
448
396
                   limit;
449
396
        }
450
896k
    }
451
452
23
    void sortIfNeeded() {
453
23
        if (sorted) {
454
13
            return;
455
13
        }
456
10
        sorted = true;
457
10
        pdqsort(samples.begin(), samples.end(), std::less<double>());
458
10
    }
459
};
460
461
} // namespace doris