Coverage Report

Created: 2025-09-17 00:25

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
/root/doris/be/src/util/reservoir_sampler.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
// This file is copied from
18
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/ReservoirSampler.h
19
// and modified by Doris
20
21
#pragma once
22
23
#include <cmath>
24
#include <cstddef>
25
#include <functional>
26
#include <limits>
27
28
#include "vec/common/pod_array_fwd.h"
29
#include "vec/common/string_buffer.hpp"
30
#include "vec/core/types.h"
31
32
namespace doris {
33
34
/// Implementing the Reservoir Sampling algorithm. Incrementally selects from the added objects a random subset of the sample_count size.
35
/// Can approximately get quantiles.
36
/// Call `quantile` takes O(sample_count log sample_count), if after the previous call `quantile` there was at least one call `insert`. Otherwise O(1).
37
/// That is, it makes sense to first add, then get quantiles without adding.
38
/*
39
 * single stream/sequence (oneseq)
40
 */
41
42
template <typename T>
43
struct default_multiplier {
44
    // Not defined for an arbitrary type
45
};
46
47
template <typename T>
48
struct default_increment {
49
    // Not defined for an arbitrary type
50
};
51
52
#define PCG_DEFINE_CONSTANT(type, what, kind, constant) \
53
    template <>                                         \
54
    struct what##_##kind<type> {                        \
55
0
        static constexpr type kind() {                  \
56
0
            return constant;                            \
57
0
        }                                               \
Unexecuted instantiation: _ZN5doris18default_multiplierImE10multiplierEv
Unexecuted instantiation: _ZN5doris18default_multiplierIhE10multiplierEv
Unexecuted instantiation: _ZN5doris17default_incrementIhE9incrementEv
Unexecuted instantiation: _ZN5doris18default_multiplierItE10multiplierEv
Unexecuted instantiation: _ZN5doris17default_incrementItE9incrementEv
Unexecuted instantiation: _ZN5doris18default_multiplierIjE10multiplierEv
Unexecuted instantiation: _ZN5doris17default_incrementIjE9incrementEv
Unexecuted instantiation: _ZN5doris17default_incrementImE9incrementEv
58
    };
59
60
PCG_DEFINE_CONSTANT(uint8_t, default, multiplier, 141U)
61
PCG_DEFINE_CONSTANT(uint8_t, default, increment, 77U)
62
63
PCG_DEFINE_CONSTANT(uint16_t, default, multiplier, 12829U)
64
PCG_DEFINE_CONSTANT(uint16_t, default, increment, 47989U)
65
66
PCG_DEFINE_CONSTANT(uint32_t, default, multiplier, 747796405U)
67
PCG_DEFINE_CONSTANT(uint32_t, default, increment, 2891336453U)
68
69
PCG_DEFINE_CONSTANT(uint64_t, default, multiplier, 6364136223846793005ULL)
70
PCG_DEFINE_CONSTANT(uint64_t, default, increment, 1442695040888963407ULL)
71
72
template <typename itype>
73
class oneseq_stream : public default_increment<itype> {
74
protected:
75
    static constexpr bool is_mcg = false;
76
77
    // Is never called, but is provided for symmetry with specific_stream
78
    void set_stream(...) { abort(); }
79
80
public:
81
    typedef itype state_type;
82
83
    static constexpr itype stream() { return default_increment<itype>::increment() >> 1; }
84
85
    static constexpr bool can_specify_stream = false;
86
87
    static constexpr size_t streams_pow2() { return 0U; }
88
89
protected:
90
    constexpr oneseq_stream() = default;
91
};
92
93
/*
94
 * no stream (mcg)
95
 */
96
97
template <typename itype>
98
class no_stream {
99
protected:
100
    static constexpr bool is_mcg = true;
101
102
    // Is never called, but is provided for symmetry with specific_stream
103
0
    void set_stream(...) { abort(); }
104
105
public:
106
    typedef itype state_type;
107
108
0
    static constexpr itype increment() { return 0; }
109
110
    static constexpr bool can_specify_stream = false;
111
112
    static constexpr size_t streams_pow2() { return 0U; }
113
114
protected:
115
    constexpr no_stream() = default;
116
};
117
118
template <typename xtype, typename itype, typename output_mixin, bool output_previous = true,
119
          typename stream_mixin = oneseq_stream<itype>,
120
          typename multiplier_mixin = default_multiplier<itype>>
