be/src/runtime/aws_msk_iam_auth.cpp
Line | Count | Source |
1 | | // Licensed to the Apache Software Foundation (ASF) under one |
2 | | // or more contributor license agreements. See the NOTICE file |
3 | | // distributed with this work for additional information |
4 | | // regarding copyright ownership. The ASF licenses this file |
5 | | // to you under the Apache License, Version 2.0 (the |
6 | | // "License"); you may not use this file except in compliance |
7 | | // with the License. You may obtain a copy of the License at |
8 | | // |
9 | | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | | // |
11 | | // Unless required by applicable law or agreed to in writing, |
12 | | // software distributed under the License is distributed on an |
13 | | // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
14 | | // KIND, either express or implied. See the License for the |
15 | | // specific language governing permissions and limitations |
16 | | // under the License. |
17 | | |
18 | | #include "runtime/aws_msk_iam_auth.h" |
19 | | |
20 | | #include <aws/core/auth/AWSCredentials.h> |
21 | | #include <aws/core/auth/AWSCredentialsProvider.h> |
22 | | #include <aws/core/auth/AWSCredentialsProviderChain.h> |
23 | | #include <aws/core/auth/STSCredentialsProvider.h> |
24 | | #include <aws/core/platform/Environment.h> |
25 | | #include <aws/identity-management/auth/STSAssumeRoleCredentialsProvider.h> |
26 | | #include <aws/sts/STSClient.h> |
27 | | #include <aws/sts/model/AssumeRoleRequest.h> |
28 | | #include <openssl/hmac.h> |
29 | | #include <openssl/sha.h> |
30 | | |
31 | | #include <algorithm> |
32 | | #include <chrono> |
33 | | #include <iomanip> |
34 | | #include <sstream> |
35 | | |
36 | | #include "common/logging.h" |
37 | | |
38 | | namespace doris { |
39 | | |
40 | 16 | AwsMskIamAuth::AwsMskIamAuth(Config config) : _config(std::move(config)) { |
41 | 16 | _credentials_provider = _create_credentials_provider(); |
42 | 16 | } |
43 | | |
44 | | std::shared_ptr<Aws::Auth::AWSCredentialsProvider> AwsMskIamAuth::_create_provider_from_type( |
45 | 2 | const std::string& provider_type) { |
46 | 2 | std::string provider_upper = provider_type; |
47 | 2 | std::transform(provider_upper.begin(), provider_upper.end(), provider_upper.begin(), ::toupper); |
48 | | |
49 | 2 | if (provider_upper == "ENV" || provider_upper == "ENVIRONMENT") { |
50 | 1 | return std::make_shared<Aws::Auth::EnvironmentAWSCredentialsProvider>(); |
51 | 1 | } else if (provider_upper == "INSTANCE_PROFILE" || provider_upper == "INSTANCEPROFILE") { |
52 | 1 | return std::make_shared<Aws::Auth::InstanceProfileCredentialsProvider>(); |
53 | 1 | } else if (provider_upper == "CONTAINER" || provider_upper == "ECS") { |
54 | 0 | return std::make_shared<Aws::Auth::TaskRoleCredentialsProvider>( |
55 | 0 | Aws::Environment::GetEnv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI").c_str()); |
56 | 0 | } else if (provider_upper == "SYSTEM_PROPERTIES" || provider_upper == "SYSTEMPROPERTIES") { |
57 | 0 | return std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>(); |
58 | 0 | } else if (provider_upper == "WEB_IDENTITY" || provider_upper == "WEBIDENTITY" || |
59 | 0 | provider_upper == "WEB_IDENTITY_TOKEN_FILE") { |
60 | 0 | return std::make_shared<Aws::Auth::STSAssumeRoleWebIdentityCredentialsProvider>(); |
61 | 0 | } else if (provider_upper == "ANONYMOUS") { |
62 | 0 | return std::make_shared<Aws::Auth::AnonymousAWSCredentialsProvider>(); |
63 | 0 | } else if (provider_upper.empty() || provider_upper == "DEFAULT") { |
64 | 0 | return std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>(); |
65 | 0 | } |
66 | | |
67 | 2 | LOG(WARNING) << "Unknown credentials provider type: " << provider_type |
68 | 0 | << ", falling back to default credentials provider chain"; |
69 | 0 | return std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>(); |
70 | 2 | } |
71 | | |
72 | | std::shared_ptr<Aws::Auth::AWSCredentialsProvider> |
73 | 6 | AwsMskIamAuth::_create_assume_role_base_provider() { |
74 | 6 | if (!_config.access_key.empty() && !_config.secret_key.empty()) { |
75 | 1 | Aws::Auth::AWSCredentials base_credentials(_config.access_key, _config.secret_key); |
76 | 1 | return std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(base_credentials); |
77 | 1 | } |
78 | | |
79 | 5 | if (!_config.profile_name.empty()) { |
80 | 1 | return std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>( |
81 | 1 | _config.profile_name.c_str()); |
82 | 1 | } |
83 | | |
84 | 4 | if (!_config.credentials_provider.empty()) { |
85 | 1 | return _create_provider_from_type(_config.credentials_provider); |
86 | 1 | } |
87 | | |
88 | 3 | return std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>(); |
89 | 4 | } |
90 | | |
91 | 16 | std::shared_ptr<Aws::Auth::AWSCredentialsProvider> AwsMskIamAuth::_create_credentials_provider() { |
92 | 16 | if (!_config.role_arn.empty()) { |
93 | 6 | Aws::Client::ClientConfiguration client_config; |
94 | 6 | if (!_config.region.empty()) { |
95 | 6 | client_config.region = _config.region; |
96 | 6 | } |
97 | | |
98 | 6 | auto base_provider = _create_assume_role_base_provider(); |
99 | 6 | LOG(INFO) << "Using AWS STS Assume Role: " << _config.role_arn; |
100 | | |
101 | 6 | auto sts_client = std::make_shared<Aws::STS::STSClient>(base_provider, client_config); |
102 | | |
103 | 6 | Aws::String external_id = _config.external_id.empty() |
104 | 6 | ? Aws::String() |
105 | 6 | : Aws::String(_config.external_id.c_str()); |
106 | 6 | return std::make_shared<Aws::Auth::STSAssumeRoleCredentialsProvider>( |
107 | 6 | _config.role_arn, Aws::String(), external_id, |
108 | 6 | Aws::Auth::DEFAULT_CREDS_LOAD_FREQ_SECONDS, sts_client); |
109 | 6 | } |
110 | | // 2. Explicit AK/SK credentials (direct access) |
111 | 10 | if (!_config.access_key.empty() && !_config.secret_key.empty()) { |
112 | 1 | LOG(INFO) << "Using explicit AWS credentials (Access Key ID: " |
113 | 1 | << _config.access_key.substr(0, 4) << "****)"; |
114 | | |
115 | 1 | Aws::Auth::AWSCredentials credentials(_config.access_key, _config.secret_key); |
116 | | |
117 | 1 | return std::make_shared<Aws::Auth::SimpleAWSCredentialsProvider>(credentials); |
118 | 1 | } |
119 | | // 3. AWS Profile (reads from ~/.aws/credentials) |
120 | 9 | if (!_config.profile_name.empty()) { |
121 | 1 | LOG(INFO) << "Using AWS Profile: " << _config.profile_name; |
122 | | |
123 | 1 | return std::make_shared<Aws::Auth::ProfileConfigFileAWSCredentialsProvider>( |
124 | 1 | _config.profile_name.c_str()); |
125 | 1 | } |
126 | | // 4. Custom Credentials Provider |
127 | 8 | if (!_config.credentials_provider.empty()) { |
128 | 1 | LOG(INFO) << "Using custom credentials provider: " << _config.credentials_provider; |
129 | 1 | return _create_provider_from_type(_config.credentials_provider); |
130 | 1 | } |
131 | | // No valid credentials configuration found |
132 | 8 | LOG(ERROR) << "AWS MSK IAM authentication requires credentials. Please provide."; |
133 | 7 | return nullptr; |
134 | 8 | } |
135 | | |
136 | 5 | Status AwsMskIamAuth::get_credentials(Aws::Auth::AWSCredentials* credentials) { |
137 | 5 | std::lock_guard<std::mutex> lock(_mutex); |
138 | | |
139 | 5 | if (!_credentials_provider) { |
140 | 5 | return Status::InternalError("AWS credentials provider not initialized"); |
141 | 5 | } |
142 | | |
143 | | // Refresh if needed |
144 | 0 | if (_should_refresh_credentials()) { |
145 | 0 | _cached_credentials = _credentials_provider->GetAWSCredentials(); |
146 | 0 | if (_cached_credentials.GetAWSAccessKeyId().empty()) { |
147 | 0 | return Status::InternalError("Failed to get AWS credentials"); |
148 | 0 | } |
149 | | |
150 | | // Set expiry time (assume 1 hour for instance profile, or use the credentials expiration) |
151 | 0 | _credentials_expiry = std::chrono::system_clock::now() + std::chrono::hours(1); |
152 | |
|
153 | 0 | LOG(INFO) << "Refreshed AWS credentials for MSK IAM authentication"; |
154 | 0 | } |
155 | | |
156 | 0 | *credentials = _cached_credentials; |
157 | 0 | return Status::OK(); |
158 | 0 | } |
159 | | |
160 | 0 | bool AwsMskIamAuth::_should_refresh_credentials() { |
161 | 0 | auto now = std::chrono::system_clock::now(); |
162 | 0 | auto refresh_time = |
163 | 0 | _credentials_expiry - std::chrono::milliseconds(_config.token_refresh_margin_ms); |
164 | 0 | return now >= refresh_time || _cached_credentials.GetAWSAccessKeyId().empty(); |
165 | 0 | } |
166 | | |
167 | | Status AwsMskIamAuth::generate_token(const std::string& broker_hostname, std::string* token, |
168 | 5 | int64_t* token_lifetime_ms) { |
169 | 5 | Aws::Auth::AWSCredentials credentials; |
170 | 5 | RETURN_IF_ERROR(get_credentials(&credentials)); |
171 | | |
172 | 0 | std::string timestamp = _get_timestamp(); |
173 | 0 | std::string date_stamp = _get_date_stamp(timestamp); |
174 | | |
175 | | // AWS MSK IAM token is a base64-encoded presigned URL |
176 | | // Reference: https://github.com/aws/aws-msk-iam-sasl-signer-python |
177 | | |
178 | | // Token expiry in seconds (900 seconds = 15 minutes, matching AWS MSK IAM signer reference) |
179 | 0 | static constexpr int TOKEN_EXPIRY_SECONDS = 900; |
180 | | |
181 | | // Build the endpoint URL |
182 | 0 | std::string endpoint_url = "https://kafka." + _config.region + ".amazonaws.com/"; |
183 | | |
184 | | // Build credential scope |
185 | 0 | std::string credential_scope = |
186 | 0 | date_stamp + "/" + _config.region + "/kafka-cluster/aws4_request"; |
187 | | |
188 | | // Build the canonical query string (sorted alphabetically) |
189 | | // IMPORTANT: All query parameters must be included in the signature calculation |
190 | | // Session Token must be in canonical query string if using temporary credentials |
191 | 0 | std::stringstream canonical_query_ss; |
192 | 0 | canonical_query_ss << "Action=kafka-cluster%3AConnect"; // URL-encoded : |
193 | | |
194 | | // Add Algorithm |
195 | 0 | canonical_query_ss << "&X-Amz-Algorithm=AWS4-HMAC-SHA256"; |
196 | | |
197 | | // Add Credential |
198 | 0 | std::string credential = std::string(credentials.GetAWSAccessKeyId()) + "/" + credential_scope; |
199 | 0 | canonical_query_ss << "&X-Amz-Credential=" << _url_encode(credential); |
200 | | |
201 | | // Add Date |
202 | 0 | canonical_query_ss << "&X-Amz-Date=" << timestamp; |
203 | | |
204 | | // Add Expires |
205 | 0 | canonical_query_ss << "&X-Amz-Expires=" << TOKEN_EXPIRY_SECONDS; |
206 | | |
207 | | // Add Security Token if present (MUST be before signature calculation) |
208 | 0 | if (!credentials.GetSessionToken().empty()) { |
209 | 0 | canonical_query_ss << "&X-Amz-Security-Token=" |
210 | 0 | << _url_encode(std::string(credentials.GetSessionToken())); |
211 | 0 | } |
212 | | |
213 | | // Add SignedHeaders |
214 | 0 | canonical_query_ss << "&X-Amz-SignedHeaders=host"; |
215 | |
|
216 | 0 | std::string canonical_query_string = canonical_query_ss.str(); |
217 | | |
218 | | // Build the canonical headers |
219 | 0 | std::string host = "kafka." + _config.region + ".amazonaws.com"; |
220 | 0 | std::string canonical_headers = "host:" + host + "\n"; |
221 | 0 | std::string signed_headers = "host"; |
222 | | |
223 | | // Build the canonical request |
224 | 0 | std::string method = "GET"; |
225 | 0 | std::string uri = "/"; |
226 | 0 | std::string payload_hash = _sha256(""); |
227 | |
|
228 | 0 | std::string canonical_request = method + "\n" + uri + "\n" + canonical_query_string + "\n" + |
229 | 0 | canonical_headers + "\n" + signed_headers + "\n" + payload_hash; |
230 | | |
231 | | // Build the string to sign |
232 | 0 | std::string algorithm = "AWS4-HMAC-SHA256"; |
233 | 0 | std::string canonical_request_hash = _sha256(canonical_request); |
234 | 0 | std::string string_to_sign = |
235 | 0 | algorithm + "\n" + timestamp + "\n" + credential_scope + "\n" + canonical_request_hash; |
236 | | |
237 | | // Calculate signature |
238 | 0 | std::string signing_key = _calculate_signing_key(std::string(credentials.GetAWSSecretKey()), |
239 | 0 | date_stamp, _config.region, "kafka-cluster"); |
240 | 0 | std::string signature = _hmac_sha256_hex(signing_key, string_to_sign); |
241 | | |
242 | | // Build the final presigned URL |
243 | | // All parameters are already in canonical_query_string, just add signature |
244 | | // Then add User-Agent AFTER signature (not part of signed content, matching reference impl) |
245 | 0 | std::string signed_url = endpoint_url + "?" + canonical_query_string + |
246 | 0 | "&X-Amz-Signature=" + signature + |
247 | 0 | "&User-Agent=doris-msk-iam-auth%2F1.0"; |
248 | | |
249 | | // Base64url encode the signed URL (without padding) |
250 | 0 | *token = _base64url_encode(signed_url); |
251 | | |
252 | | // Token lifetime in milliseconds |
253 | 0 | *token_lifetime_ms = TOKEN_EXPIRY_SECONDS * 1000; |
254 | |
|
255 | 0 | VLOG_DEBUG << "Generated AWS MSK IAM token for region: " << _config.region; |
256 | 0 | return Status::OK(); |
257 | 5 | } |
258 | | |
259 | 0 | std::string AwsMskIamAuth::_hmac_sha256_hex(const std::string& key, const std::string& data) { |
260 | 0 | std::string raw = _hmac_sha256(key, data); |
261 | 0 | std::stringstream ss; |
262 | 0 | for (unsigned char c : raw) { |
263 | 0 | ss << std::hex << std::setw(2) << std::setfill('0') << static_cast<int>(c); |
264 | 0 | } |
265 | 0 | return ss.str(); |
266 | 0 | } |
267 | | |
268 | 0 | std::string AwsMskIamAuth::_url_encode(const std::string& value) { |
269 | 0 | std::ostringstream escaped; |
270 | 0 | escaped.fill('0'); |
271 | 0 | escaped << std::hex; |
272 | |
|
273 | 0 | for (char c : value) { |
274 | | // Keep alphanumeric and other accepted characters intact |
275 | 0 | if (isalnum(static_cast<unsigned char>(c)) || c == '-' || c == '_' || c == '.' || |
276 | 0 | c == '~') { |
277 | 0 | escaped << c; |
278 | 0 | } else { |
279 | | // Any other characters are percent-encoded |
280 | 0 | escaped << std::uppercase; |
281 | 0 | escaped << '%' << std::setw(2) << static_cast<int>(static_cast<unsigned char>(c)); |
282 | 0 | escaped << std::nouppercase; |
283 | 0 | } |
284 | 0 | } |
285 | |
|
286 | 0 | return escaped.str(); |
287 | 0 | } |
288 | | |
289 | 0 | std::string AwsMskIamAuth::_base64url_encode(const std::string& input) { |
290 | | // Standard base64 alphabet |
291 | 0 | static const char* base64_chars = |
292 | 0 | "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/"; |
293 | |
|
294 | 0 | std::string result; |
295 | 0 | result.reserve(((input.size() + 2) / 3) * 4); |
296 | |
|
297 | 0 | const unsigned char* bytes = reinterpret_cast<const unsigned char*>(input.c_str()); |
298 | 0 | size_t len = input.size(); |
299 | |
|
300 | 0 | for (size_t i = 0; i < len; i += 3) { |
301 | 0 | uint32_t n = static_cast<uint32_t>(bytes[i]) << 16; |
302 | 0 | if (i + 1 < len) n |= static_cast<uint32_t>(bytes[i + 1]) << 8; |
303 | 0 | if (i + 2 < len) n |= static_cast<uint32_t>(bytes[i + 2]); |
304 | |
|
305 | 0 | result += base64_chars[(n >> 18) & 0x3F]; |
306 | 0 | result += base64_chars[(n >> 12) & 0x3F]; |
307 | 0 | if (i + 1 < len) result += base64_chars[(n >> 6) & 0x3F]; |
308 | 0 | if (i + 2 < len) result += base64_chars[n & 0x3F]; |
309 | 0 | } |
310 | | |
311 | | // Convert to URL-safe base64 (replace + with -, / with _) |
312 | | // and remove padding (=) |
313 | 0 | for (char& c : result) { |
314 | 0 | if (c == '+') |
315 | 0 | c = '-'; |
316 | 0 | else if (c == '/') |
317 | 0 | c = '_'; |
318 | 0 | } |
319 | |
|
320 | 0 | return result; |
321 | 0 | } |
322 | | |
323 | | std::string AwsMskIamAuth::_calculate_signing_key(const std::string& secret_key, |
324 | | const std::string& date_stamp, |
325 | | const std::string& region, |
326 | 0 | const std::string& service) { |
327 | 0 | std::string k_secret = "AWS4" + secret_key; |
328 | 0 | std::string k_date = _hmac_sha256(k_secret, date_stamp); |
329 | 0 | std::string k_region = _hmac_sha256(k_date, region); |
330 | 0 | std::string k_service = _hmac_sha256(k_region, service); |
331 | 0 | std::string k_signing = _hmac_sha256(k_service, "aws4_request"); |
332 | 0 | return k_signing; |
333 | 0 | } |
334 | | |
335 | 0 | std::string AwsMskIamAuth::_hmac_sha256(const std::string& key, const std::string& data) { |
336 | 0 | unsigned char digest[EVP_MAX_MD_SIZE]; |
337 | 0 | unsigned int digest_len = 0; |
338 | 0 | HMAC(EVP_sha256(), key.c_str(), static_cast<int>(key.length()), |
339 | 0 | reinterpret_cast<const unsigned char*>(data.c_str()), data.length(), digest, &digest_len); |
340 | 0 | return {reinterpret_cast<char*>(digest), digest_len}; |
341 | 0 | } |
342 | | |
343 | 0 | std::string AwsMskIamAuth::_sha256(const std::string& data) { |
344 | 0 | unsigned char hash[SHA256_DIGEST_LENGTH]; |
345 | 0 | SHA256(reinterpret_cast<const unsigned char*>(data.c_str()), data.length(), hash); |
346 | |
|
347 | 0 | std::stringstream ss; |
348 | 0 | for (unsigned char i : hash) { |
349 | 0 | ss << std::hex << std::setw(2) << std::setfill('0') << (int)i; |
350 | 0 | } |
351 | 0 | return ss.str(); |
352 | 0 | } |
353 | | |
354 | 0 | std::string AwsMskIamAuth::_get_timestamp() { |
355 | 0 | auto now = std::chrono::system_clock::now(); |
356 | 0 | auto time_t_now = std::chrono::system_clock::to_time_t(now); |
357 | 0 | std::tm tm_now; |
358 | 0 | gmtime_r(&time_t_now, &tm_now); |
359 | |
|
360 | 0 | std::stringstream ss; |
361 | 0 | ss << std::put_time(&tm_now, "%Y%m%dT%H%M%SZ"); |
362 | 0 | return ss.str(); |
363 | 0 | } |
364 | | |
365 | 0 | std::string AwsMskIamAuth::_get_date_stamp(const std::string& timestamp) { |
366 | | // Extract YYYYMMDD from YYYYMMDDTHHMMSSz |
367 | 0 | return timestamp.substr(0, 8); |
368 | 0 | } |
369 | | |
370 | | // AwsMskIamOAuthCallback implementation |
371 | | |
372 | | namespace { |
373 | | // Property keys for AWS MSK IAM authentication |
374 | | constexpr const char* PROP_SECURITY_PROTOCOL = "security.protocol"; |
375 | | constexpr const char* PROP_SASL_MECHANISM = "sasl.mechanism"; |
376 | | constexpr const char* PROP_AWS_REGION = "aws.region"; |
377 | | constexpr const char* PROP_AWS_ACCESS_KEY = "aws.access_key"; |
378 | | constexpr const char* PROP_AWS_SECRET_KEY = "aws.secret_key"; |
379 | | constexpr const char* PROP_AWS_ROLE_ARN = "aws.role_arn"; |
380 | | constexpr const char* PROP_AWS_EXTERNAL_ID = "aws.external_id"; |
381 | | constexpr const char* PROP_AWS_PROFILE_NAME = "aws.profile_name"; |
382 | | constexpr const char* PROP_AWS_CREDENTIALS_PROVIDER = "aws.credentials_provider"; |
383 | | } // namespace |
384 | | |
385 | | std::unique_ptr<AwsMskIamOAuthCallback> AwsMskIamOAuthCallback::create_from_properties( |
386 | | const std::unordered_map<std::string, std::string>& custom_properties, |
387 | 79 | const std::string& brokers) { |
388 | 79 | auto security_protocol_it = custom_properties.find(PROP_SECURITY_PROTOCOL); |
389 | 79 | auto sasl_mechanism_it = custom_properties.find(PROP_SASL_MECHANISM); |
390 | 79 | bool is_sasl_ssl = security_protocol_it != custom_properties.end() && |
391 | 79 | security_protocol_it->second == "SASL_SSL"; |
392 | 79 | bool is_oauthbearer = sasl_mechanism_it != custom_properties.end() && |
393 | 79 | sasl_mechanism_it->second == "OAUTHBEARER"; |
394 | | |
395 | 79 | if (!is_sasl_ssl || !is_oauthbearer) { |
396 | 77 | return nullptr; |
397 | 77 | } |
398 | | |
399 | | // Extract broker hostname for token generation. |
400 | 2 | std::string broker_hostname = brokers; |
401 | | // If there are multiple brokers, we use the first one (Refrain : is this ok?) |
402 | 2 | if (broker_hostname.find(',') != std::string::npos) { |
403 | 0 | broker_hostname = broker_hostname.substr(0, broker_hostname.find(',')); |
404 | 0 | } |
405 | | // Remove port if present |
406 | 2 | if (broker_hostname.find(':') != std::string::npos) { |
407 | 2 | broker_hostname = broker_hostname.substr(0, broker_hostname.find(':')); |
408 | 2 | } |
409 | | |
410 | 2 | AwsMskIamAuth::Config auth_config; |
411 | | |
412 | 2 | auto region_it = custom_properties.find(PROP_AWS_REGION); |
413 | 2 | if (region_it != custom_properties.end()) { |
414 | 2 | auth_config.region = region_it->second; |
415 | 2 | } |
416 | | |
417 | 2 | auto access_key_it = custom_properties.find(PROP_AWS_ACCESS_KEY); |
418 | 2 | auto secret_key_it = custom_properties.find(PROP_AWS_SECRET_KEY); |
419 | 2 | if (access_key_it != custom_properties.end() && secret_key_it != custom_properties.end()) { |
420 | 0 | auth_config.access_key = access_key_it->second; |
421 | 0 | auth_config.secret_key = secret_key_it->second; |
422 | 0 | LOG(INFO) << "AWS MSK IAM: using explicit credentials (region: " << auth_config.region |
423 | 0 | << ")"; |
424 | 0 | } |
425 | | |
426 | 2 | auto role_arn_it = custom_properties.find(PROP_AWS_ROLE_ARN); |
427 | 2 | if (role_arn_it != custom_properties.end()) { |
428 | 1 | auth_config.role_arn = role_arn_it->second; |
429 | 1 | LOG(INFO) << "AWS MSK IAM: using role " << auth_config.role_arn |
430 | 1 | << " (region: " << auth_config.region << ")"; |
431 | 1 | } |
432 | | |
433 | 2 | auto external_id_it = custom_properties.find(PROP_AWS_EXTERNAL_ID); |
434 | 2 | if (external_id_it != custom_properties.end()) { |
435 | 2 | auth_config.external_id = external_id_it->second; |
436 | 2 | LOG(INFO) << "AWS MSK IAM: using external id with role assumption (region: " |
437 | 2 | << auth_config.region << ")"; |
438 | 2 | } |
439 | | |
440 | 2 | auto profile_name_it = custom_properties.find(PROP_AWS_PROFILE_NAME); |
441 | 2 | if (profile_name_it != custom_properties.end()) { |
442 | 0 | auth_config.profile_name = profile_name_it->second; |
443 | 0 | LOG(INFO) << "AWS MSK IAM: using profile " << auth_config.profile_name |
444 | 0 | << " (region: " << auth_config.region << ")"; |
445 | 0 | } |
446 | | |
447 | 2 | auto credentials_provider_it = custom_properties.find(PROP_AWS_CREDENTIALS_PROVIDER); |
448 | 2 | if (credentials_provider_it != custom_properties.end()) { |
449 | 0 | auth_config.credentials_provider = credentials_provider_it->second; |
450 | 0 | LOG(INFO) << "AWS MSK IAM: using credentials provider " << auth_config.credentials_provider |
451 | 0 | << " (region: " << auth_config.region << ")"; |
452 | 0 | } |
453 | | |
454 | 2 | if (!auth_config.external_id.empty() && auth_config.role_arn.empty()) { |
455 | 1 | LOG(ERROR) << "AWS MSK IAM authentication: 'aws.external_id' requires 'aws.role_arn'"; |
456 | 1 | return nullptr; |
457 | 1 | } |
458 | | |
459 | | // Validate that at least one credential source is configured |
460 | 1 | bool has_credentials = !auth_config.access_key.empty() || !auth_config.role_arn.empty() || |
461 | 1 | !auth_config.profile_name.empty() || |
462 | 1 | !auth_config.credentials_provider.empty(); |
463 | | |
464 | 1 | if (!has_credentials) { |
465 | 0 | LOG(ERROR) << "AWS MSK IAM authentication enabled but no credentials configured. " |
466 | 0 | << "Please provide one of: access_key/secret_key, role_arn, profile_name, or " |
467 | 0 | "credentials_provider"; |
468 | 0 | return nullptr; |
469 | 0 | } |
470 | | |
471 | 1 | LOG(INFO) << "Enabling AWS MSK IAM authentication for broker: " << broker_hostname |
472 | 1 | << ", region: " << auth_config.region; |
473 | | |
474 | 1 | auto auth = std::make_shared<AwsMskIamAuth>(auth_config); |
475 | 1 | return std::make_unique<AwsMskIamOAuthCallback>(std::move(auth), std::move(broker_hostname)); |
476 | 1 | } |
477 | | |
478 | | AwsMskIamOAuthCallback::AwsMskIamOAuthCallback(std::shared_ptr<AwsMskIamAuth> auth, |
479 | | std::string broker_hostname) |
480 | 2 | : _auth(std::move(auth)), _broker_hostname(std::move(broker_hostname)) {} |
481 | | |
482 | 0 | Status AwsMskIamOAuthCallback::refresh_now(RdKafka::Handle* handle) { |
483 | 0 | std::string token; |
484 | 0 | int64_t token_lifetime_ms = 0; |
485 | |
|
486 | 0 | RETURN_IF_ERROR(_auth->generate_token(_broker_hostname, &token, &token_lifetime_ms)); |
487 | | |
488 | 0 | std::string principal = "doris-consumer"; |
489 | 0 | std::list<std::string> extensions; |
490 | 0 | std::string errstr; |
491 | |
|
492 | 0 | auto now = std::chrono::system_clock::now(); |
493 | 0 | auto now_ms = |
494 | 0 | std::chrono::duration_cast<std::chrono::milliseconds>(now.time_since_epoch()).count(); |
495 | 0 | int64_t token_expiry_ms = now_ms + token_lifetime_ms; |
496 | |
|
497 | 0 | auto err = handle->oauthbearer_set_token(token, token_expiry_ms, principal, extensions, errstr); |
498 | 0 | if (err != RdKafka::ERR_NO_ERROR) { |
499 | 0 | return Status::InternalError("Failed to set OAuth token: {}, detail: {}", |
500 | 0 | RdKafka::err2str(err), errstr); |
501 | 0 | } |
502 | | |
503 | 0 | LOG(INFO) << "Successfully set AWS MSK IAM OAuth token, lifetime: " << token_lifetime_ms |
504 | 0 | << "ms"; |
505 | 0 | return Status::OK(); |
506 | 0 | } |
507 | | |
508 | | void AwsMskIamOAuthCallback::oauthbearer_token_refresh_cb( |
509 | 0 | RdKafka::Handle* handle, const std::string& /*oauthbearer_config*/) { |
510 | 0 | Status st = refresh_now(handle); |
511 | 0 | if (!st.ok()) { |
512 | | LOG(WARNING) << st; |
513 | 0 | handle->oauthbearer_set_token_failure(st.to_string()); |
514 | 0 | } |
515 | 0 | } |
516 | | |
517 | | } // namespace doris |