Coverage Report

Created: 2024-11-18 10:37

/root/doris/be/src/util/s3_util.cpp
Line
Count
Source (jump to first uncovered line)
1
// Licensed to the Apache Software Foundation (ASF) under one
2
// or more contributor license agreements.  See the NOTICE file
3
// distributed with this work for additional information
4
// regarding copyright ownership.  The ASF licenses this file
5
// to you under the Apache License, Version 2.0 (the
6
// "License"); you may not use this file except in compliance
7
// with the License.  You may obtain a copy of the License at
8
//
9
//   http://www.apache.org/licenses/LICENSE-2.0
10
//
11
// Unless required by applicable law or agreed to in writing,
12
// software distributed under the License is distributed on an
13
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14
// KIND, either express or implied.  See the License for the
15
// specific language governing permissions and limitations
16
// under the License.
17
18
#include "util/s3_util.h"
19
20
#include <aws/core/auth/AWSAuthSigner.h>
21
#include <aws/core/auth/AWSCredentials.h>
22
#include <aws/core/auth/AWSCredentialsProviderChain.h>
23
#include <aws/core/utils/logging/LogLevel.h>
24
#include <aws/core/utils/logging/LogSystemInterface.h>
25
#include <aws/core/utils/memory/stl/AWSStringStream.h>
26
#include <aws/s3/S3Client.h>
27
#include <util/string_util.h>
28
29
#include <atomic>
30
#include <cstdlib>
31
#include <filesystem>
32
#include <functional>
33
#include <ostream>
34
#include <utility>
35
36
#include "common/config.h"
37
#include "common/logging.h"
38
#include "runtime/exec_env.h"
39
#include "s3_uri.h"
40
#include "vec/exec/scan/scanner_scheduler.h"
41
42
namespace doris {
43
44
class DorisAWSLogger final : public Aws::Utils::Logging::LogSystemInterface {
45
public:
46
0
    DorisAWSLogger() : _log_level(Aws::Utils::Logging::LogLevel::Info) {}
47
0
    DorisAWSLogger(Aws::Utils::Logging::LogLevel log_level) : _log_level(log_level) {}
48
0
    ~DorisAWSLogger() final = default;
49
0
    Aws::Utils::Logging::LogLevel GetLogLevel() const final { return _log_level; }
50
    void Log(Aws::Utils::Logging::LogLevel log_level, const char* tag, const char* format_str,
51
0
             ...) final {
52
0
        _log_impl(log_level, tag, format_str);
53
0
    }
54
    void LogStream(Aws::Utils::Logging::LogLevel log_level, const char* tag,
55
0
                   const Aws::OStringStream& message_stream) final {
56
0
        _log_impl(log_level, tag, message_stream.str().c_str());
57
0
    }
58
59
0
    void Flush() final {}
60
61
private:
62
0
    void _log_impl(Aws::Utils::Logging::LogLevel log_level, const char* tag, const char* message) {
63
0
        switch (log_level) {
64
0
        case Aws::Utils::Logging::LogLevel::Off:
65
0
            break;
66
0
        case Aws::Utils::Logging::LogLevel::Fatal:
67
0
            LOG(FATAL) << "[" << tag << "] " << message;
68
0
            break;
69
0
        case Aws::Utils::Logging::LogLevel::Error:
70
0
            LOG(ERROR) << "[" << tag << "] " << message;
71
0
            break;
72
0
        case Aws::Utils::Logging::LogLevel::Warn:
73
0
            LOG(WARNING) << "[" << tag << "] " << message;
74
0
            break;
75
0
        case Aws::Utils::Logging::LogLevel::Info:
76
0
            LOG(INFO) << "[" << tag << "] " << message;
77
0
            break;
78
0
        case Aws::Utils::Logging::LogLevel::Debug:
79
0
            LOG(INFO) << "[" << tag << "] " << message;
80
0
            break;
81
0
        case Aws::Utils::Logging::LogLevel::Trace:
82
0
            LOG(INFO) << "[" << tag << "] " << message;
83
0
            break;
84
0
        default:
85
0
            break;
86
0
        }
87
0
    }
88
89
    std::atomic<Aws::Utils::Logging::LogLevel> _log_level;
90
};
91
92
const static std::string USE_PATH_STYLE = "use_path_style";
93
94
1
S3ClientFactory::S3ClientFactory() {
95
1
    _aws_options = Aws::SDKOptions {};
96
1
    Aws::Utils::Logging::LogLevel logLevel =
97
1
            static_cast<Aws::Utils::Logging::LogLevel>(config::aws_log_level);
98
1
    _aws_options.loggingOptions.logLevel = logLevel;
99
1
    _aws_options.loggingOptions.logger_create_fn = [logLevel] {
100
0
        return std::make_shared<DorisAWSLogger>(logLevel);
101
0
    };
102
1
    Aws::InitAPI(_aws_options);
103
1
    _ca_cert_file_path = get_valid_ca_cert_path();
104
1
}
105
106
2
string S3ClientFactory::get_valid_ca_cert_path() {
107
2
    vector<std::string> vec_ca_file_path = doris::split(config::ca_cert_file_paths, ";");
108
2
    vector<std::string>::iterator it = vec_ca_file_path.begin();
109
2
    for (; it != vec_ca_file_path.end(); ++it) {
110
2
        if (std::filesystem::exists(*it)) {
111
2
            return *it;
112
2
        }
113
2
    }
114
0
    return "";
115
2
}
116
117
1
S3ClientFactory::~S3ClientFactory() {
118
1
    Aws::ShutdownAPI(_aws_options);
119
1
}
120
121
3
S3ClientFactory& S3ClientFactory::instance() {
122
3
    static S3ClientFactory ret;
123
3
    return ret;
124
3
}
125
126
0
bool S3ClientFactory::is_s3_conf_valid(const std::map<std::string, std::string>& prop) {
127
0
    StringCaseMap<std::string> properties(prop.begin(), prop.end());
128
0
    if (properties.find(S3_ENDPOINT) == properties.end() ||
129
0
        properties.find(S3_REGION) == properties.end()) {
130
0
        DCHECK(false) << "aws properties is incorrect.";
131
0
        LOG(ERROR) << "aws properties is incorrect.";
132
0
        return false;
133
0
    }
134
0
    return true;
135
0
}
136
137
3
bool S3ClientFactory::is_s3_conf_valid(const S3Conf& s3_conf) {
138
3
    return !s3_conf.endpoint.empty();
139
3
}
140
141
3
std::shared_ptr<Aws::S3::S3Client> S3ClientFactory::create(const S3Conf& s3_conf) {
142
3
    if (!is_s3_conf_valid(s3_conf)) {
143
2
        return nullptr;
144
2
    }
145
146
1
    uint64_t hash = s3_conf.get_hash();
147
1
    {
148
1
        std::lock_guard l(_lock);
149
1
        auto it = _cache.find(hash);
150
1
        if (it != _cache.end()) {
151
0
            return it->second;
152
0
        }
153
1
    }
154
155
1
    Aws::Client::ClientConfiguration aws_config = S3ClientFactory::getClientConfiguration();
156
1
    aws_config.endpointOverride = s3_conf.endpoint;
157
1
    aws_config.region = s3_conf.region;
158
1
    std::string ca_cert = get_valid_ca_cert_path();
159
1
    if ("" != _ca_cert_file_path) {
160
1
        aws_config.caFile = _ca_cert_file_path;
161
1
    } else {
162
        // config::ca_cert_file_paths is valmutable,get newest value if file path invaild
163
0
        _ca_cert_file_path = get_valid_ca_cert_path();
164
0
        if ("" != _ca_cert_file_path) {
165
0
            aws_config.caFile = _ca_cert_file_path;
166
0
        }
167
0
    }
168
1
    if (s3_conf.max_connections > 0) {
169
0
        aws_config.maxConnections = s3_conf.max_connections;
170
1
    } else {
171
1
#ifdef BE_TEST
172
        // the S3Client may shared by many threads.
173
        // So need to set the number of connections large enough.
174
1
        aws_config.maxConnections = config::doris_scanner_thread_pool_thread_num;
175
#else
176
        aws_config.maxConnections =
177
                ExecEnv::GetInstance()->scanner_scheduler()->remote_thread_pool_max_size();
178
#endif
179
1
    }
180
181
1
    if (s3_conf.request_timeout_ms > 0) {
182
0
        aws_config.requestTimeoutMs = s3_conf.request_timeout_ms;
183
0
    }
184
1
    if (s3_conf.connect_timeout_ms > 0) {
185
0
        aws_config.connectTimeoutMs = s3_conf.connect_timeout_ms;
186
0
    }
187
1
    std::shared_ptr<Aws::S3::S3Client> new_client;
188
1
    if (!s3_conf.ak.empty() && !s3_conf.sk.empty()) {
189
1
        Aws::Auth::AWSCredentials aws_cred(s3_conf.ak, s3_conf.sk);
190
1
        DCHECK(!aws_cred.IsExpiredOrEmpty());
191
1
        if (!s3_conf.token.empty()) {
192
0
            aws_cred.SetSessionToken(s3_conf.token);
193
0
        }
194
1
        new_client = std::make_shared<Aws::S3::S3Client>(
195
1
                std::move(aws_cred), std::move(aws_config),
196
1
                Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
197
1
                s3_conf.use_virtual_addressing);
198
1
    } else {
199
0
        std::shared_ptr<Aws::Auth::AWSCredentialsProvider> aws_provider_chain =
200
0
                std::make_shared<Aws::Auth::DefaultAWSCredentialsProviderChain>();
201
0
        new_client = std::make_shared<Aws::S3::S3Client>(
202
0
                std::move(aws_provider_chain), std::move(aws_config),
203
0
                Aws::Client::AWSAuthV4Signer::PayloadSigningPolicy::Never,
204
0
                s3_conf.use_virtual_addressing);
205
0
    }
206
207
1
    {
208
1
        std::lock_guard l(_lock);
209
1
        _cache[hash] = new_client;
210
1
    }
211
1
    return new_client;
212
1
}
213
214
Status S3ClientFactory::convert_properties_to_s3_conf(
215
0
        const std::map<std::string, std::string>& prop, const S3URI& s3_uri, S3Conf* s3_conf) {
216
0
    if (!is_s3_conf_valid(prop)) {
217
0
        return Status::InvalidArgument("S3 properties are incorrect, please check properties.");
218
0
    }
219
0
    StringCaseMap<std::string> properties(prop.begin(), prop.end());
220
0
    if (properties.find(S3_AK) != properties.end() && properties.find(S3_SK) != properties.end()) {
221
0
        s3_conf->ak = properties.find(S3_AK)->second;
222
0
        s3_conf->sk = properties.find(S3_SK)->second;
223
0
    }
224
0
    if (properties.find(S3_TOKEN) != properties.end()) {
225
0
        s3_conf->token = properties.find(S3_TOKEN)->second;
226
0
    }
227
0
    s3_conf->endpoint = properties.find(S3_ENDPOINT)->second;
228
0
    s3_conf->region = properties.find(S3_REGION)->second;
229
230
0
    if (properties.find(S3_MAX_CONN_SIZE) != properties.end()) {
231
0
        s3_conf->max_connections = std::atoi(properties.find(S3_MAX_CONN_SIZE)->second.c_str());
232
0
    }
233
0
    if (properties.find(S3_REQUEST_TIMEOUT_MS) != properties.end()) {
234
0
        s3_conf->request_timeout_ms =
235
0
                std::atoi(properties.find(S3_REQUEST_TIMEOUT_MS)->second.c_str());
236
0
    }
237
0
    if (properties.find(S3_CONN_TIMEOUT_MS) != properties.end()) {
238
0
        s3_conf->connect_timeout_ms =
239
0
                std::atoi(properties.find(S3_CONN_TIMEOUT_MS)->second.c_str());
240
0
    }
241
0
    if (s3_uri.get_bucket() == "") {
242
0
        return Status::InvalidArgument("Invalid S3 URI {}, bucket is not specified",
243
0
                                       s3_uri.to_string());
244
0
    }
245
0
    s3_conf->bucket = s3_uri.get_bucket();
246
0
    s3_conf->prefix = "";
247
248
    // See https://sdk.amazonaws.com/cpp/api/LATEST/class_aws_1_1_s3_1_1_s3_client.html
249
0
    s3_conf->use_virtual_addressing = true;
250
0
    if (properties.find(USE_PATH_STYLE) != properties.end()) {
251
0
        s3_conf->use_virtual_addressing =
252
0
                properties.find(USE_PATH_STYLE)->second == "true" ? false : true;
253
0
    }
254
0
    return Status::OK();
255
0
}
256
257
} // end namespace doris