121
class engine : protected output_mixin, public stream_mixin, protected multiplier_mixin {
122
protected:
123
    itype state_;
124
125
    struct can_specify_stream_tag {};
126
    struct no_specifiable_stream_tag {};
127
128
    using stream_mixin::increment;
129
    using multiplier_mixin::multiplier;
130
131
public:
132
    typedef xtype result_type;
133
    typedef itype state_type;
134
135
    static constexpr size_t period_pow2() {
136
        return sizeof(state_type) * 8 - 2 * stream_mixin::is_mcg;
137
    }
138
139
    // It would be nice to use std::numeric_limits for these, but
140
    // we can't be sure that it'd be defined for the 128-bit types.
141
142
    static constexpr result_type min() { return result_type(0UL); }
143
144
0
    static constexpr result_type max() { return result_type(~result_type(0UL)); }
145
146
protected:
147
0
    itype bump(itype state) { return state * multiplier() + increment(); }
148
149
0
    itype base_generate() { return state_ = bump(state_); }
150
151
0
    itype base_generate0() {
152
0
        itype old_state = state_;
153
0
        state_ = bump(state_);
154
0
        return old_state;
155
0
    }
156
157
public:
158
0
    result_type operator()() {
159
0
        if (output_previous) {
160
0
            return this->output(base_generate0());
161
0
        } else {
162
0
            return this->output(base_generate());
163
0
        }
164
0
    }
165
166
    result_type operator()(result_type upper_bound) { return bounded_rand(*this, upper_bound); }
167
168
    engine(itype state = itype(0xcafef00dd15ea5e5ULL))
169
0
            : state_(this->is_mcg ? state | state_type(3U) : bump(state + this->increment())) {
170
        // Nothing else to do.
171
0
    }
172
173
    // This function may or may not exist.  It thus has to be a template
174
    // to use SFINAE; users don't have to worry about its template-ness.
175
176
    template <typename sm = stream_mixin>
177
    engine(itype state, typename sm::stream_state stream_seed)
178
            : stream_mixin(stream_seed),
179
              state_(this->is_mcg ? state | state_type(3U) : bump(state + this->increment())) {
180
        // Nothing else to do.
181
    }
182
183
    template <typename SeedSeq>
184
    engine(SeedSeq&& seedSeq,
185
           typename std::enable_if<!stream_mixin::can_specify_stream &&
186
                                           !std::is_convertible<SeedSeq, itype>::value &&
187
                                           !std::is_convertible<SeedSeq, engine>::value,
188
                                   no_specifiable_stream_tag>::type = {})
189
            : engine(generate_one<itype>(std::forward<SeedSeq>(seedSeq))) {
190
        // Nothing else to do.
191
    }
192
193
    template <typename SeedSeq>
194
    engine(SeedSeq&& seedSeq,
195
           typename std::enable_if<stream_mixin::can_specify_stream &&
196
                                           !std::is_convertible<SeedSeq, itype>::value &&
197
                                           !std::is_convertible<SeedSeq, engine>::value,
198
                                   can_specify_stream_tag>::type = {})
199
            : engine(generate_one<itype, 1, 2>(seedSeq), generate_one<itype, 0, 2>(seedSeq)) {
200
        // Nothing else to do.
201
    }
202
203
    template <typename... Args>
204
0
    void seed(Args&&... args) {
205
0
        new (this) engine(std::forward<Args>(args)...);
206
0
    }
207
208
    template <typename CharT, typename Traits, typename xtype1, typename itype1,
209
              typename output_mixin1, bool output_previous1, typename stream_mixin1,
210
              typename multiplier_mixin1>
211
    friend std::basic_ostream<CharT, Traits>& operator<<(
212
            std::basic_ostream<CharT, Traits>& out,
213
            const engine<xtype1, itype1, output_mixin1, output_previous1, stream_mixin1,
214
0
                         multiplier_mixin1>& rng) {
215
0
        auto orig_flags = out.flags(std::ios_base::dec | std::ios_base::left);
216
0
        auto space = out.widen(' ');
217
0
        auto orig_fill = out.fill();
218
219
0
        out << rng.multiplier() << space << rng.increment() << space << rng.state_;
220
221
0
        out.flags(orig_flags);
222
0
        out.fill(orig_fill);
223
0
        return out;
224
0
    }
225
226
    template <typename CharT, typename Traits, typename xtype1, typename itype1,
227
              typename output_mixin1, bool output_previous1, typename stream_mixin1,
228
              typename multiplier_mixin1>
229
    friend std::basic_istream<CharT, Traits>& operator>>(
230
            std::basic_istream<CharT, Traits>& in,
231
            engine<xtype1, itype1, output_mixin1, output_previous1, stream_mixin1,
232
0
                   multiplier_mixin1>& rng) {
233
0
        auto orig_flags = in.flags(std::ios_base::dec | std::ios_base::skipws);
234
235
0
        itype multiplier, increment, state;
236
0
        in >> multiplier >> increment >> state;
237
238
0
        if (!in.fail()) {
239
0
            bool good = true;
240
0
            if (multiplier != rng.multiplier()) {
241
0
                good = false;
242
0
            } else if (rng.can_specify_stream) {
243
0
                rng.set_stream(increment >> 1);
244
0
            } else if (increment != rng.increment()) {
245
0
                good = false;
246
0
            }
247
0
            if (good) {
248
0
                rng.state_ = state;
249
0
            } else {
250
0
                in.clear(std::ios::failbit);
251
0
            }
252
0
        }
253
254
0
        in.flags(orig_flags);
255
0
        return in;
256
0
    }
257
};
258
259
#ifndef PCG_BITCOUNT_T
260
typedef uint8_t bitcount_t;
261
#else
262
typedef PCG_BITCOUNT_T bitcount_t;
263
#endif
264
265
template <typename xtype, typename itype>
266
struct xsh_rs_mixin {
267
0
    static xtype output(itype internal) {
268
0
        constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8);
269
0
        constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8);
