Coverage Report

Created: 2026-03-16 12:03

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/exprs/aggregate/aggregate_function_covar.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <glog/logging.h>
21
22
#include <boost/iterator/iterator_facade.hpp>
23
#include <cstddef>
24
#include <cstdint>
25
#include <memory>
26
27
#include "core/assert_cast.h"
28
#include "core/column/column.h"
29
#include "core/column/column_nullable.h"
30
#include "core/data_type/data_type_decimal.h"
31
#include "core/data_type/data_type_number.h"
32
#include "core/types.h"
33
#include "exprs/aggregate/aggregate_function.h"
34
35
namespace doris {
36
#include "common/compile_check_begin.h"
37
38
class Arena;
39
class BufferReadable;
40
class BufferWritable;
41
template <PrimitiveType T>
42
class ColumnDecimal;
43
template <PrimitiveType T>
44
class ColumnVector;
45
46
template <PrimitiveType T>
47
struct BaseData {
48
764
    BaseData() = default;
49
766
    virtual ~BaseData() = default;
50
174
    static DataTypePtr get_return_type() { return std::make_shared<DataTypeFloat64>(); }
51
52
342
    void write(BufferWritable& buf) const {
53
342
        buf.write_binary(sum_x);
54
342
        buf.write_binary(sum_y);
55
342
        buf.write_binary(sum_xy);
56
342
        buf.write_binary(count);
57
342
    }
58
59
263
    void read(BufferReadable& buf) {
60
263
        buf.read_binary(sum_x);
61
263
        buf.read_binary(sum_y);
62
263
        buf.read_binary(sum_xy);
63
263
        buf.read_binary(count);
64
263
    }
65
66
76
    void reset() {
67
76
        sum_x = 0.0;
68
76
        sum_y = 0.0;
69
76
        sum_xy = 0.0;
70
76
        count = 0;
71
76
    }
72
73
    // Cov(X, Y) = E(XY) - E(X)E(Y)
74
63
    double get_pop_result() const {
75
63
        if (count == 1) {
76
17
            return 0.0;
77
17
        }
78
46
        return sum_xy / (double)count - sum_x * sum_y / ((double)count * (double)count);
79
63
    }
80
81
47
    double get_samp_result() const {
82
47
        return sum_xy / double(count - 1) -
83
47
               sum_x * sum_y / ((double)(count) * ((double)(count - 1)));
84
47
    }
85
86
264
    void merge(const BaseData& rhs) {
87
264
        if (rhs.count == 0) {
88
0
            return;
89
0
        }
90
264
        sum_x += rhs.sum_x;
91
264
        sum_y += rhs.sum_y;
92
264
        sum_xy += rhs.sum_xy;
93
264
        count += rhs.count;
94
264
    }
95
96
404
    void add(const IColumn* column_x, const IColumn* column_y, size_t row_num) {
97
404
        const auto& sources_x = assert_cast<const typename PrimitiveTypeTraits<T>::ColumnType&,
98
404
                                            TypeCheckOnRelease::DISABLE>(*column_x);
99
404
        double source_data_x = double(sources_x.get_data()[row_num]);
100
404
        const auto& sources_y = assert_cast<const typename PrimitiveTypeTraits<T>::ColumnType&,
101
404
                                            TypeCheckOnRelease::DISABLE>(*column_y);
102
404
        double source_data_y = double(sources_y.get_data()[row_num]);
103
104
404
        sum_x += source_data_x;
105
404
        sum_y += source_data_y;
106
404
        sum_xy += source_data_x * source_data_y;
107
404
        count += 1;
108
404
    }
109
110
    double sum_x {};
111
    double sum_y {};
112
    double sum_xy {};
113
    int64_t count {};
114
};
115
116
template <PrimitiveType T>
117
struct PopData : BaseData<T> {
118
53
    static const char* name() { return "covar"; }
119
120
63
    void insert_result_into(IColumn& to) const {
121
63
        auto& col = assert_cast<ColumnFloat64&>(to);
122
63
        col.get_data().push_back(this->get_pop_result());
123
63
    }
124
};
125
126
template <PrimitiveType T>
127
struct SampData : BaseData<T> {
128
37
    static const char* name() { return "covar_samp"; }
129
130
59
    void insert_result_into(IColumn& to) const {
131
59
        auto& col = assert_cast<ColumnFloat64&>(to);
132
59
        if (this->count == 1 || this->count == 0) {
133
12
            col.insert_default();
134
47
        } else {
135
47
            col.get_data().push_back(this->get_samp_result());
136
47
        }
137
59
    }
138
};
139
140
template <typename Data>
141
class AggregateFunctionSampCovariance
142
        : public IAggregateFunctionDataHelper<Data, AggregateFunctionSampCovariance<Data>>,
143
          MultiExpression,
144
          NullableAggregateFunction {
145
public:
146
    AggregateFunctionSampCovariance(const DataTypes& argument_types_)
147
885
            : IAggregateFunctionDataHelper<Data, AggregateFunctionSampCovariance<Data>>(
148
885
                      argument_types_) {}
_ZN5doris31AggregateFunctionSampCovarianceINS_8SampDataILNS_13PrimitiveTypeE9EEEEC2ERKSt6vectorISt10shared_ptrIKNS_9IDataTypeEESaIS9_EE
Line
Count
Source
147
445
            : IAggregateFunctionDataHelper<Data, AggregateFunctionSampCovariance<Data>>(
148
445
                      argument_types_) {}
_ZN5doris31AggregateFunctionSampCovarianceINS_7PopDataILNS_13PrimitiveTypeE9EEEEC2ERKSt6vectorISt10shared_ptrIKNS_9IDataTypeEESaIS9_EE
Line
Count
Source
147
440
            : IAggregateFunctionDataHelper<Data, AggregateFunctionSampCovariance<Data>>(
148
440
                      argument_types_) {}
149
150
90
    String get_name() const override { return Data::name(); }
_ZNK5doris31AggregateFunctionSampCovarianceINS_8SampDataILNS_13PrimitiveTypeE9EEEE8get_nameB5cxx11Ev
Line
Count
Source
150
37
    String get_name() const override { return Data::name(); }
_ZNK5doris31AggregateFunctionSampCovarianceINS_7PopDataILNS_13PrimitiveTypeE9EEEE8get_nameB5cxx11Ev
Line
Count
Source
150
53
    String get_name() const override { return Data::name(); }
151
152
174
    DataTypePtr get_return_type() const override { return Data::get_return_type(); }
_ZNK5doris31AggregateFunctionSampCovarianceINS_8SampDataILNS_13PrimitiveTypeE9EEEE15get_return_typeEv
Line
Count
Source
152
81
    DataTypePtr get_return_type() const override { return Data::get_return_type(); }
_ZNK5doris31AggregateFunctionSampCovarianceINS_7PopDataILNS_13PrimitiveTypeE9EEEE15get_return_typeEv
Line
Count
Source
152
93
    DataTypePtr get_return_type() const override { return Data::get_return_type(); }
153
154
    void add(AggregateDataPtr __restrict place, const IColumn** columns, ssize_t row_num,
155
404
             Arena&) const override {
156
404
        this->data(place).add(columns[0], columns[1], row_num);
157
404
    }
_ZNK5doris31AggregateFunctionSampCovarianceINS_8SampDataILNS_13PrimitiveTypeE9EEEE3addEPcPPKNS_7IColumnElRNS_5ArenaE
Line
Count
Source
155
203
             Arena&) const override {
156
203
        this->data(place).add(columns[0], columns[1], row_num);
157
203
    }
_ZNK5doris31AggregateFunctionSampCovarianceINS_7PopDataILNS_13PrimitiveTypeE9EEEE3addEPcPPKNS_7IColumnElRNS_5ArenaE
Line
Count
Source
155
201
             Arena&) const override {
156
201
        this->data(place).add(columns[0], columns[1], row_num);
157
201
    }
