Coverage Report

Created: 2026-04-01 15:21

next uncovered line (L), next uncovered region (R), next uncovered branch (B)
be/src/runtime/aws_msk_iam_auth.cpp
Line
Count
Source
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#include "runtime/aws_msk_iam_auth.h"
19
20
#include <aws/core/auth/AWSCredentials.h>
21
#include <aws/core/auth/AWSCredentialsProvider.h>
22
#include <aws/core/auth/AWSCredentialsProviderChain.h>
23
#include <aws/core/auth/STSCredentialsProvider.h>
24
#include <aws/core/platform/Environment.h>
25
#include <aws/identity-management/auth/STSAssumeRoleCredentialsProvider.h>
26
#include <aws/sts/STSClient.h>
27
#include <aws/sts/model/AssumeRoleRequest.h>
28
#include <openssl/hmac.h>
29
#include <openssl/sha.h>
30
31
#include <algorithm>
32
#include <chrono>
33
#include <iomanip>
34
#include <sstream>
35
36
#include "common/logging.h"
37
38
namespace doris {
39
40
16
AwsMskIamAuth::AwsMskIamAuth(Config config) : _config(std::move(config)) {
41
16
    _credentials_provider = _create_credentials_provider();
42
16
}
43
44
std::shared_ptr<Aws::Auth::AWSCredentialsProvider> AwsMskIamAuth::_create_provider_from_type(
45
2
        const std::string& provider_type) {
46
2
    std::string provider_upper = provider_type;
47
2
    std::transform(provider_upper.begin(), provider_upper.end(), provider_upper.begin(), ::toupper);
48
49
2
    if (provider_upper == "ENV" || provider_upper == "ENVIRONMENT") {
50
1
        return std::make_shared<Aws::Auth::EnvironmentAWSCredentialsProvider>();
51
1
    } else if (provider_upper == "INSTANCE_PROFILE" || provider_upper == "INSTANCEPROFILE") {
52
1
        return std::make_shared<Aws::Auth::InstanceProfileCredentialsProvider>();
53
1
    } else if (provider_upper == "CONTAINER" || provider_upper == "ECS") {
54
0
        return std::make_shared<Aws::Auth::TaskRoleCredentialsProvider>(
55
0
                Aws::Environment::GetEnv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI").c_str());
56
0
    } else if (provider_upper == "SYSTEM_PROPERTIES" || provider_upper == "SYSTEMPROPERTIES") {
57
0
        return std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>();
58
0
    } else if (provider_upper == "WEB_IDENTITY" || provider_upper == "WEBIDENTITY" ||
59
0
               provider_upper == "WEB_IDENTITY_TOKEN_FILE") {
60
0
        return std::make_shared<Aws::Auth::STSAssumeRoleWebIdentityCredentialsProvider>();
61
0
    } else if (provider_upper == "ANONYMOUS") {
62
0
        return std::make_shared<Aws::Auth::AnonymousAWSCredentialsProvider>();
63
0
    } else if (provider_upper.empty() || provider_upper == "DEFAULT") {
64
0
        return std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>();
65
0
    }
66
67
2
    LOG(WARNING) << "Unknown credentials provider type: " << provider_type
68
0
                 << ", falling back to default credentials provider chain";
69
0
    return std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>();
70
2
}
71
72
std::shared_ptr<Aws::Auth::AWSCredentialsProvider>
73
6
AwsMskIamAuth::_create_assume_role_base_provider() {
74
6
    if (!_config.access_key.empty() && !_config.secret_key.empty()) {
75
1
        Aws::Auth::AWSCredentials base_credentials(_config.access_key, _config.secret_key);
76
1
        return std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(base_credentials);
77
1
    }
78
79
5
    if (!_config.profile_name.empty()) {
80
1
        return std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>(
81
1
                _config.profile_name.c_str());
82
1
    }
83
84
4
    if (!_config.credentials_provider.empty()) {
85
1
        return _create_provider_from_type(_config.credentials_provider);
86
1
    }
87
88
3
    return std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>();
89
4
}
90
91
16
std::shared_ptr<Aws::Auth::AWSCredentialsProvider> AwsMskIamAuth::_create_credentials_provider() {
92
16
    if (!_config.role_arn.empty()) {
93
6
        Aws::Client::ClientConfiguration client_config;
94
6
        if (!_config.region.empty()) {
95
6
            client_config.region = _config.region;
96
6
        }
97
98
6
        auto base_provider = _create_assume_role_base_provider();
99
6
        LOG(INFO) << "Using AWS STS Assume Role: " << _config.role_arn;
100
101
6
        auto sts_client = std::make_shared<Aws::STS::STSClient>(base_provider, client_config);
102
103
6
        Aws::String external_id = _config.external_id.empty()
104
6
                                          ? Aws::String()
105
6
                                          : Aws::String(_config.external_id.c_str());
106
6
        return std::make_shared<Aws::Auth::STSAssumeRoleCredentialsProvider>(
107
6
                _config.role_arn, Aws::String(), external_id,
108
6
                Aws::Auth::DEFAULT_CREDS_LOAD_FREQ_SECONDS, sts_client);
109
6
    }
110
    // 2. Explicit AK/SK credentials (direct access)
111
10
    if (!_config.access_key.empty() && !_config.secret_key.empty()) {
112
1
        LOG(INFO) << "Using explicit AWS credentials (Access Key ID: "
113
1
                  << _config.access_key.substr(0, 4) << "****)";
114
115
1
        Aws::Auth::AWSCredentials credentials(_config.access_key, _config.secret_key);
116
117
1
        return std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(credentials);
118
1
    }
119
    // 3. AWS Profile (reads from ~/.aws/credentials)
120
9
    if (!_config.profile_name.empty()) {
121
1
        LOG(INFO) << "Using AWS Profile: " << _config.profile_name;
122
123
1
        return std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>(
124
1
                _config.profile_name.c_str());
125
1
    }
126
    // 4. Custom Credentials Provider
127
8
    if (!_config.credentials_provider.empty()) {
128
1
        LOG(INFO) << "Using custom credentials provider: " << _config.credentials_provider;
129
1
        return _create_provider_from_type(_config.credentials_provider);
130
1
    }
131
    // No valid credentials configuration found
132
8
    LOG(ERROR) << "AWS MSK IAM authentication requires credentials. Please provide.";
133
7
    return nullptr;
134
8
}
135
136
5
Status AwsMskIamAuth::get_credentials(Aws::Auth::AWSCredentials* credentials) {
137
5
    std::lock_guard<std::mutex> lock(_mutex);
138
139
5
    if (!_credentials_provider) {
140
5
        return Status::InternalError("AWS credentials provider not initialized");
141
5
    }
142
143
    // Refresh if needed
144
0
    if (_should_refresh_credentials()) {
145
0
        _cached_credentials = _credentials_provider->GetAWSCredentials();
146
0
        if (_cached_credentials.GetAWSAccessKeyId().empty()) {
147
0
            return Status::InternalError("Failed to get AWS credentials");
148
0
        }
149
150
        // Set expiry time (assume 1 hour for instance profile, or use the credentials expiration)
151
0
        _credentials_expiry = std::chrono::system_clock::now() + std::chrono::hours(1);
152
153
0
        LOG(INFO) << "Refreshed AWS credentials for MSK IAM authentication";
154
0
    }
155
156
0
    *credentials = _cached_credentials;
157
0
    return Status::OK();
158
0
}
159
160
0
bool AwsMskIamAuth::_should_refresh_credentials() {
161
0
    auto now = std::chrono::system_clock::now();
162
0
    auto refresh_time =
163
0
            _credentials_expiry - std::chrono::milliseconds(_config.token_refresh_margin_ms);
164
0
    return now >= refresh_time || _cached_credentials.GetAWSAccessKeyId().empty();
165
0
}
166
167
Status AwsMskIamAuth::generate_token(const std::string& broker_hostname, std::string* token,
168
5
                                     int64_t* token_lifetime_ms) {
169
5
    Aws::Auth::AWSCredentials credentials;
170
5
    RETURN_IF_ERROR(get_credentials(&credentials));
171
172
0
    std::string timestamp = _get_timestamp();
173
0
    std::string date_stamp = _get_date_stamp(timestamp);
174
175
    // AWS MSK IAM token is a base64-encoded presigned URL
176
    // Reference: https://github.com/aws/aws-msk-iam-sasl-signer-python
177
178
    // Token expiry in seconds (900 seconds = 15 minutes, matching AWS MSK IAM signer reference)
179
0
    static constexpr int TOKEN_EXPIRY_SECONDS = 900;
180
181
    // Build the endpoint URL
182
0
    std::string endpoint_url = "https://kafka." + _config.region + ".amazonaws.com/";
183
184
    // Build credential scope
185
0
    std::string credential_scope =
186
0
            date_stamp + "/" + _config.region + "/kafka-cluster/aws4_request";
187
188
    // Build the canonical query string (sorted alphabetically)
189
    // IMPORTANT: All query parameters must be included in the signature calculation
190
    // Session Token must be in canonical query string if using temporary credentials
191
0
    std::stringstream canonical_query_ss;
192
0
    canonical_query_ss << "Action=kafka-cluster%3AConnect"; // URL-encoded :
193
194
    // Add Algorithm
195
0
    canonical_query_ss << "&X-Amz-Algorithm=AWS4-HMAC-SHA256";
196
197
    // Add Credential
198
0
    std::string credential = std::string(credentials.GetAWSAccessKeyId()) + "/" + credential_scope;
199
0
    canonical_query_ss << "&X-Amz-Credential=" << _url_encode(credential);
200
201
    // Add Date
202
0
    canonical_query_ss << "&X-Amz-Date=" << timestamp;
203
204
    // Add Expires
205
0
    canonical_query_ss << "&X-Amz-Expires=" << TOKEN_EXPIRY_SECONDS;
206
207
    // Add Security Token if present (MUST be before signature calculation)
208
0
    if (!credentials.GetSessionToken().empty()) {
209
0
        canonical_query_ss << "&X-Amz-Security-Token="
210
0
                           << _url_encode(std::string(credentials.GetSessionToken()));
211
0
    }
212
213
    // Add SignedHeaders
214
0
    canonical_query_ss << "&X-Amz-SignedHeaders=host";
215
216
0
    std::string canonical_query_string = canonical_query_ss.str();
217
218
    // Build the canonical headers
219
0
    std::string host = "kafka." + _config.region + ".amazonaws.com";
220
0
    std::string canonical_headers = "host:" + host + "\n";
221
0
    std::string signed_headers = "host";
222
223
    // Build the canonical request
224
0
    std::string method = "GET";
225
0
    std::string uri = "/";
226
0
    std::string payload_hash = _sha256("");
227
228
0
    std::string canonical_request = method + "\n" + uri + "\n" + canonical_query_string + "\n" +
229
0
                                    canonical_headers + "\n" + signed_headers + "\n" + payload_hash;
230
231
    // Build the string to sign
232
0
    std::string algorithm = "AWS4-HMAC-SHA256";
233
0
    std::string canonical_request_hash = _sha256(canonical_request);
234
0
    std::string string_to_sign =
235
0
            algorithm + "\n" + timestamp + "\n" + credential_scope + "\n" + canonical_request_hash;
236
237
    // Calculate signature
238
0
    std::string signing_key = _calculate_signing_key(std::string(credentials.GetAWSSecretKey()),
239
0
                                                     date_stamp, _config.region, "kafka-cluster");
240
0
    std::string signature = _hmac_sha256_hex(signing_key, string_to_sign);
241
242
    // Build the final presigned URL
243
    // All parameters are already in canonical_query_string, just add signature
244
    // Then add User-Agent AFTER signature (not part of signed content, matching reference impl)
245
0
    std::string signed_url = endpoint_url + "?" + canonical_query_string +
246
0
                             "&X-Amz-Signature=" + signature +
247
0
                             "&User-Agent=doris-msk-iam-auth%2F1.0";
248
249
    // Base64url encode the signed URL (without padding)
250
0
    *token = _base64url_encode(signed_url);
251
252
    // Token lifetime in milliseconds
253
0
    *token_lifetime_ms = TOKEN_EXPIRY_SECONDS * 1000;
254
255
0
    VLOG_DEBUG << "Generated AWS MSK IAM token for region: " << _config.region;
256
0
    return Status::OK();
257
5
}
258
259
0
std::string AwsMskIamAuth::_hmac_sha256_hex(const std::string& key, const std::string& data) {
260
0
    std::string raw = _hmac_sha256(key, data);
261
0
    std::stringstream ss;
262
0
    for (unsigned char c : raw) {
263
0
        ss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c);
264
0
    }
265
0
    return ss.str();
266
0
}
267
268
0
std::string AwsMskIamAuth::_url_encode(const std::string& value) {
269
0
    std::ostringstream escaped;
270
0
    escaped.fill('0');
271
0
    escaped << std::hex;
272
273
0
    for (char c : value) {
274
        // Keep alphanumeric and other accepted characters intact
275
0
        if (isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '_' || c == '.' ||
276
0
            c == '~') {
277
0
            escaped << c;
278
0
        } else {
279
            // Any other characters are percent-encoded
280
0
            escaped << std::uppercase;
281
0
            escaped << '%' << std::setw(2) << static_cast<int>(static_cast<unsigned char>(c));
282
0
            escaped << std::nouppercase;
283
0
        }
284
0
    }
285
286
0
    return escaped.str();
287
0
}
288
289
0
std::string AwsMskIamAuth::_base64url_encode(const std::string& input) {
290
    // Standard base64 alphabet
291
0
    static const char* base64_chars =
292
0
            "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
293
294
0
    std::string result;
295
0
    result.reserve(((input.size() + 2) / 3) * 4);
296
297
0
    const unsigned char* bytes = reinterpret_cast<const unsigned char*>(input.c_str());
298
0
    size_t len = input.size();
299
300
0
    for (size_t i = 0; i < len; i += 3) {
301
0
        uint32_t n = static_cast<uint32_t>(bytes[i]) << 16;
302
0
        if (i + 1 < len) n |= static_cast<uint32_t>(bytes[i + 1]) << 8;
303
0
        if (i + 2 < len) n |= static_cast<uint32_t>(bytes[i + 2]);
304
305
0
        result += base64_chars[(n >> 18) & 0x3F];
306
0
        result += base64_chars[(n >> 12) & 0x3F];
307
0
        if (i + 1 < len) result += base64_chars[(n >> 6) & 0x3F];
308
0
        if (i + 2 < len) result += base64_chars[n & 0x3F];
309
0
    }
310
311
    // Convert to URL-safe base64 (replace + with -, / with _)
312
    // and remove padding (=)
313
0
    for (char& c : result) {
314
0
        if (c == '+')
315
0
            c = '-';
316
0
        else if (c == '/')
317
0
            c = '_';
318
0
    }
319
320
0
    return result;
321
0
}
322
323
std::string AwsMskIamAuth::_calculate_signing_key(const std::string& secret_key,
324
                                                  const std::string& date_stamp,
325
                                                  const std::string& region,
326
0
                                                  const std::string& service) {
327
0
    std::string k_secret = "AWS4" + secret_key;
328
0
    std::string k_date = _hmac_sha256(k_secret, date_stamp);
329
0
    std::string k_region = _hmac_sha256(k_date, region);
330
0
    std::string k_service = _hmac_sha256(k_region, service);
331
0
    std::string k_signing = _hmac_sha256(k_service, "aws4_request");
332
0
    return k_signing;
333
0
}
334
335
0
std::string AwsMskIamAuth::_hmac_sha256(const std::string& key, const std::string& data) {
336
0
    unsigned char digest[EVP_MAX_MD_SIZE];
337
0
    unsigned int digest_len = 0;
338
0
    HMAC(EVP_sha256(), key.c_str(), static_cast<int>(key.length()),
339
0
         reinterpret_cast<const unsigned char*>(data.c_str()), data.length(), digest, &digest_len);
340
0
    return {reinterpret_cast<char*>(digest), digest_len};
341
0
}
342
343
0
std::string AwsMskIamAuth::_sha256(const std::string& data) {
344
0
    unsigned char hash[SHA256_DIGEST_LENGTH];
345
0
    SHA256(reinterpret_cast<const unsigned char*>(data.c_str()), data.length(), hash);
346
347
0
    std::stringstream ss;
348
0
    for (unsigned char i : hash) {
349
0
        ss << std::hex << std::setw(2) << std::setfill('0') << (int)i;
350
0
    }
351
0
    return ss.str();
352
0
}
353
354
0
std::string AwsMskIamAuth::_get_timestamp() {
355
0
    auto now = std::chrono::system_clock::now();
356
0
    auto time_t_now = std::chrono::system_clock::to_time_t(now);
357
0
    std::tm tm_now;
358
0
    gmtime_r(&time_t_now, &tm_now);
359
360
0
    std::stringstream ss;
361
0
    ss << std::put_time(&tm_now, "%Y%m%dT%H%M%SZ");
362
0
    return ss.str();
363
0
}
364
365
0
std::string AwsMskIamAuth::_get_date_stamp(const std::string& timestamp) {
366
    // Extract YYYYMMDD from YYYYMMDDTHHMMSSz
367
0
    return timestamp.substr(0, 8);
368
0
}
369
370
// AwsMskIamOAuthCallback implementation
371
372
namespace {
373
// Property keys for AWS MSK IAM authentication
374
constexpr const char* PROP_SECURITY_PROTOCOL = "security.protocol";
375
constexpr const char* PROP_SASL_MECHANISM = "sasl.mechanism";
376
constexpr const char* PROP_AWS_REGION = "aws.region";
377
constexpr const char* PROP_AWS_ACCESS_KEY = "aws.access_key";
378
constexpr const char* PROP_AWS_SECRET_KEY = "aws.secret_key";
379
constexpr const char* PROP_AWS_ROLE_ARN = "aws.role_arn";
380
constexpr const char* PROP_AWS_EXTERNAL_ID = "aws.external_id";
381
constexpr const char* PROP_AWS_PROFILE_NAME = "aws.profile_name";
382
constexpr const char* PROP_AWS_CREDENTIALS_PROVIDER = "aws.credentials_provider";
383
} // namespace
384
385
std::unique_ptr<AwsMskIamOAuthCallback> AwsMskIamOAuthCallback::create_from_properties(
386
        const std::unordered_map<std::string, std::string>& custom_properties,
387
79
        const std::string& brokers) {
388
79
    auto security_protocol_it = custom_properties.find(PROP_SECURITY_PROTOCOL);
389
79
    auto sasl_mechanism_it = custom_properties.find(PROP_SASL_MECHANISM);
390
79
    bool is_sasl_ssl = security_protocol_it != custom_properties.end() &&
391
79
                       security_protocol_it->second == "SASL_SSL";
392
79
    bool is_oauthbearer = sasl_mechanism_it != custom_properties.end() &&
393
79
                          sasl_mechanism_it->second == "OAUTHBEARER";
394
395
79
    if (!is_sasl_ssl || !is_oauthbearer) {
396
77
        return nullptr;
397
77
    }
398
399
    // Extract broker hostname for token generation.
400
2
    std::string broker_hostname = brokers;
401
    // If there are multiple brokers, we use the first one (Refrain : is this ok?)
402
2
    if (broker_hostname.find(',') != std::string::npos) {
403
0
        broker_hostname = broker_hostname.substr(0, broker_hostname.find(','));
404
0
    }
405
    // Remove port if present
406
2
    if (broker_hostname.find(':') != std::string::npos) {
407
2
        broker_hostname = broker_hostname.substr(0, broker_hostname.find(':'));
408
2
    }
409
410
2
    AwsMskIamAuth::Config auth_config;
411
412
2
    auto region_it = custom_properties.find(PROP_AWS_REGION);
413
2
    if (region_it != custom_properties.end()) {
414
2
        auth_config.region = region_it->second;
415
2
    }
416
417
2
    auto access_key_it = custom_properties.find(PROP_AWS_ACCESS_KEY);
418
2
    auto secret_key_it = custom_properties.find(PROP_AWS_SECRET_KEY);
419
2
    if (access_key_it != custom_properties.end() && secret_key_it != custom_properties.end()) {
420
0
        auth_config.access_key = access_key_it->second;
421
0
        auth_config.secret_key = secret_key_it->second;
422
0
        LOG(INFO) << "AWS MSK IAM: using explicit credentials (region: " << auth_config.region
423
0
                  << ")";
424
0
    }
425
426
2
    auto role_arn_it = custom_properties.find(PROP_AWS_ROLE_ARN);
427
2
    if (role_arn_it != custom_properties.end()) {
428
1
        auth_config.role_arn = role_arn_it->second;
429
1
        LOG(INFO) << "AWS MSK IAM: using role " << auth_config.role_arn
430
1
                  << " (region: " << auth_config.region << ")";
431
1
    }
432
433
2
    auto external_id_it = custom_properties.find(PROP_AWS_EXTERNAL_ID);
434
2
    if (external_id_it != custom_properties.end()) {
435
2
        auth_config.external_id = external_id_it->second;
436
2
        LOG(INFO) << "AWS MSK IAM: using external id with role assumption (region: "
437
2
                  << auth_config.region << ")";
438
2
    }
439
440
2
    auto profile_name_it = custom_properties.find(PROP_AWS_PROFILE_NAME);
441
2
    if (profile_name_it != custom_properties.end()) {
442
0
        auth_config.profile_name = profile_name_it->second;
443
0
        LOG(INFO) << "AWS MSK IAM: using profile " << auth_config.profile_name
444
0
                  << " (region: " << auth_config.region << ")";
445
0
    }
446
447
2
    auto credentials_provider_it = custom_properties.find(PROP_AWS_CREDENTIALS_PROVIDER);
448
2
    if (credentials_provider_it != custom_properties.end()) {
449
0
        auth_config.credentials_provider = credentials_provider_it->second;
450
0
        LOG(INFO) << "AWS MSK IAM: using credentials provider " << auth_config.credentials_provider
451
0
                  << " (region: " << auth_config.region << ")";
452
0
    }
453
454
2
    if (!auth_config.external_id.empty() && auth_config.role_arn.empty()) {
455
1
        LOG(ERROR) << "AWS MSK IAM authentication: 'aws.external_id' requires 'aws.role_arn'";
456
1
        return nullptr;
457
1
    }
458
459
    // Validate that at least one credential source is configured
460
1
    bool has_credentials = !auth_config.access_key.empty() || !auth_config.role_arn.empty() ||
461
1
                           !auth_config.profile_name.empty() ||
462
1
                           !auth_config.credentials_provider.empty();
463
464
1
    if (!has_credentials) {
465
0
        LOG(ERROR) << "AWS MSK IAM authentication enabled but no credentials configured. "
466
0
                   << "Please provide one of: access_key/secret_key, role_arn, profile_name, or "
467
0
                      "credentials_provider";
468
0
        return nullptr;
469
0
    }
470
471
1
    LOG(INFO) << "Enabling AWS MSK IAM authentication for broker: " << broker_hostname
472
1
              << ", region: " << auth_config.region;
473
474
1
    auto auth = std::make_shared<AwsMskIamAuth>(auth_config);
475
1
    return std::make_unique<AwsMskIamOAuthCallback>(std::move(auth), std::move(broker_hostname));
476
1
}
477
478
AwsMskIamOAuthCallback::AwsMskIamOAuthCallback(std::shared_ptr<AwsMskIamAuth> auth,
479
                                               std::string broker_hostname)
480
2
        : _auth(std::move(auth)), _broker_hostname(std::move(broker_hostname)) {}
481
482
0
Status AwsMskIamOAuthCallback::refresh_now(RdKafka::Handle* handle) {
483
0
    std::string token;
484
0
    int64_t token_lifetime_ms = 0;
485
486
0
    RETURN_IF_ERROR(_auth->generate_token(_broker_hostname, &token, &token_lifetime_ms));
487
488
0
    std::string principal = "doris-consumer";
489
0
    std::list<std::string> extensions;
490
0
    std::string errstr;
491
492
0
    auto now = std::chrono::system_clock::now();
493
0
    auto now_ms =
494
0
            std::chrono::duration_cast<std::chrono::milliseconds>(now.time_since_epoch()).count();
495
0
    int64_t token_expiry_ms = now_ms + token_lifetime_ms;
496
497
0
    auto err = handle->oauthbearer_set_token(token, token_expiry_ms, principal, extensions, errstr);
498
0
    if (err != RdKafka::ERR_NO_ERROR) {
499
0
        return Status::InternalError("Failed to set OAuth token: {}, detail: {}",
500
0
                                     RdKafka::err2str(err), errstr);
501
0
    }
502
503
0
    LOG(INFO) << "Successfully set AWS MSK IAM OAuth token, lifetime: " << token_lifetime_ms
504
0
              << "ms";
505
0
    return Status::OK();
506
0
}
507
508
void AwsMskIamOAuthCallback::oauthbearer_token_refresh_cb(
509
0
        RdKafka::Handle* handle, const std::string& /*oauthbearer_config*/) {
510
0
    Status st = refresh_now(handle);
511
0
    if (!st.ok()) {
512
        LOG(WARNING) << st;
513
0
        handle->oauthbearer_set_token_failure(st.to_string());
514
0
    }
515
0
}
516
517
} // namespace doris