270
0
        constexpr bitcount_t sparebits = bits - xtypebits;
271
0
        constexpr bitcount_t opbits = sparebits - 5 >= 64   ? 5
272
0
                                      : sparebits - 4 >= 32 ? 4
273
0
                                      : sparebits - 3 >= 16 ? 3
274
0
                                      : sparebits - 2 >= 4  ? 2
275
0
                                      : sparebits - 1 >= 1  ? 1
276
0
                                                            : 0;
277
0
        constexpr bitcount_t mask = (1 << opbits) - 1;
278
0
        constexpr bitcount_t maxrandshift = mask;
279
0
        constexpr bitcount_t topspare = opbits;
280
0
        constexpr bitcount_t bottomspare = sparebits - topspare;
281
0
        constexpr bitcount_t xshift = topspare + (xtypebits + maxrandshift) / 2;
282
0
        bitcount_t rshift = opbits ? bitcount_t(internal >> (bits - opbits)) & mask : 0;
283
0
        internal ^= internal >> xshift;
284
0
        auto result = xtype(internal >> (bottomspare - maxrandshift + rshift));
285
0
        return result;
286
0
    }
287
};
288
289
template <typename xtype, typename itype, template <typename XT, typename IT> class output_mixin,
290
          bool output_previous = (sizeof(itype) <= 8)>
291
using mcg_base =
292
        engine<xtype, itype, output_mixin<xtype, itype>, output_previous, no_stream<itype>>;
