/root/doris/be/src/util/reservoir_sampler.h
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | // This file is copied from |
18 | | // https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/ReservoirSampler.h |
19 | | // and modified by Doris |
20 | | |
21 | | #pragma once |
22 | | |
23 | | #include <cmath> |
24 | | #include <cstddef> |
25 | | #include <functional> |
26 | | #include <limits> |
27 | | |
28 | | #include "vec/common/pod_array_fwd.h" |
29 | | #include "vec/common/string_buffer.hpp" |
30 | | #include "vec/core/types.h" |
31 | | |
32 | | namespace doris { |
33 | | |
34 | | /// Implementing the Reservoir Sampling algorithm. Incrementally selects from the added objects a random subset of the sample_count size. |
35 | | /// Can approximately get quantiles. |
36 | | /// Call `quantile` takes O(sample_count log sample_count), if after the previous call `quantile` there was at least one call `insert`. Otherwise O(1). |
37 | | /// That is, it makes sense to first add, then get quantiles without adding. |
38 | | /* |
39 | | * single stream/sequence (oneseq) |
40 | | */ |
41 | | |
42 | | template <typename T> |
43 | | struct default_multiplier { |
44 | | // Not defined for an arbitrary type |
45 | | }; |
46 | | |
47 | | template <typename T> |
48 | | struct default_increment { |
49 | | // Not defined for an arbitrary type |
50 | | }; |
51 | | |
52 | | #define PCG_DEFINE_CONSTANT(type, what, kind, constant) \ |
53 | | template <> \ |
54 | | struct what##_##kind<type> { \ |
55 | 0 | static constexpr type kind() { \ |
56 | 0 | return constant; \ |
57 | 0 | } \ Unexecuted instantiation: _ZN5doris18default_multiplierImE10multiplierEv Unexecuted instantiation: _ZN5doris18default_multiplierIhE10multiplierEv Unexecuted instantiation: _ZN5doris17default_incrementIhE9incrementEv Unexecuted instantiation: _ZN5doris18default_multiplierItE10multiplierEv Unexecuted instantiation: _ZN5doris17default_incrementItE9incrementEv Unexecuted instantiation: _ZN5doris18default_multiplierIjE10multiplierEv Unexecuted instantiation: _ZN5doris17default_incrementIjE9incrementEv Unexecuted instantiation: _ZN5doris17default_incrementImE9incrementEv |
58 | | }; |
59 | | |
60 | | PCG_DEFINE_CONSTANT(uint8_t, default, multiplier, 141U) |
61 | | PCG_DEFINE_CONSTANT(uint8_t, default, increment, 77U) |
62 | | |
63 | | PCG_DEFINE_CONSTANT(uint16_t, default, multiplier, 12829U) |
64 | | PCG_DEFINE_CONSTANT(uint16_t, default, increment, 47989U) |
65 | | |
66 | | PCG_DEFINE_CONSTANT(uint32_t, default, multiplier, 747796405U) |
67 | | PCG_DEFINE_CONSTANT(uint32_t, default, increment, 2891336453U) |
68 | | |
69 | | PCG_DEFINE_CONSTANT(uint64_t, default, multiplier, 6364136223846793005ULL) |
70 | | PCG_DEFINE_CONSTANT(uint64_t, default, increment, 1442695040888963407ULL) |
71 | | |
72 | | template <typename itype> |
73 | | class oneseq_stream : public default_increment<itype> { |
74 | | protected: |
75 | | static constexpr bool is_mcg = false; |
76 | | |
77 | | // Is never called, but is provided for symmetry with specific_stream |
78 | | void set_stream(...) { abort(); } |
79 | | |
80 | | public: |
81 | | typedef itype state_type; |
82 | | |
83 | | static constexpr itype stream() { return default_increment<itype>::increment() >> 1; } |
84 | | |
85 | | static constexpr bool can_specify_stream = false; |
86 | | |
87 | | static constexpr size_t streams_pow2() { return 0U; } |
88 | | |
89 | | protected: |
90 | | constexpr oneseq_stream() = default; |
91 | | }; |
92 | | |
93 | | /* |
94 | | * no stream (mcg) |
95 | | */ |
96 | | |
97 | | template <typename itype> |
98 | | class no_stream { |
99 | | protected: |
100 | | static constexpr bool is_mcg = true; |
101 | | |
102 | | // Is never called, but is provided for symmetry with specific_stream |
103 | 0 | void set_stream(...) { abort(); } |
104 | | |
105 | | public: |
106 | | typedef itype state_type; |
107 | | |
108 | 0 | static constexpr itype increment() { return 0; } |
109 | | |
110 | | static constexpr bool can_specify_stream = false; |
111 | | |
112 | | static constexpr size_t streams_pow2() { return 0U; } |
113 | | |
114 | | protected: |
115 | | constexpr no_stream() = default; |
116 | | }; |
117 | | |
118 | | template <typename xtype, typename itype, typename output_mixin, bool output_previous = true, |
119 | | typename stream_mixin = oneseq_stream<itype>, |
120 | | typename multiplier_mixin = default_multiplier<itype>> |
121 | | class engine : protected output_mixin, public stream_mixin, protected multiplier_mixin { |
122 | | protected: |
123 | | itype state_; |
124 | | |
125 | | struct can_specify_stream_tag {}; |
126 | | struct no_specifiable_stream_tag {}; |
127 | | |
128 | | using stream_mixin::increment; |
129 | | using multiplier_mixin::multiplier; |
130 | | |
131 | | public: |
132 | | typedef xtype result_type; |
133 | | typedef itype state_type; |
134 | | |
135 | | static constexpr size_t period_pow2() { |
136 | | return sizeof(state_type) * 8 - 2 * stream_mixin::is_mcg; |
137 | | } |
138 | | |
139 | | // It would be nice to use std::numeric_limits for these, but |
140 | | // we can't be sure that it'd be defined for the 128-bit types. |
141 | | |
142 | | static constexpr result_type min() { return result_type(0UL); } |
143 | | |
144 | 0 | static constexpr result_type max() { return result_type(~result_type(0UL)); } |
145 | | |
146 | | protected: |
147 | 0 | itype bump(itype state) { return state * multiplier() + increment(); } |
148 | | |
149 | 0 | itype base_generate() { return state_ = bump(state_); } |
150 | | |
151 | 0 | itype base_generate0() { |
152 | 0 | itype old_state = state_; |
153 | 0 | state_ = bump(state_); |
154 | 0 | return old_state; |
155 | 0 | } |
156 | | |
157 | | public: |
158 | 0 | result_type operator()() { |
159 | 0 | if (output_previous) { |
160 | 0 | return this->output(base_generate0()); |
161 | 0 | } else { |
162 | 0 | return this->output(base_generate()); |
163 | 0 | } |
164 | 0 | } |
165 | | |
166 | | result_type operator()(result_type upper_bound) { return bounded_rand(*this, upper_bound); } |
167 | | |
168 | | engine(itype state = itype(0xcafef00dd15ea5e5ULL)) |
169 | 0 | : state_(this->is_mcg ? state | state_type(3U) : bump(state + this->increment())) { |
170 | | // Nothing else to do. |
171 | 0 | } |
172 | | |
173 | | // This function may or may not exist. It thus has to be a template |
174 | | // to use SFINAE; users don't have to worry about its template-ness. |
175 | | |
176 | | template <typename sm = stream_mixin> |
177 | | engine(itype state, typename sm::stream_state stream_seed) |
178 | | : stream_mixin(stream_seed), |
179 | | state_(this->is_mcg ? state | state_type(3U) : bump(state + this->increment())) { |
180 | | // Nothing else to do. |
181 | | } |
182 | | |
183 | | template <typename SeedSeq> |
184 | | engine(SeedSeq&& seedSeq, |
185 | | typename std::enable_if<!stream_mixin::can_specify_stream && |
186 | | !std::is_convertible<SeedSeq, itype>::value && |
187 | | !std::is_convertible<SeedSeq, engine>::value, |
188 | | no_specifiable_stream_tag>::type = {}) |
189 | | : engine(generate_one<itype>(std::forward<SeedSeq>(seedSeq))) { |
190 | | // Nothing else to do. |
191 | | } |
192 | | |
193 | | template <typename SeedSeq> |
194 | | engine(SeedSeq&& seedSeq, |
195 | | typename std::enable_if<stream_mixin::can_specify_stream && |
196 | | !std::is_convertible<SeedSeq, itype>::value && |
197 | | !std::is_convertible<SeedSeq, engine>::value, |
198 | | can_specify_stream_tag>::type = {}) |
199 | | : engine(generate_one<itype, 1, 2>(seedSeq), generate_one<itype, 0, 2>(seedSeq)) { |
200 | | // Nothing else to do. |
201 | | } |
202 | | |
203 | | template <typename... Args> |
204 | 0 | void seed(Args&&... args) { |
205 | 0 | new (this) engine(std::forward<Args>(args)...); |
206 | 0 | } |
207 | | |
208 | | template <typename CharT, typename Traits, typename xtype1, typename itype1, |
209 | | typename output_mixin1, bool output_previous1, typename stream_mixin1, |
210 | | typename multiplier_mixin1> |
211 | | friend std::basic_ostream<CharT, Traits>& operator<<( |
212 | | std::basic_ostream<CharT, Traits>& out, |
213 | | const engine<xtype1, itype1, output_mixin1, output_previous1, stream_mixin1, |
214 | 0 | multiplier_mixin1>& rng) { |
215 | 0 | auto orig_flags = out.flags(std::ios_base::dec | std::ios_base::left); |
216 | 0 | auto space = out.widen(' '); |
217 | 0 | auto orig_fill = out.fill(); |
218 | |
|
219 | 0 | out << rng.multiplier() << space << rng.increment() << space << rng.state_; |
220 | |
|
221 | 0 | out.flags(orig_flags); |
222 | 0 | out.fill(orig_fill); |
223 | 0 | return out; |
224 | 0 | } |
225 | | |
226 | | template <typename CharT, typename Traits, typename xtype1, typename itype1, |
227 | | typename output_mixin1, bool output_previous1, typename stream_mixin1, |
228 | | typename multiplier_mixin1> |
229 | | friend std::basic_istream<CharT, Traits>& operator>>( |
230 | | std::basic_istream<CharT, Traits>& in, |
231 | | engine<xtype1, itype1, output_mixin1, output_previous1, stream_mixin1, |
232 | 0 | multiplier_mixin1>& rng) { |
233 | 0 | auto orig_flags = in.flags(std::ios_base::dec | std::ios_base::skipws); |
234 | |
|
235 | 0 | itype multiplier, increment, state; |
236 | 0 | in >> multiplier >> increment >> state; |
237 | |
|
238 | 0 | if (!in.fail()) { |
239 | 0 | bool good = true; |
240 | 0 | if (multiplier != rng.multiplier()) { |
241 | 0 | good = false; |
242 | 0 | } else if (rng.can_specify_stream) { |
243 | 0 | rng.set_stream(increment >> 1); |
244 | 0 | } else if (increment != rng.increment()) { |
245 | 0 | good = false; |
246 | 0 | } |
247 | 0 | if (good) { |
248 | 0 | rng.state_ = state; |
249 | 0 | } else { |
250 | 0 | in.clear(std::ios::failbit); |
251 | 0 | } |
252 | 0 | } |
253 | |
|
254 | 0 | in.flags(orig_flags); |
255 | 0 | return in; |
256 | 0 | } |
257 | | }; |
258 | | |
259 | | #ifndef PCG_BITCOUNT_T |
260 | | typedef uint8_t bitcount_t; |
261 | | #else |
262 | | typedef PCG_BITCOUNT_T bitcount_t; |
263 | | #endif |
264 | | |
265 | | template <typename xtype, typename itype> |
266 | | struct xsh_rs_mixin { |
267 | 0 | static xtype output(itype internal) { |
268 | 0 | constexpr bitcount_t bits = bitcount_t(sizeof(itype) * 8); |
269 | 0 | constexpr bitcount_t xtypebits = bitcount_t(sizeof(xtype) * 8); |
270 | 0 | constexpr bitcount_t sparebits = bits - xtypebits; |
271 | 0 | constexpr bitcount_t opbits = sparebits - 5 >= 64 ? 5 |
272 | 0 | : sparebits - 4 >= 32 ? 4 |
273 | 0 | : sparebits - 3 >= 16 ? 3 |
274 | 0 | : sparebits - 2 >= 4 ? 2 |
275 | 0 | : sparebits - 1 >= 1 ? 1 |
276 | 0 | : 0; |
277 | 0 | constexpr bitcount_t mask = (1 << opbits) - 1; |
278 | 0 | constexpr bitcount_t maxrandshift = mask; |
279 | 0 | constexpr bitcount_t topspare = opbits; |
280 | 0 | constexpr bitcount_t bottomspare = sparebits - topspare; |
281 | 0 | constexpr bitcount_t xshift = topspare + (xtypebits + maxrandshift) / 2; |
282 | 0 | bitcount_t rshift = opbits ? bitcount_t(internal >> (bits - opbits)) & mask : 0; |
283 | 0 | internal ^= internal >> xshift; |
284 | 0 | auto result = xtype(internal >> (bottomspare - maxrandshift + rshift)); |
285 | 0 | return result; |
286 | 0 | } |
287 | | }; |
288 | | |
289 | | template <typename xtype, typename itype, template <typename XT, typename IT> class output_mixin, |
290 | | bool output_previous = (sizeof(itype) <= 8)> |
291 | | using mcg_base = |
292 | | engine<xtype, itype, output_mixin<xtype, itype>, output_previous, no_stream<itype>>; |
293 | | |
294 | | class ReservoirSampler { |
295 | | public: |
296 | 0 | explicit ReservoirSampler(size_t sample_count_ = 8192) : sample_count(sample_count_) { |
297 | 0 | rng.seed(123456); |
298 | 0 | } |
299 | | |
300 | 0 | void insert(const double v) { |
301 | 0 | if (std::isnan(v)) { |
302 | 0 | return; |
303 | 0 | } |
304 | | |
305 | 0 | sorted = false; |
306 | 0 | ++total_values; |
307 | 0 | if (samples.size() < sample_count) { |
308 | 0 | samples.push_back(v); |
309 | 0 | } else { |
310 | 0 | uint64_t rnd = gen_random(total_values); |
311 | 0 | if (rnd < sample_count) { |
312 | 0 | samples[rnd] = v; |
313 | 0 | } |
314 | 0 | } |
315 | 0 | } |
316 | | |
317 | 0 | void clear() { |
318 | 0 | samples.clear(); |
319 | 0 | sorted = false; |
320 | 0 | total_values = 0; |
321 | 0 | rng.seed(123456); |
322 | 0 | } |
323 | | |
324 | 0 | double quantileInterpolated(double level) { |
325 | 0 | if (samples.empty()) { |
326 | 0 | return std::numeric_limits<double>::quiet_NaN(); |
327 | 0 | } |
328 | 0 | sortIfNeeded(); |
329 | |
|
330 | 0 | double index = std::max(0., std::min(samples.size() - 1., level * (samples.size() - 1))); |
331 | | |
332 | | /// To get the value of a fractional index, we linearly interpolate between neighboring values. |
333 | 0 | auto left_index = static_cast<size_t>(index); |
334 | 0 | size_t right_index = left_index + 1; |
335 | 0 | if (right_index == samples.size()) { |
336 | 0 | return samples[left_index]; |
337 | 0 | } |
338 | | |
339 | 0 | double left_coef = right_index - index; |
340 | 0 | double right_coef = index - left_index; |
341 | 0 | return samples[left_index] * left_coef + samples[right_index] * right_coef; |
342 | 0 | } |
343 | | |
344 | 0 | void merge(const ReservoirSampler& b) { |
345 | 0 | if (sample_count != b.sample_count) { |
346 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, |
347 | 0 | "Cannot merge ReservoirSampler's with different sample_count"); |
348 | 0 | } |
349 | 0 | sorted = false; |
350 | |
|
351 | 0 | if (b.total_values <= sample_count) { |
352 | 0 | for (double sample : b.samples) { |
353 | 0 | insert(sample); |
354 | 0 | } |
355 | 0 | } else if (total_values <= sample_count) { |
356 | 0 | Array from = std::move(samples); |
357 | 0 | samples.assign(b.samples.begin(), b.samples.end()); |
358 | 0 | total_values = b.total_values; |
359 | 0 | for (double i : from) { |
360 | 0 | insert(i); |
361 | 0 | } |
362 | 0 | } else { |
363 | | /// Replace every element in our reservoir to the b's reservoir |
364 | | /// with the probability of b.total_values / (a.total_values + b.total_values) |
365 | | /// Do it more roughly than true random sampling to save performance. |
366 | 0 | total_values += b.total_values; |
367 | | /// Will replace every frequency'th element in a to element from b. |
368 | 0 | double frequency = static_cast<double>(total_values) / b.total_values; |
369 | | /// When frequency is too low, replace just one random element with the corresponding probability. |
370 | 0 | if (frequency * 2 >= sample_count) { |
371 | 0 | uint64_t rnd = gen_random(static_cast<uint64_t>(frequency)); |
372 | 0 | if (rnd < sample_count) { |
373 | 0 | samples[rnd] = b.samples[rnd]; |
374 | 0 | } |
375 | 0 | } else { |
376 | 0 | for (double i = 0; i < sample_count; i += frequency) { |
377 | 0 | auto idx = static_cast<size_t>(i); |
378 | 0 | samples[idx] = b.samples[idx]; |
379 | 0 | } |
380 | 0 | } |
381 | 0 | } |
382 | 0 | } |
383 | | |
384 | 0 | void read(vectorized::BufferReadable& buf) { |
385 | 0 | buf.read_binary(sample_count); |
386 | 0 | buf.read_binary(total_values); |
387 | |
|
388 | 0 | size_t size = std::min(total_values, sample_count); |
389 | 0 | static constexpr size_t MAX_RESERVOIR_SIZE = 1024 * 1024 * 1024; // 1GB |
390 | 0 | if (UNLIKELY(size > MAX_RESERVOIR_SIZE)) { |
391 | 0 | throw doris::Exception(ErrorCode::INTERNAL_ERROR, "Too large array size (maximum: {})", |
392 | 0 | MAX_RESERVOIR_SIZE); |
393 | 0 | } |
394 | | |
395 | 0 | std::string rng_string; |
396 | 0 | buf.read_binary(rng_string); |
397 | 0 | std::stringstream rng_buf(rng_string); |
398 | 0 | rng_buf >> rng; |
399 | |
|
400 | 0 | samples.resize(size); |
401 | 0 | for (double& sample : samples) { |
402 | 0 | buf.read_binary(sample); |
403 | 0 | } |
404 | |
|
405 | 0 | sorted = false; |
406 | 0 | } |
407 | | |
408 | 0 | void write(vectorized::BufferWritable& buf) const { |
409 | 0 | buf.write_binary(sample_count); |
410 | 0 | buf.write_binary(total_values); |
411 | |
|
412 | 0 | std::stringstream rng_buf; |
413 | 0 | rng_buf << rng; |
414 | 0 | buf.write_binary(rng_buf.str()); |
415 | |
|
416 | 0 | for (size_t i = 0; i < std::min(sample_count, total_values); ++i) { |
417 | 0 | buf.write_binary(samples[i]); |
418 | 0 | } |
419 | 0 | } |
420 | | |
421 | | private: |
422 | | /// We allocate a little memory on the stack - to avoid allocations when there are many objects with a small number of elements. |
423 | | using Array = vectorized::PODArrayWithStackMemory<double, 64>; |
424 | | using pcg32_fast = mcg_base<uint32_t, uint64_t, xsh_rs_mixin>; |
425 | | |
426 | | size_t sample_count; |
427 | | size_t total_values = 0; |
428 | | Array samples; |
429 | | pcg32_fast rng; |
430 | | bool sorted = false; |
431 | | |
432 | 0 | uint64_t gen_random(uint64_t limit) { |
433 | | /// With a large number of values, we will generate random numbers several times slower. |
434 | 0 | if (limit <= static_cast<uint64_t>(pcg32_fast::max())) { |
435 | 0 | return rng() % limit; |
436 | 0 | } else { |
437 | 0 | return (static_cast<uint64_t>(rng()) * |
438 | 0 | (static_cast<uint64_t>(pcg32_fast::max()) + 1ULL) + |
439 | 0 | static_cast<uint64_t>(rng())) % |
440 | 0 | limit; |
441 | 0 | } |
442 | 0 | } |
443 | | |
444 | 0 | void sortIfNeeded() { |
445 | 0 | if (sorted) { |
446 | 0 | return; |
447 | 0 | } |
448 | 0 | sorted = true; |
449 | 0 | std::sort(samples.begin(), samples.end(), std::less<double>()); |
450 | 0 | } |
451 | | }; |
452 | | |
453 | | } // namespace doris |