158
159
76
    void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
_ZNK5doris31AggregateFunctionSampCovarianceINS_8SampDataILNS_13PrimitiveTypeE9EEEE5resetEPc
Line
Count
Source
159
38
    void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
_ZNK5doris31AggregateFunctionSampCovarianceINS_7PopDataILNS_13PrimitiveTypeE9EEEE5resetEPc
Line
Count
Source
159
38
    void reset(AggregateDataPtr __restrict place) const override { this->data(place).reset(); }
160
161
    void merge(AggregateDataPtr __restrict place, ConstAggregateDataPtr rhs,
162
264
               Arena&) const override {
163
264
        this->data(place).merge(this->data(rhs));
164
264
    }
_ZNK5doris31AggregateFunctionSampCovarianceINS_8SampDataILNS_13PrimitiveTypeE9EEEE5mergeEPcPKcRNS_5ArenaE
Line
Count
Source
162
129
               Arena&) const override {
163
129
        this->data(place).merge(this->data(rhs));
164
129
    }
_ZNK5doris31AggregateFunctionSampCovarianceINS_7PopDataILNS_13PrimitiveTypeE9EEEE5mergeEPcPKcRNS_5ArenaE
Line
Count
Source
162
135
               Arena&) const override {
163
135
        this->data(place).merge(this->data(rhs));
164
135
    }