293
294
class ReservoirSampler {
295
public:
296
0
    explicit ReservoirSampler(size_t sample_count_ = 8192) : sample_count(sample_count_) {
297
0
        rng.seed(123456);
298
0
    }
299
300
0
    void insert(const double v) {
301
0
        if (std::isnan(v)) {
302
0
            return;
303
0
        }
304
305
0
        sorted = false;
306
0
        ++total_values;
307
0
        if (samples.size() < sample_count) {
308
0
            samples.push_back(v);
309
0
        } else {
310
0
            uint64_t rnd = gen_random(total_values);
311
0
            if (rnd < sample_count) {
312
0
                samples[rnd] = v;
313
0
            }
314
0
        }
315
0
    }
316
317
0
    void clear() {
318
0
        samples.clear();
319
0
        sorted = false;
320
0
        total_values = 0;
321
0
        rng.seed(123456);
322
0
    }
323
324
0
    double quantileInterpolated(double level) {
325
0
        if (samples.empty()) {
326
0
            return std::numeric_limits<double>::quiet_NaN();
327
0
        }
328
0
        sortIfNeeded();
329
330
0
        double index = std::max(0., std::min(samples.size() - 1., level * (samples.size() - 1)));
331
332
        /// To get the value of a fractional index, we linearly interpolate between neighboring values.
333
0
        auto left_index = static_cast<size_t>(index);
334
0
        size_t right_index = left_index + 1;
335
0
        if (right_index == samples.size()) {
336
0
            return samples[left_index];
337
0
        }
338
339
0
        double left_coef = right_index - index;
340
0
        double right_coef = index - left_index;
341
0
        return samples[left_index] * left_coef + samples[right_index] * right_coef;
342
0
    }
343
344
0
    void merge(const ReservoirSampler& b) {
345
0
        if (sample_count != b.sample_count) {
346
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR,
347
0
                                   "Cannot merge ReservoirSampler's with different sample_count");
348
0
        }
349
0
        sorted = false;
350
351
0
        if (b.total_values <= sample_count) {
352
0
            for (double sample : b.samples) {
353
0
                insert(sample);
354
0
            }
355
0
        } else if (total_values <= sample_count) {
356
0
            Array from = std::move(samples);
357
0
            samples.assign(b.samples.begin(), b.samples.end());
358
0
            total_values = b.total_values;
359
0
            for (double i : from) {
360
0
                insert(i);
361
0
            }
362
0
        } else {
363
            /// Replace every element in our reservoir to the b's reservoir
364
            /// with the probability of b.total_values / (a.total_values + b.total_values)
365
            /// Do it more roughly than true random sampling to save performance.
366
0
            total_values += b.total_values;
367
            /// Will replace every frequency'th element in a to element from b.
368
0
            double frequency = static_cast<double>(total_values) / b.total_values;
369
            /// When frequency is too low, replace just one random element with the corresponding probability.
370
0
            if (frequency * 2 >= sample_count) {
371
0
                uint64_t rnd = gen_random(static_cast<uint64_t>(frequency));
372
0
                if (rnd < sample_count) {
373
0
                    samples[rnd] = b.samples[rnd];
374
0
                }
375
0
            } else {
376
0
                for (double i = 0; i < sample_count; i += frequency) {
377
0
                    auto idx = static_cast<size_t>(i);
378
0
                    samples[idx] = b.samples[idx];
379
0
                }
380
0
            }
381
0
        }
382
0
    }
383
384
0
    void read(vectorized::BufferReadable& buf) {
385
0
        buf.read_binary(sample_count);
386
0
        buf.read_binary(total_values);
387
388
0
        size_t size = std::min(total_values, sample_count);
389
0
        static constexpr size_t MAX_RESERVOIR_SIZE = 1024 * 1024 * 1024; // 1GB
390
0
        if (UNLIKELY(size > MAX_RESERVOIR_SIZE)) {
391
0
            throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Too large array size (maximum: {})",
392
0
                                   MAX_RESERVOIR_SIZE);
393
0
        }
394
395
0
        std::string rng_string;
396
0
        buf.read_binary(rng_string);
397
0
        std::stringstream rng_buf(rng_string);
398
0
        rng_buf >> rng;
399
400
0
        samples.resize(size);
401
0
        for (double& sample : samples) {
402
0
            buf.read_binary(sample);
403
0
        }
404
405
0
        sorted = false;
406
0
    }
407
408
0
    void write(vectorized::BufferWritable& buf) const {
409
0
        buf.write_binary(sample_count);
410
0
        buf.write_binary(total_values);
411
412
0
        std::stringstream rng_buf;
413
0
        rng_buf << rng;
414
0
        buf.write_binary(rng_buf.str());
415
416
0
        for (size_t i = 0; i < std::min(sample_count, total_values); ++i) {
417
0
            buf.write_binary(samples[i]);
418
0
        }
419
0
    }
420
421
private:
422
    /// We allocate a little memory on the stack - to avoid allocations when there are many objects with a small number of elements.
423
    using Array = vectorized::PODArrayWithStackMemory<double, 64>;
424
    using pcg32_fast = mcg_base<uint32_t, uint64_t, xsh_rs_mixin>;
425
426
    size_t sample_count;
427
    size_t total_values = 0;
428
    Array samples;
429
    pcg32_fast rng;
430
    bool sorted = false;
431
432
0
    uint64_t gen_random(uint64_t limit) {
433
        /// With a large number of values, we will generate random numbers several times slower.
434
0
        if (limit <= static_cast<uint64_t>(pcg32_fast::max())) {
435
0
            return rng() % limit;
436
0
        } else {
437
0
            return (static_cast<uint64_t>(rng()) *
438
0
                            (static_cast<uint64_t>(pcg32_fast::max()) + 1ULL) +
439
0
                    static_cast<uint64_t>(rng())) %
440
0
                   limit;
441
0
        }
442
0
    }
443
444
0
    void sortIfNeeded() {
445
0
        if (sorted) {
446
0
            return;
447
0
        }
448
0
        sorted = true;
449
0
        std::sort(samples.begin(), samples.end(), std::less<double>());
450
0
    }
451
};
452
453
} // namespace doris