Coverage Report

Created: 2026-05-13 12:59

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/udf/python/python_udaf_client.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <arrow/record_batch.h>
21
#include <arrow/status.h>
22
23
#include "udf/python/python_client.h"
24
25
namespace doris {
26
27
class PythonUDAFClient;
28
29
using PythonUDAFClientPtr = std::shared_ptr<PythonUDAFClient>;
30
31
// Fixed-size (30 bytes) binary metadata structure for UDAF operations (Request)
32
struct __attribute__((packed)) UDAFMetadata {
33
    uint32_t meta_version;   // 4 bytes: metadata version (current version = 1)
34
    uint8_t operation;       // 1 byte: UDAFOperation enum
35
    uint8_t is_single_place; // 1 byte: boolean (0 or 1, ACCUMULATE only)
36
    int64_t place_id;        // 8 bytes: aggregate state identifier (globally unique)
37
    int64_t row_start;       // 8 bytes: start row index (ACCUMULATE only)
38
    int64_t row_end;         // 8 bytes: end row index (exclusive, ACCUMULATE only)
39
};
40
41
static_assert(sizeof(UDAFMetadata) == 30, "UDAFMetadata size must be 30 bytes");
42
43
// Current metadata version constant
44
constexpr uint32_t UDAF_METADATA_VERSION = 1;
45
46
/**
47
 * Python UDAF Client
48
 * 
49
 * Implements Snowflake-style UDAF pattern with the following methods:
50
 * - __init__(): Initialize aggregate state
51
 * - aggregate_state: Property that returns internal state
52
 * - accumulate(input): Add new input to aggregate state
53
 * - merge(other_state): Combine two intermediate states
54
 * - finish(): Generate final result from aggregate state
55
 * 
56
 * Communication protocol with Python server:
57
 * 1. CREATE: Initialize UDAF class instance and get initial state
58
 * 2. ACCUMULATE: Send input data batch and get updated states
59
 * 3. SERIALIZE: Get serialized state for shuffle/merge
60
 * 4. MERGE: Combine serialized states
61
 * 5. FINALIZE: Get final result from state
62
 * 6. RESET: Reset state to initial value
63
 * 7. DESTROY: Clean up resources
64
 */
65
class PythonUDAFClient : public PythonClient {
66
public:
67
    // UDAF operation types
68
    enum class UDAFOperation : uint8_t {
69
        CREATE = 0,     // Create new aggregate state
70
        ACCUMULATE = 1, // Add input rows to state
71
        SERIALIZE = 2,  // Serialize state for shuffle
72
        MERGE = 3,      // Merge two states
73
        FINALIZE = 4,   // Get final result
74
        RESET = 5,      // Reset state
75
        DESTROY = 6     // Destroy state
76
    };
77
78
3.46k
    PythonUDAFClient() = default;
79
80
3.46k
    ~PythonUDAFClient() override {
81
        // Clean up all remaining states on destruction
82
3.46k
        auto st = close();
83
3.46k
        if (!st.ok()) {
84
            LOG(WARNING) << "Failed to close PythonUDAFClient in destructor: " << st.to_string();
85
0
        }
86
3.46k
    }
87
88
    static Status create(const PythonUDFMeta& func_meta, ProcessPtr process,
89
                         const std::shared_ptr<arrow::Schema>& data_schema,
90
                         PythonUDAFClientPtr* client);
91
92
    /**
93
     * Initialize UDAF client with data schema
94
     * Overrides base class to set _schema before initialization
95
     * @param func_meta Function metadata
96
     * @param process Python process handle
97
     * @param data_schema Arrow schema for UDAF data
98
     * @return Status
99
     */
100
    Status init(const PythonUDFMeta& func_meta, ProcessPtr process,
101
                const std::shared_ptr<arrow::Schema>& data_schema);
102
103
    /**
104
     * Create aggregate state for a place
105
     * @param place_id Unique identifier for the aggregate state
106
     * @return Status
107
     */
108
    Status create(int64_t place_id);
109
110
    /**
111
     * Accumulate input data into aggregate state
112
     * 
113
     * For single-place mode (is_single_place=true):
114
     *   - input RecordBatch contains only data columns
115
     *   - All rows are accumulated to the same place_id
116
     * 
117
     * For multi-place mode (is_single_place=false):
118
     *   - input RecordBatch MUST contain a "places" column (int64) as the last column
119
     *   - The "places" column indicates which place each row belongs to
120
     *   - place_id parameter is ignored (set to 0 by convention)
121
     * 
122
     * @param place_id Aggregate state identifier (used only in single-place mode)
123
     * @param is_single_place Whether all rows go to single place
124
     * @param input Input data batch (must contain "places" column if is_single_place=false)
125
     * @param row_start Start row index
126
     * @param row_end End row index (exclusive)
127
     * @return Status
128
     */
129
    Status accumulate(int64_t place_id, bool is_single_place, const arrow::RecordBatch& input,
130
                      int64_t row_start, int64_t row_end);
131
132
    /**
133
     * Serialize aggregate state for shuffle/merge
134
     * @param place_id Aggregate state identifier
135
     * @param serialized_state Output serialized state
136
     * @return Status
137
     */
138
    Status serialize(int64_t place_id, std::shared_ptr<arrow::Buffer>* serialized_state);
139
140
    /**
141
     * Merge another serialized state into current state
142
     * @param place_id Target aggregate state identifier
143
     * @param serialized_state Serialized state to merge
144
     * @return Status
145
     */
146
    Status merge(int64_t place_id, const std::shared_ptr<arrow::Buffer>& serialized_state);
147
148
    /**
149
     * Get final result from aggregate state
150
     * @param place_id Aggregate state identifier
151
     * @param output Output result
152
     * @return Status
153
     */
154
    Status finalize(int64_t place_id, std::shared_ptr<arrow::RecordBatch>* output);
155
156
    /**
157
     * Reset aggregate state to initial value
158
     * @param place_id Aggregate state identifier
159
     * @return Status
160
     */
161
    Status reset(int64_t place_id);
162
163
    /**
164
     * Destroy aggregate state and free resources
165
     * @param place_id Aggregate state identifier
166
     * @return Status
167
     */
168
    Status destroy(int64_t place_id);
169
170
    /**
171
     * Close client connection and cleanup
172
     * Overrides base class to destroy the tracked place first
173
     * @return Status
174
     */
175
    Status close();
176
177
#ifdef BE_TEST
178
    static Status make_udaf_failure_status_for_test(
179
            const std::shared_ptr<arrow::RecordBatch>& response, const char* operation,
180
            int64_t place_id);
181
#endif
182
183
private:
184
    DISALLOW_COPY_AND_ASSIGN(PythonUDAFClient);
185
186
    static Status make_udaf_failure_status(const std::shared_ptr<arrow::RecordBatch>& response,
187
                                           const char* operation, int64_t place_id);
188
189
    /**
190
     * Send RecordBatch request to Python server with app_metadata
191
     * @param metadata UDAFMetadata structure (will be sent as app_metadata)
192
     * @param request_batch Request RecordBatch (contains data columns + binary_data column)
193
     * @param response_batch Output RecordBatch
194
     * @return Status
195
     */
196
    Status _send_request(const UDAFMetadata& metadata,
197
                         const std::shared_ptr<arrow::RecordBatch>& request_batch,
198
                         std::shared_ptr<arrow::RecordBatch>* response_batch);
199
200
    /**
201
     * Create request batch with data columns (for ACCUMULATE)
202
     * Appends NULL binary_data column to input data batch
203
     */
204
    Status _create_data_request_batch(const arrow::RecordBatch& input_data,
205
                                      std::shared_ptr<arrow::RecordBatch>* out);
206
207
    /**
208
     * Create request batch with binary data (for MERGE)
209
     * Creates NULL data columns + binary_data column
210
     */
211
    Status _create_binary_request_batch(const std::shared_ptr<arrow::Buffer>& binary_data,
212
                                        std::shared_ptr<arrow::RecordBatch>* out);
213
214
    /**
215
     * Get or create empty request batch (for CREATE/SERIALIZE/FINALIZE/RESET/DESTROY)
216
     * All columns are NULL. Cached after first creation for reuse.
217
     */
218
    Status _get_empty_request_batch(std::shared_ptr<arrow::RecordBatch>* out);
219
220
    // Arrow Flight schema: [argument_types..., places: int64, binary_data: binary]
221
    std::shared_ptr<arrow::Schema> _schema;
222
    std::shared_ptr<arrow::RecordBatch> _empty_request_batch;
223
    // Track created state for cleanup
224
    std::optional<int64_t> _created_place_id;
225
    // Thread safety: protect gRPC stream operations
226
    // CRITICAL: gRPC ClientReaderWriter does NOT support concurrent Write() calls
227
    // Even within same thread, multiple pipeline tasks may trigger concurrent operations
228
    // (e.g., normal accumulate() + cleanup destroy() during task finalization)
229
    mutable std::mutex _operation_mutex;
230
};
231
232
} // namespace doris