165
166
342
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
167
342
        this->data(place).write(buf);
168
342
    }
_ZNK5doris31AggregateFunctionSampCovarianceINS_8SampDataILNS_13PrimitiveTypeE9EEEE9serializeEPKcRNS_14BufferWritableE
Line
Count
Source
166
168
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
167
168
        this->data(place).write(buf);
168
168
    }
_ZNK5doris31AggregateFunctionSampCovarianceINS_7PopDataILNS_13PrimitiveTypeE9EEEE9serializeEPKcRNS_14BufferWritableE
Line
Count
Source
166
174
    void serialize(ConstAggregateDataPtr __restrict place, BufferWritable& buf) const override {
167
174
        this->data(place).write(buf);
168
174
    }
169
170
    void deserialize(AggregateDataPtr __restrict place, BufferReadable& buf,
171
264
                     Arena&) const override {
172
264
        this->data(place).read(buf);
173
264
    }
_ZNK5doris31AggregateFunctionSampCovarianceINS_8SampDataILNS_13PrimitiveTypeE9EEEE11deserializeEPcRNS_14BufferReadableERNS_5ArenaE
Line
Count
Source
171
129
                     Arena&) const override {
172
129
        this->data(place).read(buf);
173
129
    }
_ZNK5doris31AggregateFunctionSampCovarianceINS_7PopDataILNS_13PrimitiveTypeE9EEEE11deserializeEPcRNS_14BufferReadableERNS_5ArenaE
Line
Count
Source
171
135
                     Arena&) const override {
172
135
        this->data(place).read(buf);
173
135
    }
174
175
122
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
176
122
        this->data(place).insert_result_into(to);
177
122
    }
_ZNK5doris31AggregateFunctionSampCovarianceINS_8SampDataILNS_13PrimitiveTypeE9EEEE18insert_result_intoEPKcRNS_7IColumnE
Line
Count
Source
175
59
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
176
59
        this->data(place).insert_result_into(to);
177
59
    }
_ZNK5doris31AggregateFunctionSampCovarianceINS_7PopDataILNS_13PrimitiveTypeE9EEEE18insert_result_intoEPKcRNS_7IColumnE
Line
Count
Source
175
63
    void insert_result_into(ConstAggregateDataPtr __restrict place, IColumn& to) const override {
176
63
        this->data(place).insert_result_into(to);
177
63
    }
178
};
179
180
#include "common/compile_check_end.h"
181
} // namespace doris