be/src/runtime/aws_msk_iam_auth.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "runtime/aws_msk_iam_auth.h" |
19 | | |
20 | | #include <aws/core/auth/AWSCredentials.h> |
21 | | #include <aws/core/auth/AWSCredentialsProvider.h> |
22 | | #include <aws/core/auth/AWSCredentialsProviderChain.h> |
23 | | #include <aws/core/auth/STSCredentialsProvider.h> |
24 | | #include <aws/core/platform/Environment.h> |
25 | | #include <aws/identity-management/auth/STSAssumeRoleCredentialsProvider.h> |
26 | | #include <aws/sts/STSClient.h> |
27 | | #include <aws/sts/model/AssumeRoleRequest.h> |
28 | | #include <openssl/hmac.h> |
29 | | #include <openssl/sha.h> |
30 | | |
31 | | #include <algorithm> |
32 | | #include <chrono> |
33 | | #include <iomanip> |
34 | | #include <sstream> |
35 | | |
36 | | #include "common/logging.h" |
37 | | |
38 | | namespace doris { |
39 | | |
40 | 12 | AwsMskIamAuth::AwsMskIamAuth(Config config) : _config(std::move(config)) { |
41 | 12 | _credentials_provider = _create_credentials_provider(); |
42 | 12 | } |
43 | | |
44 | 12 | std::shared_ptr<Aws::Auth::AWSCredentialsProvider> AwsMskIamAuth::_create_credentials_provider() { |
45 | 12 | if (!_config.role_arn.empty() && !_config.access_key.empty() && !_config.secret_key.empty()) { |
46 | 1 | LOG(INFO) << "Using AWS STS Assume Role with explicit credentials (cross-account): " |
47 | 1 | << _config.role_arn << " (Access Key ID: " << _config.access_key.substr(0, 4) |
48 | 1 | << "****)"; |
49 | | |
50 | 1 | Aws::Client::ClientConfiguration client_config; |
51 | 1 | if (!_config.region.empty()) { |
52 | 1 | client_config.region = _config.region; |
53 | 1 | } |
54 | | |
55 | | // Use explicit AK/SK as base credentials to assume the role |
56 | 1 | Aws::Auth::AWSCredentials base_credentials(_config.access_key, _config.secret_key); |
57 | 1 | auto base_provider = |
58 | 1 | std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(base_credentials); |
59 | | |
60 | 1 | auto sts_client = std::make_shared<Aws::STS::STSClient>(base_provider, client_config); |
61 | | |
62 | 1 | return std::make_shared<Aws::Auth::STSAssumeRoleCredentialsProvider>( |
63 | 1 | _config.role_arn, Aws::String(), /* external_id */ Aws::String(), |
64 | 1 | Aws::Auth::DEFAULT_CREDS_LOAD_FREQ_SECONDS, sts_client); |
65 | 1 | } |
66 | | // 2. Explicit AK/SK credentials (direct access) |
67 | 11 | if (!_config.access_key.empty() && !_config.secret_key.empty()) { |
68 | 1 | LOG(INFO) << "Using explicit AWS credentials (Access Key ID: " |
69 | 1 | << _config.access_key.substr(0, 4) << "****)"; |
70 | | |
71 | 1 | Aws::Auth::AWSCredentials credentials(_config.access_key, _config.secret_key); |
72 | | |
73 | 1 | return std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(credentials); |
74 | 1 | } |
75 | | // 3. Assume Role with Instance Profile (for same-account access from within AWS) |
76 | 10 | if (!_config.role_arn.empty()) { |
77 | 1 | LOG(INFO) << "Using AWS STS Assume Role with Instance Profile: " << _config.role_arn; |
78 | | |
79 | 1 | Aws::Client::ClientConfiguration client_config; |
80 | 1 | if (!_config.region.empty()) { |
81 | 1 | client_config.region = _config.region; |
82 | 1 | } |
83 | | |
84 | 1 | auto sts_client = std::make_shared<Aws::STS::STSClient>( |
85 | 1 | std::make_shared<Aws::Auth::InstanceProfileCredentialsProvider>(), client_config); |
86 | | |
87 | 1 | return std::make_shared<Aws::Auth::STSAssumeRoleCredentialsProvider>( |
88 | 1 | _config.role_arn, Aws::String(), /* external_id */ Aws::String(), |
89 | 1 | Aws::Auth::DEFAULT_CREDS_LOAD_FREQ_SECONDS, sts_client); |
90 | 1 | } |
91 | | // 4. AWS Profile (reads from ~/.aws/credentials) |
92 | 9 | if (!_config.profile_name.empty()) { |
93 | 1 | LOG(INFO) << "Using AWS Profile: " << _config.profile_name; |
94 | | |
95 | 1 | return std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>( |
96 | 1 | _config.profile_name.c_str()); |
97 | 1 | } |
98 | | // 5. Custom Credentials Provider |
99 | 8 | if (!_config.credentials_provider.empty()) { |
100 | 1 | LOG(INFO) << "Using custom credentials provider: " << _config.credentials_provider; |
101 | | |
102 | | // Parse credentials provider type string |
103 | 1 | std::string provider_upper = _config.credentials_provider; |
104 | 1 | std::transform(provider_upper.begin(), provider_upper.end(), provider_upper.begin(), |
105 | 1 | ::toupper); |
106 | | |
107 | 1 | if (provider_upper == "ENV" || provider_upper == "ENVIRONMENT") { |
108 | 0 | return std::make_shared<Aws::Auth::EnvironmentAWSCredentialsProvider>(); |
109 | 1 | } else if (provider_upper == "INSTANCE_PROFILE" || provider_upper == "INSTANCEPROFILE") { |
110 | 1 | return std::make_shared<Aws::Auth::InstanceProfileCredentialsProvider>(); |
111 | 1 | } else if (provider_upper == "CONTAINER" || provider_upper == "ECS") { |
112 | 0 | return std::make_shared<Aws::Auth::TaskRoleCredentialsProvider>( |
113 | 0 | Aws::Environment::GetEnv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI").c_str()); |
114 | 0 | } else if (provider_upper == "DEFAULT") { |
115 | 0 | return std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>(); |
116 | 0 | } else { |
117 | 0 | LOG(WARNING) << "Unknown credentials provider type: " << _config.credentials_provider |
118 | 0 | << ", falling back to default credentials provider chain"; |
119 | 0 | return std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>(); |
120 | 0 | } |
121 | 1 | } |
122 | | // No valid credentials configuration found |
123 | 8 | LOG(ERROR) << "AWS MSK IAM authentication requires credentials. Please provide."; |
124 | 7 | return nullptr; |
125 | 8 | } |
126 | | |
127 | 5 | Status AwsMskIamAuth::get_credentials(Aws::Auth::AWSCredentials* credentials) { |
128 | 5 | std::lock_guard<std::mutex> lock(_mutex); |
129 | | |
130 | 5 | if (!_credentials_provider) { |
131 | 5 | return Status::InternalError("AWS credentials provider not initialized"); |
132 | 5 | } |
133 | | |
134 | | // Refresh if needed |
135 | 0 | if (_should_refresh_credentials()) { |
136 | 0 | _cached_credentials = _credentials_provider->GetAWSCredentials(); |
137 | 0 | if (_cached_credentials.GetAWSAccessKeyId().empty()) { |
138 | 0 | return Status::InternalError("Failed to get AWS credentials"); |
139 | 0 | } |
140 | | |
141 | | // Set expiry time (assume 1 hour for instance profile, or use the credentials expiration) |
142 | 0 | _credentials_expiry = std::chrono::system_clock::now() + std::chrono::hours(1); |
143 | |
|
144 | 0 | LOG(INFO) << "Refreshed AWS credentials for MSK IAM authentication"; |
145 | 0 | } |
146 | | |
147 | 0 | *credentials = _cached_credentials; |
148 | 0 | return Status::OK(); |
149 | 0 | } |
150 | | |
151 | 0 | bool AwsMskIamAuth::_should_refresh_credentials() { |
152 | 0 | auto now = std::chrono::system_clock::now(); |
153 | 0 | auto refresh_time = |
154 | 0 | _credentials_expiry - std::chrono::milliseconds(_config.token_refresh_margin_ms); |
155 | 0 | return now >= refresh_time || _cached_credentials.GetAWSAccessKeyId().empty(); |
156 | 0 | } |
157 | | |
158 | | Status AwsMskIamAuth::generate_token(const std::string& broker_hostname, std::string* token, |
159 | 5 | int64_t* token_lifetime_ms) { |
160 | 5 | Aws::Auth::AWSCredentials credentials; |
161 | 5 | RETURN_IF_ERROR(get_credentials(&credentials)); |
162 | | |
163 | 0 | std::string timestamp = _get_timestamp(); |
164 | 0 | std::string date_stamp = _get_date_stamp(timestamp); |
165 | | |
166 | | // AWS MSK IAM token is a base64-encoded presigned URL |
167 | | // Reference: https://github.com/aws/aws-msk-iam-sasl-signer-python |
168 | | |
169 | | // Token expiry in seconds (900 seconds = 15 minutes, matching AWS MSK IAM signer reference) |
170 | 0 | static constexpr int TOKEN_EXPIRY_SECONDS = 900; |
171 | | |
172 | | // Build the endpoint URL |
173 | 0 | std::string endpoint_url = "https://kafka." + _config.region + ".amazonaws.com/"; |
174 | | |
175 | | // Build credential scope |
176 | 0 | std::string credential_scope = |
177 | 0 | date_stamp + "/" + _config.region + "/kafka-cluster/aws4_request"; |
178 | | |
179 | | // Build the canonical query string (sorted alphabetically) |
180 | | // IMPORTANT: All query parameters must be included in the signature calculation |
181 | | // Session Token must be in canonical query string if using temporary credentials |
182 | 0 | std::stringstream canonical_query_ss; |
183 | 0 | canonical_query_ss << "Action=kafka-cluster%3AConnect"; // URL-encoded : |
184 | | |
185 | | // Add Algorithm |
186 | 0 | canonical_query_ss << "&X-Amz-Algorithm=AWS4-HMAC-SHA256"; |
187 | | |
188 | | // Add Credential |
189 | 0 | std::string credential = std::string(credentials.GetAWSAccessKeyId()) + "/" + credential_scope; |
190 | 0 | canonical_query_ss << "&X-Amz-Credential=" << _url_encode(credential); |
191 | | |
192 | | // Add Date |
193 | 0 | canonical_query_ss << "&X-Amz-Date=" << timestamp; |
194 | | |
195 | | // Add Expires |
196 | 0 | canonical_query_ss << "&X-Amz-Expires=" << TOKEN_EXPIRY_SECONDS; |
197 | | |
198 | | // Add Security Token if present (MUST be before signature calculation) |
199 | 0 | if (!credentials.GetSessionToken().empty()) { |
200 | 0 | canonical_query_ss << "&X-Amz-Security-Token=" |
201 | 0 | << _url_encode(std::string(credentials.GetSessionToken())); |
202 | 0 | } |
203 | | |
204 | | // Add SignedHeaders |
205 | 0 | canonical_query_ss << "&X-Amz-SignedHeaders=host"; |
206 | |
|
207 | 0 | std::string canonical_query_string = canonical_query_ss.str(); |
208 | | |
209 | | // Build the canonical headers |
210 | 0 | std::string host = "kafka." + _config.region + ".amazonaws.com"; |
211 | 0 | std::string canonical_headers = "host:" + host + "\n"; |
212 | 0 | std::string signed_headers = "host"; |
213 | | |
214 | | // Build the canonical request |
215 | 0 | std::string method = "GET"; |
216 | 0 | std::string uri = "/"; |
217 | 0 | std::string payload_hash = _sha256(""); |
218 | |
|
219 | 0 | std::string canonical_request = method + "\n" + uri + "\n" + canonical_query_string + "\n" + |
220 | 0 | canonical_headers + "\n" + signed_headers + "\n" + payload_hash; |
221 | | |
222 | | // Build the string to sign |
223 | 0 | std::string algorithm = "AWS4-HMAC-SHA256"; |
224 | 0 | std::string canonical_request_hash = _sha256(canonical_request); |
225 | 0 | std::string string_to_sign = |
226 | 0 | algorithm + "\n" + timestamp + "\n" + credential_scope + "\n" + canonical_request_hash; |
227 | | |
228 | | // Calculate signature |
229 | 0 | std::string signing_key = _calculate_signing_key(std::string(credentials.GetAWSSecretKey()), |
230 | 0 | date_stamp, _config.region, "kafka-cluster"); |
231 | 0 | std::string signature = _hmac_sha256_hex(signing_key, string_to_sign); |
232 | | |
233 | | // Build the final presigned URL |
234 | | // All parameters are already in canonical_query_string, just add signature |
235 | | // Then add User-Agent AFTER signature (not part of signed content, matching reference impl) |
236 | 0 | std::string signed_url = endpoint_url + "?" + canonical_query_string + |
237 | 0 | "&X-Amz-Signature=" + signature + |
238 | 0 | "&User-Agent=doris-msk-iam-auth%2F1.0"; |
239 | | |
240 | | // Base64url encode the signed URL (without padding) |
241 | 0 | *token = _base64url_encode(signed_url); |
242 | | |
243 | | // Token lifetime in milliseconds |
244 | 0 | *token_lifetime_ms = TOKEN_EXPIRY_SECONDS * 1000; |
245 | |
|
246 | 0 | VLOG_DEBUG << "Generated AWS MSK IAM token for region: " << _config.region; |
247 | 0 | return Status::OK(); |
248 | 5 | } |
249 | | |
250 | 0 | std::string AwsMskIamAuth::_hmac_sha256_hex(const std::string& key, const std::string& data) { |
251 | 0 | std::string raw = _hmac_sha256(key, data); |
252 | 0 | std::stringstream ss; |
253 | 0 | for (unsigned char c : raw) { |
254 | 0 | ss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c); |
255 | 0 | } |
256 | 0 | return ss.str(); |
257 | 0 | } |
258 | | |
259 | 0 | std::string AwsMskIamAuth::_url_encode(const std::string& value) { |
260 | 0 | std::ostringstream escaped; |
261 | 0 | escaped.fill('0'); |
262 | 0 | escaped << std::hex; |
263 | |
|
264 | 0 | for (char c : value) { |
265 | | // Keep alphanumeric and other accepted characters intact |
266 | 0 | if (isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '_' || c == '.' || |
267 | 0 | c == '~') { |
268 | 0 | escaped << c; |
269 | 0 | } else { |
270 | | // Any other characters are percent-encoded |
271 | 0 | escaped << std::uppercase; |
272 | 0 | escaped << '%' << std::setw(2) << static_cast<int>(static_cast<unsigned char>(c)); |
273 | 0 | escaped << std::nouppercase; |
274 | 0 | } |
275 | 0 | } |
276 | |
|
277 | 0 | return escaped.str(); |
278 | 0 | } |
279 | | |
280 | 0 | std::string AwsMskIamAuth::_base64url_encode(const std::string& input) { |
281 | | // Standard base64 alphabet |
282 | 0 | static const char* base64_chars = |
283 | 0 | "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; |
284 | |
|
285 | 0 | std::string result; |
286 | 0 | result.reserve(((input.size() + 2) / 3) * 4); |
287 | |
|
288 | 0 | const unsigned char* bytes = reinterpret_cast<const unsigned char*>(input.c_str()); |
289 | 0 | size_t len = input.size(); |
290 | |
|
291 | 0 | for (size_t i = 0; i < len; i += 3) { |
292 | 0 | uint32_t n = static_cast<uint32_t>(bytes[i]) << 16; |
293 | 0 | if (i + 1 < len) n |= static_cast<uint32_t>(bytes[i + 1]) << 8; |
294 | 0 | if (i + 2 < len) n |= static_cast<uint32_t>(bytes[i + 2]); |
295 | |
|
296 | 0 | result += base64_chars[(n >> 18) & 0x3F]; |
297 | 0 | result += base64_chars[(n >> 12) & 0x3F]; |
298 | 0 | if (i + 1 < len) result += base64_chars[(n >> 6) & 0x3F]; |
299 | 0 | if (i + 2 < len) result += base64_chars[n & 0x3F]; |
300 | 0 | } |
301 | | |
302 | | // Convert to URL-safe base64 (replace + with -, / with _) |
303 | | // and remove padding (=) |
304 | 0 | for (char& c : result) { |
305 | 0 | if (c == '+') |
306 | 0 | c = '-'; |
307 | 0 | else if (c == '/') |
308 | 0 | c = '_'; |
309 | 0 | } |
310 | |
|
311 | 0 | return result; |
312 | 0 | } |
313 | | |
314 | | std::string AwsMskIamAuth::_calculate_signing_key(const std::string& secret_key, |
315 | | const std::string& date_stamp, |
316 | | const std::string& region, |
317 | 0 | const std::string& service) { |
318 | 0 | std::string k_secret = "AWS4" + secret_key; |
319 | 0 | std::string k_date = _hmac_sha256(k_secret, date_stamp); |
320 | 0 | std::string k_region = _hmac_sha256(k_date, region); |
321 | 0 | std::string k_service = _hmac_sha256(k_region, service); |
322 | 0 | std::string k_signing = _hmac_sha256(k_service, "aws4_request"); |
323 | 0 | return k_signing; |
324 | 0 | } |
325 | | |
326 | 0 | std::string AwsMskIamAuth::_hmac_sha256(const std::string& key, const std::string& data) { |
327 | 0 | unsigned char digest[EVP_MAX_MD_SIZE]; |
328 | 0 | unsigned int digest_len = 0; |
329 | 0 | HMAC(EVP_sha256(), key.c_str(), static_cast<int>(key.length()), |
330 | 0 | reinterpret_cast<const unsigned char*>(data.c_str()), data.length(), digest, &digest_len); |
331 | 0 | return {reinterpret_cast<char*>(digest), digest_len}; |
332 | 0 | } |
333 | | |
334 | 0 | std::string AwsMskIamAuth::_sha256(const std::string& data) { |
335 | 0 | unsigned char hash[SHA256_DIGEST_LENGTH]; |
336 | 0 | SHA256(reinterpret_cast<const unsigned char*>(data.c_str()), data.length(), hash); |
337 | |
|
338 | 0 | std::stringstream ss; |
339 | 0 | for (unsigned char i : hash) { |
340 | 0 | ss << std::hex << std::setw(2) << std::setfill('0') << (int)i; |
341 | 0 | } |
342 | 0 | return ss.str(); |
343 | 0 | } |
344 | | |
345 | 0 | std::string AwsMskIamAuth::_get_timestamp() { |
346 | 0 | auto now = std::chrono::system_clock::now(); |
347 | 0 | auto time_t_now = std::chrono::system_clock::to_time_t(now); |
348 | 0 | std::tm tm_now; |
349 | 0 | gmtime_r(&time_t_now, &tm_now); |
350 | |
|
351 | 0 | std::stringstream ss; |
352 | 0 | ss << std::put_time(&tm_now, "%Y%m%dT%H%M%SZ"); |
353 | 0 | return ss.str(); |
354 | 0 | } |
355 | | |
356 | 0 | std::string AwsMskIamAuth::_get_date_stamp(const std::string& timestamp) { |
357 | | // Extract YYYYMMDD from YYYYMMDDTHHMMSSz |
358 | 0 | return timestamp.substr(0, 8); |
359 | 0 | } |
360 | | |
361 | | // AwsMskIamOAuthCallback implementation |
362 | | |
363 | | namespace { |
364 | | // Property keys for AWS MSK IAM authentication |
365 | | constexpr const char* PROP_SECURITY_PROTOCOL = "security.protocol"; |
366 | | constexpr const char* PROP_SASL_MECHANISM = "sasl.mechanism"; |
367 | | constexpr const char* PROP_AWS_REGION = "aws.region"; |
368 | | constexpr const char* PROP_AWS_ACCESS_KEY = "aws.access_key"; |
369 | | constexpr const char* PROP_AWS_SECRET_KEY = "aws.secret_key"; |
370 | | constexpr const char* PROP_AWS_ROLE_ARN = "aws.role_arn"; |
371 | | constexpr const char* PROP_AWS_PROFILE_NAME = "aws.profile_name"; |
372 | | constexpr const char* PROP_AWS_CREDENTIALS_PROVIDER = "aws.credentials_provider"; |
373 | | } // namespace |
374 | | |
375 | | std::unique_ptr<AwsMskIamOAuthCallback> AwsMskIamOAuthCallback::create_from_properties( |
376 | | const std::unordered_map<std::string, std::string>& custom_properties, |
377 | 78 | const std::string& brokers) { |
378 | 78 | auto security_protocol_it = custom_properties.find(PROP_SECURITY_PROTOCOL); |
379 | 78 | auto sasl_mechanism_it = custom_properties.find(PROP_SASL_MECHANISM); |
380 | 78 | bool is_sasl_ssl = security_protocol_it != custom_properties.end() && |
381 | 78 | security_protocol_it->second == "SASL_SSL"; |
382 | 78 | bool is_oauthbearer = sasl_mechanism_it != custom_properties.end() && |
383 | 78 | sasl_mechanism_it->second == "OAUTHBEARER"; |
384 | | |
385 | 78 | if (!is_sasl_ssl || !is_oauthbearer) { |
386 | 78 | return nullptr; |
387 | 78 | } |
388 | | |
389 | | // Extract broker hostname for token generation. |
390 | 0 | std::string broker_hostname = brokers; |
391 | | // If there are multiple brokers, we use the first one (Refrain : is this ok?) |
392 | 0 | if (broker_hostname.find(',') != std::string::npos) { |
393 | 0 | broker_hostname = broker_hostname.substr(0, broker_hostname.find(',')); |
394 | 0 | } |
395 | | // Remove port if present |
396 | 0 | if (broker_hostname.find(':') != std::string::npos) { |
397 | 0 | broker_hostname = broker_hostname.substr(0, broker_hostname.find(':')); |
398 | 0 | } |
399 | |
|
400 | 0 | AwsMskIamAuth::Config auth_config; |
401 | |
|
402 | 0 | auto region_it = custom_properties.find(PROP_AWS_REGION); |
403 | 0 | if (region_it != custom_properties.end()) { |
404 | 0 | auth_config.region = region_it->second; |
405 | 0 | } |
406 | |
|
407 | 0 | auto access_key_it = custom_properties.find(PROP_AWS_ACCESS_KEY); |
408 | 0 | auto secret_key_it = custom_properties.find(PROP_AWS_SECRET_KEY); |
409 | 0 | if (access_key_it != custom_properties.end() && secret_key_it != custom_properties.end()) { |
410 | 0 | auth_config.access_key = access_key_it->second; |
411 | 0 | auth_config.secret_key = secret_key_it->second; |
412 | 0 | LOG(INFO) << "AWS MSK IAM: using explicit credentials (region: " << auth_config.region |
413 | 0 | << ")"; |
414 | 0 | } |
415 | |
|
416 | 0 | auto role_arn_it = custom_properties.find(PROP_AWS_ROLE_ARN); |
417 | 0 | if (role_arn_it != custom_properties.end()) { |
418 | 0 | auth_config.role_arn = role_arn_it->second; |
419 | 0 | LOG(INFO) << "AWS MSK IAM: using role " << auth_config.role_arn |
420 | 0 | << " (region: " << auth_config.region << ")"; |
421 | 0 | } |
422 | |
|
423 | 0 | auto profile_name_it = custom_properties.find(PROP_AWS_PROFILE_NAME); |
424 | 0 | if (profile_name_it != custom_properties.end()) { |
425 | 0 | auth_config.profile_name = profile_name_it->second; |
426 | 0 | LOG(INFO) << "AWS MSK IAM: using profile " << auth_config.profile_name |
427 | 0 | << " (region: " << auth_config.region << ")"; |
428 | 0 | } |
429 | |
|
430 | 0 | auto credentials_provider_it = custom_properties.find(PROP_AWS_CREDENTIALS_PROVIDER); |
431 | 0 | if (credentials_provider_it != custom_properties.end()) { |
432 | 0 | auth_config.credentials_provider = credentials_provider_it->second; |
433 | 0 | LOG(INFO) << "AWS MSK IAM: using credentials provider " << auth_config.credentials_provider |
434 | 0 | << " (region: " << auth_config.region << ")"; |
435 | 0 | } |
436 | | |
437 | | // Validate that at least one credential source is configured |
438 | 0 | bool has_credentials = !auth_config.access_key.empty() || !auth_config.role_arn.empty() || |
439 | 0 | !auth_config.profile_name.empty() || |
440 | 0 | !auth_config.credentials_provider.empty(); |
441 | |
|
442 | 0 | if (!has_credentials) { |
443 | 0 | LOG(ERROR) << "AWS MSK IAM authentication enabled but no credentials configured. " |
444 | 0 | << "Please provide one of: access_key/secret_key, role_arn, profile_name, or " |
445 | 0 | "credentials_provider"; |
446 | 0 | return nullptr; |
447 | 0 | } |
448 | | |
449 | 0 | LOG(INFO) << "Enabling AWS MSK IAM authentication for broker: " << broker_hostname |
450 | 0 | << ", region: " << auth_config.region; |
451 | |
|
452 | 0 | auto auth = std::make_shared<AwsMskIamAuth>(auth_config); |
453 | 0 | return std::make_unique<AwsMskIamOAuthCallback>(std::move(auth), std::move(broker_hostname)); |
454 | 0 | } |
455 | | |
456 | | AwsMskIamOAuthCallback::AwsMskIamOAuthCallback(std::shared_ptr<AwsMskIamAuth> auth, |
457 | | std::string broker_hostname) |
458 | 1 | : _auth(std::move(auth)), _broker_hostname(std::move(broker_hostname)) {} |
459 | | |
460 | 0 | Status AwsMskIamOAuthCallback::refresh_now(RdKafka::Handle* handle) { |
461 | 0 | std::string token; |
462 | 0 | int64_t token_lifetime_ms = 0; |
463 | |
|
464 | 0 | RETURN_IF_ERROR(_auth->generate_token(_broker_hostname, &token, &token_lifetime_ms)); |
465 | | |
466 | 0 | std::string principal = "doris-consumer"; |
467 | 0 | std::list<std::string> extensions; |
468 | 0 | std::string errstr; |
469 | |
|
470 | 0 | auto now = std::chrono::system_clock::now(); |
471 | 0 | auto now_ms = |
472 | 0 | std::chrono::duration_cast<std::chrono::milliseconds>(now.time_since_epoch()).count(); |
473 | 0 | int64_t token_expiry_ms = now_ms + token_lifetime_ms; |
474 | |
|
475 | 0 | auto err = handle->oauthbearer_set_token(token, token_expiry_ms, principal, extensions, errstr); |
476 | 0 | if (err != RdKafka::ERR_NO_ERROR) { |
477 | 0 | return Status::InternalError("Failed to set OAuth token: {}, detail: {}", |
478 | 0 | RdKafka::err2str(err), errstr); |
479 | 0 | } |
480 | | |
481 | 0 | LOG(INFO) << "Successfully set AWS MSK IAM OAuth token, lifetime: " << token_lifetime_ms |
482 | 0 | << "ms"; |
483 | 0 | return Status::OK(); |
484 | 0 | } |
485 | | |
486 | | void AwsMskIamOAuthCallback::oauthbearer_token_refresh_cb( |
487 | 0 | RdKafka::Handle* handle, const std::string& /*oauthbearer_config*/) { |
488 | 0 | Status st = refresh_now(handle); |
489 | 0 | if (!st.ok()) { |
490 | | LOG(WARNING) << st; |
491 | 0 | handle->oauthbearer_set_token_failure(st.to_string()); |
492 | 0 | } |
493 | 0 | } |
494 | | |
495 | | } // namespace doris |