Coverage Report

Created: 2026-05-22 02:05

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_ema.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
// This file is adapted from
19
// https://github.com/ClickHouse/ClickHouse/blob/master/src/AggregateFunctions/AggregateFunctionExponentialMovingAverage.cpp
20
21
#pragma once
22
23
#include <cmath>
24
#include <memory>
25
26
#include "core/assert_cast.h"
27
#include "core/column/column_vector.h"
28
#include "core/data_type/data_type_number.h"
29
#include "core/types.h"
30
#include "exprs/aggregate/aggregate_function.h"
31
32
namespace doris {
33
class Arena;
34
class BufferReadable;
35
class BufferWritable;
36
class IColumn;
37
38
/**
39
 * Exponentially smoothed moving average over time.
40
 *
41
 * Each value corresponds to a timeunit index. The half_decay parameter is the
42
 * time lag at which exponential weights decay by one-half.
43
 *
44
 * State is a (value, time) pair representing the exponentially accumulated sum
45
 * at a reference time. To get the average, divide by sumWeights(half_decay).
46
 *
47
 * Formula:
48
 *   scale(dt, x) = 2^(-dt/x)
49
 *   sumWeights(x) = 1 / (1 - 2^(-1/x))
50
 *   add(v, t): merge current state with point (v, t)
51
 *   merge(a, b): move both to the later time, then sum values
52
 *   get():  value / sumWeights(half_decay)
53
 *
54
 * Usage: exponential_moving_average(half_decay, value, timeunit)
55
 *   - half_decay: constant double, the half-life period in timeunit units
56
 *   - value:      numeric column to average
57
 *   - timeunit:   numeric time index (not raw timestamp; use intDiv if needed)
58
 * Returns DOUBLE.
59
 */
60
struct ExponentialMovingAverageData {
61
    double value = 0.0;
62
    double time = 0.0;
63
    double half_decay = 0.0;
64
65
38
    static double scale(double time_passed, double hd) { return std::exp2(-time_passed / hd); }
66
67
10
    static double sum_weights(double hd) { return 1.0 / (1.0 - std::exp2(-1.0 / hd)); }
68
69
35
    void add(double new_value, double current_time, double hd) {
70
35
        half_decay = hd;
71
35
        ExponentialMovingAverageData other;
72
35
        other.value = new_value;
73
35
        other.time = current_time;
74
35
        merge_point(other, hd);
75
35
    }
76
77
40
    void merge_point(const ExponentialMovingAverageData& other, double hd) {
78
40
        if (time > other.time) {
79
1
            value = value + other.value * scale(time - other.time, hd);
80
39
        } else if (time < other.time) {
81
37
            value = other.value + value * scale(other.time - time, hd);
82
37
            time = other.time;
83
37
        } else {
84
2
            value = value + other.value;
85
2
        }
86
40
    }
87
88
6
    void merge(const ExponentialMovingAverageData& rhs) {
89
6
        double hd = half_decay != 0.0 ? half_decay : rhs.half_decay;
90
6
        if (hd == 0.0) {
91
1
            return;
92
1
        }
93
5
        half_decay = hd;
94
5
        merge_point(rhs, hd);
95
5
    }
96
97
11
    double get() const {
98
11
        if (half_decay == 0.0) {
99
1
            return 0.0;
100
1
        }
101
10
        return value / sum_weights(half_decay);
102
11
    }
103
104
6
    void write(BufferWritable& buf) const {
105
6
        buf.write_binary(value);
106
6
        buf.write_binary(time);
107
6
        buf.write_binary(half_decay);
108
6
    }
109
110
6
    void read(BufferReadable& buf) {
111
6
        buf.read_binary(value);
112
6
        buf.read_binary(time);
113
6
        buf.read_binary(half_decay);
114
6
    }
115
116
0
    void reset() {
117
0
        value = 0.0;
118
0
        time = 0.0;
119
0
        half_decay = 0.0;
120
0
    }
121
};
122
123
class AggregateFunctionExponentialMovingAverage final
124
        : public IAggregateFunctionDataHelper<ExponentialMovingAverageData,
125
                                              AggregateFunctionExponentialMovingAverage>,
126
          MultiExpression,
127
          NullableAggregateFunction {
128
public:
129
    AggregateFunctionExponentialMovingAverage(const DataTypes& argument_types_)
130
16
            : IAggregateFunctionDataHelper<ExponentialMovingAverageData,
131
16
                                           AggregateFunctionExponentialMovingAverage>(
132
16
                      argument_types_) {}
133
134
1
    String get_name() const override { return "exponential_moving_average"; }
135
136
32
    DataTypePtr get_return_type() const override { return std::make_shared<DataTypeFloat64>(); }
137
138
0
    void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
139
140
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
141
35
             Arena&) const override {
142
35
        const double half_decay =
143
35
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[0])
144
35
                        .get_data()[row_num];
145
35
        const double new_value =
146
35
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[1])
147
35
                        .get_data()[row_num];
148
35
        const double current_time =
149
35
                assert_cast<const ColumnFloat64&, TypeCheckOnRelease::DISABLE>(*columns[2])
150
35
                        .get_data()[row_num];
151
35
        this->data(place).add(new_value, current_time, half_decay);
152
35
    }
153
154
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
155
6
               Arena&) const override {
156
6
        this->data(place).merge(this->data(rhs));
157
6
    }
158
159
6
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
160
6
        this->data(place).write(buf);
161
6
    }
162
163
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
164
6
                     Arena&) const override {
165
6
        this->data(place).read(buf);
166
6
    }
167
168
11
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
169
11
        assert_cast<ColumnFloat64&>(to).get_data().push_back(this->data(place).get());
170
11
    }
171
};
172
173
} // namespace doris