be/src/util/reservoir_sampler.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | // This file is copied from |
18 | | // https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/ReservoirSampler.h |
19 | | // and modified by Doris |
20 | | |
21 | | #pragma once |
22 | | |
23 | | #include <pdqsort.h> |
24 | | |
25 | | #include <cmath> |
26 | | #include <cstddef> |
27 | | #include <functional> |
28 | | #include <limits> |
29 | | |
30 | | #include "core/pod_array_fwd.h" |
31 | | #include "core/string_buffer.hpp" |
32 | | #include "core/types.h" |
33 | | |
34 | | namespace doris { |
35 | | |
36 | | /// Implementing the Reservoir Sampling algorithm. Incrementally selects from the added objects a random subset of the sample_count size. |
37 | | /// Can approximately get quantiles. |
38 | | /// Call `quantile` takes O(sample_count log sample_count), if after the previous call `quantile` there was at least one call `insert`. Otherwise O(1). |
39 | | /// That is, it makes sense to first add, then get quantiles without adding. |
40 | | /* |
41 | | * single stream/sequence (oneseq) |
42 | | */ |
43 | | |
44 | | template <typename T> |
45 | | struct default_multiplier { |
46 | | // Not defined for an arbitrary type |
47 | | }; |
48 | | |
49 | | template <typename T> |
50 | | struct default_increment { |
51 | | // Not defined for an arbitrary type |
52 | | }; |
53 | | |
54 | | #define PCG_DEFINE_CONSTANT(type, what, kind, constant) \ |
55 | | template <> \ |
56 | | struct what##_##kind<type> { \ |
57 | 895k | static constexpr type kind() { \ |
58 | 895k | return constant; \ |
59 | 895k | } \ _ZN5doris18default_multiplierImE10multiplierEv Line | Count | Source | 57 | 895k | static constexpr type kind() { \ | 58 | 895k | return constant; \ | 59 | 895k | } \ |
Unexecuted instantiation: _ZN5doris18default_multiplierIhE10multiplierEv Unexecuted instantiation: _ZN5doris17default_incrementIhE9incrementEv Unexecuted instantiation: _ZN5doris18default_multiplierItE10multiplierEv Unexecuted instantiation: _ZN5doris17default_incrementItE9incrementEv Unexecuted instantiation: _ZN5doris18default_multiplierIjE10multiplierEv Unexecuted instantiation: _ZN5doris17default_incrementIjE9incrementEv Unexecuted instantiation: _ZN5doris17default_incrementImE9incrementEv |
60 | | }; |
61 | | |
62 | | PCG_DEFINE_CONSTANT(uint8_t, default, multiplier, 141U) |
63 | | PCG_DEFINE_CONSTANT(uint8_t, default, increment, 77U) |
64 | | |
65 | | PCG_DEFINE_CONSTANT(uint16_t, default, multiplier, 12829U) |
66 | | PCG_DEFINE_CONSTANT(uint16_t, default, increment, 47989U) |
67 | | |
68 | | PCG_DEFINE_CONSTANT(uint32_t, default, multiplier, 747796405U) |
69 | | PCG_DEFINE_CONSTANT(uint32_t, default, increment, 2891336453U) |
70 | | |
71 | | PCG_DEFINE_CONSTANT(uint64_t, default, multiplier, 6364136223846793005ULL) |
72 | | PCG_DEFINE_CONSTANT(uint64_t, default, increment, 1442695040888963407ULL) |
73 | | |
74 | | template <typename itype> |
75 | | class oneseq_stream : public default_increment<itype> { |
76 | | protected: |
77 | | static constexpr bool is_mcg = false; |
78 | | |
79 | | // Is never called, but is provided for symmetry with specific_stream |
80 | | void set_stream(...) { abort(); } |
81 | | |
82 | | public: |
83 | | typedef itype state_type; |
84 | | |
85 | | static constexpr itype stream() { return default_increment<itype>::increment() >> 1; } |
86 | | |
87 | | static constexpr bool can_specify_stream = false; |
88 | | |
89 | | static constexpr size_t streams_pow2() { return 0U; } |
90 | | |
91 | | protected: |
92 | | constexpr oneseq_stream() = default; |
93 | | }; |
94 | | |
95 | | /* |
96 | | * no stream (mcg) |
97 | | */ |
98 | | |
99 | | template <typename itype> |
100 | | class no_stream { |
101 | | protected: |
102 | | static constexpr bool is_mcg = true; |
103 | | |
104 | | // Is never called, but is provided for symmetry with specific_stream |
105 | 0 | void set_stream(...) { abort(); } |
106 | | |
107 | | public: |
108 | | typedef itype state_type; |
109 | | |
110 | 895k | static constexpr itype increment() { return 0; } |
111 | | |
112 | | static constexpr bool can_specify_stream = false; |
113 | | |
114 | | static constexpr size_t streams_pow2() { return 0U; } |
115 | | |
116 | | protected: |
117 | | constexpr no_stream() = default; |
118 | | }; |
119 | | |
120 | | template <typename xtype, typename itype, typename output_mixin, bool output_previous = true, |
121 | | typename stream_mixin = oneseq_stream<itype>, |
122 | | typename multiplier_mixin = default_multiplier<itype>> |
123 | | class engine : protected output_mixin, public stream_mixin, protected multiplier_mixin { |
124 | | protected: |
125 | | itype state_; |
126 | | |
127 | | struct can_specify_stream_tag {}; |
128 | | struct no_specifiable_stream_tag {}; |
129 | | |
130 | | using stream_mixin::increment; |
131 | | using multiplier_mixin::multiplier; |
132 | | |
133 | | public: |
134 | | typedef xtype result_type; |
135 | | typedef itype state_type; |
136 | | |
137 | | static constexpr size_t period_pow2() { |
138 | | return sizeof(state_type) * 8 - 2 * stream_mixin::is_mcg; |
139 | | } |
140 | | |
141 | | // It would be nice to use std::numeric_limits for these, but |
142 | | // we can't be sure that it'd be defined for the 128-bit types. |
143 | | |
144 | | static constexpr result_type min() { return result_type(0UL); } |
145 | | |
146 | 896k | static constexpr result_type max() { return result_type(~result_type(0UL)); } |
147 | | |
148 | | protected: |
149 | 895k | itype bump(itype state) { return state * multiplier() + increment(); } |
150 | | |
151 | 0 | itype base_generate() { return state_ = bump(state_); } |
152 | | |
153 | 896k | itype base_generate0() { |
154 | 896k | itype old_state = state_; |
155 | 896k | state_ = bump(state_); |
156 | 896k | return old_state; |
157 | 896k | } |
158 | | |
159 | | public: |
160 | 896k | result_type operator()() { |
161 | 896k | if (output_previous) { |
162 | 896k | return this->output(base_generate0()); |
163 | 18.4E | } else { |
164 | 18.4E | return this->output(base_generate()); |
165 | 18.4E | } |
166 | 896k | } |
167 | | |
168 | | result_type operator()(result_type upper_bound) { return bounded_rand(*this, upper_bound); } |
169 | | |
170 | | engine(itype state = itype(0xcafef00dd15ea5e5ULL)) |
171 | 99 | : state_(this->is_mcg ? state | state_type(3U) : bump(state + this->increment())) { |
172 | | // Nothing else to do. |
173 | 99 | } |
174 | | |
175 | | // This function may or may not exist. It thus has to be a template |
176 | | // to use SFINAE; users don't have to worry about its template-ness. |
177 | | |
178 | | template <typename sm = stream_mixin> |
179 | | engine(itype state, typename sm::stream_state stream_seed) |
180 | | : stream_mixin(stream_seed), |
181 | | state_(this->is_mcg ? state | state_type(3U) : bump(state + this->increment())) { |
182 | | // Nothing else to do. |
183 | | } |
184 | | |
185 | | template <typename SeedSeq> |
186 | | engine(SeedSeq&& seedSeq, |
187 | | typename std::enable_if<!stream_mixin::can_specify_stream && |
188 | | !std::is_convertible<SeedSeq, itype>::value && |
189 | | !std::is_convertible<SeedSeq, engine>::value, |
190 | | no_specifiable_stream_tag>::type = {}) |
191 | | : engine(generate_one<itype>(std::forward<SeedSeq>(seedSeq))) { |
192 | | // Nothing else to do. |
193 | | } |
194 | | |
195 | | template <typename SeedSeq> |
196 | | engine(SeedSeq&& seedSeq, |
197 | | typename std::enable_if<stream_mixin::can_specify_stream && |
198 | | !std::is_convertible<SeedSeq, itype>::value && |
199 | | !std::is_convertible<SeedSeq, engine>::value, |
200 | | can_specify_stream_tag>::type = {}) |
201 | | : engine(generate_one<itype, 1, 2>(seedSeq), generate_one<itype, 0, 2>(seedSeq)) { |
202 | | // Nothing else to do. |
203 | | } |
204 | | |
205 | | template <typename... Args> |
206 | 51 | void seed(Args&&... args) { |
207 | 51 | new (this) engine(std::forward<Args>(args)...); |
208 | 51 | } |
209 | | |
210 | | template <typename CharT, typename Traits, typename xtype1, typename itype1, |
211 | | typename output_mixin1, bool output_previous1, typename stream_mixin1, |
212 | | typename multiplier_mixin1> |
213 | | friend std::basic_ostream<CharT, Traits>& operator<<( |
214 | | std::basic_ostream<CharT, Traits>& out, |
215 | | const engine<xtype1, itype1, output_mixin1, output_previous1, stream_mixin1, |
216 | 16 | multiplier_mixin1>& rng) { |
217 | 16 | auto orig_flags = out.flags(std::ios_base::dec | std::ios_base::left); |
218 | 16 | auto space = out.widen(' '); |
219 | 16 | auto orig_fill = out.fill(); |
220 | | |
221 | 16 | out << rng.multiplier() << space << rng.increment() << space << rng.state_; |
222 | | |
223 | 16 | out.flags(orig_flags); |
224 | 16 | out.fill(orig_fill); |
225 | 16 | return out; |
226 | 16 | } |
227 | | |
228 | | template <typename CharT, typename Traits, typename xtype1, typename itype1, |
229 | | typename output_mixin1, bool output_previous1, typename stream_mixin1, |
230 | | typename multiplier_mixin1> |
231 | | friend std::basic_istream<CharT, Traits>& operator>>( |
232 | | std::basic_istream<CharT, Traits>& in, |
233 | | engine<xtype1, itype1, output_mixin1, output_previous1, stream_mixin1, |
234 | 16 | multiplier_mixin1>& rng) { |
235 | 16 | auto orig_flags = in.flags(std::ios_base::dec | std::ios_base::skipws); |
236 | | |
237 | 16 | itype multiplier, increment, state; |
238 | 16 | in >> multiplier >> increment >> state; |
239 | | |
240 | 16 | if (!in.fail()) { |
241 | 16 | bool good = true; |
242 | 16 | if (multiplier != rng.multiplier()) { |
243 | 0 | good = false; |
244 | 16 | } else if (rng.can_specify_stream) { |
245 | 0 | rng.set_stream(increment >> 1); |
246 | 16 | } else if (increment != rng.increment()) { |
247 | 0 | good = false; |
248 | 0 | } |
249 | 16 | if (good) { |
250 | 16 | rng.state_ = state; |
251 | 16 | } else { |
252 | 0 | in.clear(std::ios::failbit); |
253 | 0 | } |
254 | 16 | } |
255 | | |
256 | 16 | in.flags(orig_flags); |
257 | 16 | return in; |
258 | 16 | } |
259 | | }; |
260 | | |
261 | | #ifndef PCG_BITCOUNT_T |
262 | | typedef uint8_t bitcount_t; |
263 | | #else |
264 | | typedef PCG_BITCOUNT_T bitcount_t; |
265 | | #endif |
266 | | |
267 | | template <typename xtype, typename itype> |
268 | | struct xsh_rs_mixin { |
269 | 895k | static xtype output(itype internal) { |
270 | 895k | constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8); |
271 | 895k | constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8); |
272 | 895k | constexpr bitcount_t sparebits = bits - xtypebits; |
273 | 895k | constexpr bitcount_t opbits = sparebits - 5 >= 64 ? 5 |
274 | 895k | : sparebits - 4 >= 32 ? 4 |
275 | 895k | : sparebits - 3 >= 16 ? 3 |
276 | 895k | : sparebits - 2 >= 4 ? 2 |
277 | 895k | : sparebits - 1 >= 1 ? 1 |
278 | 895k | : 0; |
279 | 895k | constexpr bitcount_t mask = (1 << opbits) - 1; |
280 | 895k | constexpr bitcount_t maxrandshift = mask; |
281 | 895k | constexpr bitcount_t topspare = opbits; |
282 | 895k | constexpr bitcount_t bottomspare = sparebits - topspare; |
283 | 895k | constexpr bitcount_t xshift = topspare + (xtypebits + maxrandshift) / 2; |
284 | 18.4E | bitcount_t rshift = opbits ? bitcount_t(internal >> (bits - opbits)) & mask : 0; |
285 | 895k | internal ^= internal >> xshift; |
286 | 895k | auto result = xtype(internal >> (bottomspare - maxrandshift + rshift)); |
287 | 895k | return result; |
288 | 895k | } |
289 | | }; |
290 | | |
291 | | template <typename xtype, typename itype, template <typename XT, typename IT> class output_mixin, |
292 | | bool output_previous = (sizeof(itype) <= 8)> |
293 | | using mcg_base = |
294 | | engine<xtype, itype, output_mixin<xtype, itype>, output_previous, no_stream<itype>>; |
295 | | |
296 | | class ReservoirSampler { |
297 | | public: |
298 | 48 | explicit ReservoirSampler(size_t sample_count_ = 8192) : sample_count(sample_count_) { |
299 | 48 | rng.seed(123456); |
300 | 48 | } |
301 | | |
302 | 913k | void insert(const double v) { |
303 | 913k | if (std::isnan(v)) { |
304 | 0 | return; |
305 | 0 | } |
306 | | |
307 | 913k | sorted = false; |
308 | 913k | ++total_values; |
309 | 913k | if (samples.size() < sample_count) { |
310 | 16.4k | samples.push_back(v); |
311 | 896k | } else { |
312 | 896k | uint64_t rnd = gen_random(total_values); |
313 | 896k | if (rnd < sample_count) { |
314 | 65.3k | samples[rnd] = v; |
315 | 65.3k | } |
316 | 896k | } |
317 | 913k | } |
318 | | |
319 | 123 | void insert_many(const double* values, size_t size) { |
320 | 913k | for (size_t i = 0; i < size; ++i) { |
321 | 912k | insert(values[i]); |
322 | 912k | } |
323 | 123 | } |
324 | | |
325 | 3 | void clear() { |
326 | 3 | samples.clear(); |
327 | 3 | sorted = false; |
328 | 3 | total_values = 0; |
329 | 3 | rng.seed(123456); |
330 | 3 | } |
331 | | |
332 | 23 | double quantileInterpolated(double level) { |
333 | 23 | if (samples.empty()) { |
334 | 0 | return std::numeric_limits<double>::quiet_NaN(); |
335 | 0 | } |
336 | 23 | sortIfNeeded(); |
337 | | |
338 | 23 | double index = std::max(0., std::min(samples.size() - 1., level * (samples.size() - 1))); |
339 | | |
340 | | /// To get the value of a fractional index, we linearly interpolate between neighboring values. |
341 | 23 | auto left_index = static_cast<size_t>(index); |
342 | 23 | size_t right_index = left_index + 1; |
343 | 23 | if (right_index == samples.size()) { |
344 | 1 | return samples[left_index]; |
345 | 1 | } |
346 | | |
347 | 22 | double left_coef = right_index - index; |
348 | 22 | double right_coef = index - left_index; |
349 | 22 | return samples[left_index] * left_coef + samples[right_index] * right_coef; |
350 | 23 | } |
351 | | |
352 | 16 | void merge(const ReservoirSampler& b) { |
353 | 16 | if (sample_count != b.sample_count) { |
354 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, |
355 | 0 | "Cannot merge ReservoirSampler's with different sample_count"); |
356 | 0 | } |
357 | 16 | sorted = false; |
358 | | |
359 | 16 | if (b.total_values <= sample_count) { |
360 | 75 | for (double sample : b.samples) { |
361 | 75 | insert(sample); |
362 | 75 | } |
363 | 14 | } else if (total_values <= sample_count) { |
364 | 1 | Array from = std::move(samples); |
365 | 1 | samples.assign(b.samples.begin(), b.samples.end()); |
366 | 1 | total_values = b.total_values; |
367 | 1 | for (double i : from) { |
368 | 0 | insert(i); |
369 | 0 | } |
370 | 1 | } else { |
371 | | /// Replace every element in our reservoir to the b's reservoir |
372 | | /// with the probability of b.total_values / (a.total_values + b.total_values) |
373 | | /// Do it more roughly than true random sampling to save performance. |
374 | 1 | total_values += b.total_values; |
375 | | /// Will replace every frequency'th element in a to element from b. |
376 | 1 | double frequency = static_cast<double>(total_values) / b.total_values; |
377 | | /// When frequency is too low, replace just one random element with the corresponding probability. |
378 | 1 | if (frequency * 2 >= sample_count) { |
379 | 0 | uint64_t rnd = gen_random(static_cast<uint64_t>(frequency)); |
380 | 0 | if (rnd < sample_count) { |
381 | 0 | samples[rnd] = b.samples[rnd]; |
382 | 0 | } |
383 | 1 | } else { |
384 | 4.11k | for (double i = 0; i < sample_count; i += frequency) { |
385 | 4.11k | auto idx = static_cast<size_t>(i); |
386 | 4.11k | samples[idx] = b.samples[idx]; |
387 | 4.11k | } |
388 | 1 | } |
389 | 1 | } |
390 | 16 | } |
391 | | |
392 | 16 | void read(BufferReadable& buf) { |
393 | 16 | buf.read_binary(sample_count); |
394 | 16 | buf.read_binary(total_values); |
395 | | |
396 | 16 | size_t size = std::min(total_values, sample_count); |
397 | 16 | static constexpr size_t MAX_RESERVOIR_SIZE = 1024 * 1024 * 1024; // 1GB |
398 | 16 | if (UNLIKELY(size > MAX_RESERVOIR_SIZE)) { |
399 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Too large array size (maximum: {})", |
400 | 0 | MAX_RESERVOIR_SIZE); |
401 | 0 | } |
402 | | |
403 | 16 | std::string rng_string; |
404 | 16 | buf.read_binary(rng_string); |
405 | 16 | std::stringstream rng_buf(rng_string); |
406 | 16 | rng_buf >> rng; |
407 | | |
408 | 16 | samples.resize(size); |
409 | 16.4k | for (double& sample : samples) { |
410 | 16.4k | buf.read_binary(sample); |
411 | 16.4k | } |
412 | | |
413 | 16 | sorted = false; |
414 | 16 | } |
415 | | |
416 | 16 | void write(BufferWritable& buf) const { |
417 | 16 | buf.write_binary(sample_count); |
418 | 16 | buf.write_binary(total_values); |
419 | | |
420 | 16 | std::stringstream rng_buf; |
421 | 16 | rng_buf << rng; |
422 | 16 | buf.write_binary(rng_buf.str()); |
423 | | |
424 | 16.4k | for (size_t i = 0; i < std::min(sample_count, total_values); ++i) { |
425 | 16.4k | buf.write_binary(samples[i]); |
426 | 16.4k | } |
427 | 16 | } |
428 | | |
429 | | private: |
430 | | /// We allocate a little memory on the stack - to avoid allocations when there are many objects with a small number of elements. |
431 | | using Array = PODArrayWithStackMemory<double, 64>; |
432 | | using pcg32_fast = mcg_base<uint32_t, uint64_t, xsh_rs_mixin>; |
433 | | |
434 | | size_t sample_count; |
435 | | size_t total_values = 0; |
436 | | Array samples; |
437 | | pcg32_fast rng; |
438 | | bool sorted = false; |
439 | | |
440 | 896k | uint64_t gen_random(uint64_t limit) { |
441 | | /// With a large number of values, we will generate random numbers several times slower. |
442 | 896k | if (limit <= static_cast<uint64_t>(pcg32_fast::max())) { |
443 | 896k | return rng() % limit; |
444 | 896k | } else { |
445 | 396 | return (static_cast<uint64_t>(rng()) * |
446 | 396 | (static_cast<uint64_t>(pcg32_fast::max()) + 1ULL) + |
447 | 396 | static_cast<uint64_t>(rng())) % |
448 | 396 | limit; |
449 | 396 | } |
450 | 896k | } |
451 | | |
452 | 23 | void sortIfNeeded() { |
453 | 23 | if (sorted) { |
454 | 13 | return; |
455 | 13 | } |
456 | 10 | sorted = true; |
457 | 10 | pdqsort(samples.begin(), samples.end(), std::less<double>()); |
458 | 10 | } |
459 | | }; |
460 | | |
461 | | } // namespace doris |