Coverage Report

Created: 2026-03-13 05:13

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/udf/python/python_udaf_client.h
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#pragma once
19
20
#include <arrow/status.h>
21
22
#include "udf/python/python_client.h"
23
24
namespace doris {
25
26
class PythonUDAFClient;
27
28
using PythonUDAFClientPtr = std::shared_ptr<PythonUDAFClient>;
29
30
// Fixed-size (30 bytes) binary metadata structure for UDAF operations (Request)
31
struct __attribute__((packed)) UDAFMetadata {
32
    uint32_t meta_version;   // 4 bytes: metadata version (current version = 1)
33
    uint8_t operation;       // 1 byte: UDAFOperation enum
34
    uint8_t is_single_place; // 1 byte: boolean (0 or 1, ACCUMULATE only)
35
    int64_t place_id;        // 8 bytes: aggregate state identifier (globally unique)
36
    int64_t row_start;       // 8 bytes: start row index (ACCUMULATE only)
37
    int64_t row_end;         // 8 bytes: end row index (exclusive, ACCUMULATE only)
38
};
39
40
static_assert(sizeof(UDAFMetadata) == 30, "UDAFMetadata size must be 30 bytes");
41
42
// Current metadata version constant
43
constexpr uint32_t UDAF_METADATA_VERSION = 1;
44
45
/**
46
 * Python UDAF Client
47
 * 
48
 * Implements Snowflake-style UDAF pattern with the following methods:
49
 * - __init__(): Initialize aggregate state
50
 * - aggregate_state: Property that returns internal state
51
 * - accumulate(input): Add new input to aggregate state
52
 * - merge(other_state): Combine two intermediate states
53
 * - finish(): Generate final result from aggregate state
54
 * 
55
 * Communication protocol with Python server:
56
 * 1. CREATE: Initialize UDAF class instance and get initial state
57
 * 2. ACCUMULATE: Send input data batch and get updated states
58
 * 3. SERIALIZE: Get serialized state for shuffle/merge
59
 * 4. MERGE: Combine serialized states
60
 * 5. FINALIZE: Get final result from state
61
 * 6. RESET: Reset state to initial value
62
 * 7. DESTROY: Clean up resources
63
 */
64
class PythonUDAFClient : public PythonClient {
65
public:
66
    // UDAF operation types
67
    enum class UDAFOperation : uint8_t {
68
        CREATE = 0,     // Create new aggregate state
69
        ACCUMULATE = 1, // Add input rows to state
70
        SERIALIZE = 2,  // Serialize state for shuffle
71
        MERGE = 3,      // Merge two states
72
        FINALIZE = 4,   // Get final result
73
        RESET = 5,      // Reset state
74
        DESTROY = 6     // Destroy state
75
    };
76
77
3.11k
    PythonUDAFClient() = default;
78
79
3.10k
    ~PythonUDAFClient() override {
80
        // Clean up all remaining states on destruction
81
3.10k
        auto st = close();
82
3.10k
        if (!st.ok()) {
83
            LOG(WARNING) << "Failed to close PythonUDAFClient in destructor: " << st.to_string();
84
0
        }
85
3.10k
    }
86
87
    static Status create(const PythonUDFMeta& func_meta, ProcessPtr process,
88
                         const std::shared_ptr<arrow::Schema>& data_schema,
89
                         PythonUDAFClientPtr* client);
90
91
    /**
92
     * Initialize UDAF client with data schema
93
     * Overrides base class to set _schema before initialization
94
     * @param func_meta Function metadata
95
     * @param process Python process handle
96
     * @param data_schema Arrow schema for UDAF data
97
     * @return Status
98
     */
99
    Status init(const PythonUDFMeta& func_meta, ProcessPtr process,
100
                const std::shared_ptr<arrow::Schema>& data_schema);
101
102
    /**
103
     * Create aggregate state for a place
104
     * @param place_id Unique identifier for the aggregate state
105
     * @return Status
106
     */
107
    Status create(int64_t place_id);
108
109
    /**
110
     * Accumulate input data into aggregate state
111
     * 
112
     * For single-place mode (is_single_place=true):
113
     *   - input RecordBatch contains only data columns
114
     *   - All rows are accumulated to the same place_id
115
     * 
116
     * For multi-place mode (is_single_place=false):
117
     *   - input RecordBatch MUST contain a "places" column (int64) as the last column
118
     *   - The "places" column indicates which place each row belongs to
119
     *   - place_id parameter is ignored (set to 0 by convention)
120
     * 
121
     * @param place_id Aggregate state identifier (used only in single-place mode)
122
     * @param is_single_place Whether all rows go to single place
123
     * @param input Input data batch (must contain "places" column if is_single_place=false)
124
     * @param row_start Start row index
125
     * @param row_end End row index (exclusive)
126
     * @return Status
127
     */
128
    Status accumulate(int64_t place_id, bool is_single_place, const arrow::RecordBatch& input,
129
                      int64_t row_start, int64_t row_end);
130
131
    /**
132
     * Serialize aggregate state for shuffle/merge
133
     * @param place_id Aggregate state identifier
134
     * @param serialized_state Output serialized state
135
     * @return Status
136
     */
137
    Status serialize(int64_t place_id, std::shared_ptr<arrow::Buffer>* serialized_state);
138
139
    /**
140
     * Merge another serialized state into current state
141
     * @param place_id Target aggregate state identifier
142
     * @param serialized_state Serialized state to merge
143
     * @return Status
144
     */
145
    Status merge(int64_t place_id, const std::shared_ptr<arrow::Buffer>& serialized_state);
146
147
    /**
148
     * Get final result from aggregate state
149
     * @param place_id Aggregate state identifier
150
     * @param output Output result
151
     * @return Status
152
     */
153
    Status finalize(int64_t place_id, std::shared_ptr<arrow::RecordBatch>* output);
154
155
    /**
156
     * Reset aggregate state to initial value
157
     * @param place_id Aggregate state identifier
158
     * @return Status
159
     */
160
    Status reset(int64_t place_id);
161
162
    /**
163
     * Destroy aggregate state and free resources
164
     * @param place_id Aggregate state identifier
165
     * @return Status
166
     */
167
    Status destroy(int64_t place_id);
168
169
    /**
170
     * Close client connection and cleanup
171
     * Overrides base class to destroy the tracked place first
172
     * @return Status
173
     */
174
    Status close();
175
176
private:
177
    DISALLOW_COPY_AND_ASSIGN(PythonUDAFClient);
178
179
    /**
180
     * Send RecordBatch request to Python server with app_metadata
181
     * @param metadata UDAFMetadata structure (will be sent as app_metadata)
182
     * @param request_batch Request RecordBatch (contains data columns + binary_data column)
183
     * @param response_batch Output RecordBatch
184
     * @return Status
185
     */
186
    Status _send_request(const UDAFMetadata& metadata,
187
                         const std::shared_ptr<arrow::RecordBatch>& request_batch,
188
                         std::shared_ptr<arrow::RecordBatch>* response_batch);
189
190
    /**
191
     * Create request batch with data columns (for ACCUMULATE)
192
     * Appends NULL binary_data column to input data batch
193
     */
194
    Status _create_data_request_batch(const arrow::RecordBatch& input_data,
195
                                      std::shared_ptr<arrow::RecordBatch>* out);
196
197
    /**
198
     * Create request batch with binary data (for MERGE)
199
     * Creates NULL data columns + binary_data column
200
     */
201
    Status _create_binary_request_batch(const std::shared_ptr<arrow::Buffer>& binary_data,
202
                                        std::shared_ptr<arrow::RecordBatch>* out);
203
204
    /**
205
     * Get or create empty request batch (for CREATE/SERIALIZE/FINALIZE/RESET/DESTROY)
206
     * All columns are NULL. Cached after first creation for reuse.
207
     */
208
    Status _get_empty_request_batch(std::shared_ptr<arrow::RecordBatch>* out);
209
210
    // Arrow Flight schema: [argument_types..., places: int64, binary_data: binary]
211
    std::shared_ptr<arrow::Schema> _schema;
212
    std::shared_ptr<arrow::RecordBatch> _empty_request_batch;
213
    // Track created state for cleanup
214
    std::optional<int64_t> _created_place_id;
215
    // Thread safety: protect gRPC stream operations
216
    // CRITICAL: gRPC ClientReaderWriter does NOT support concurrent Write() calls
217
    // Even within same thread, multiple pipeline tasks may trigger concurrent operations
218
    // (e.g., normal accumulate() + cleanup destroy() during task finalization)
219
    mutable std::mutex _operation_mutex;
220
};
221
222
} // namespace doris