diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 1ff6a9a40f..68db04dea9 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -2,6 +2,8 @@ import urllib.request +HTTPLIB_VERSION = "f80864ca031932351abef49b74097c67f14719c6" + vendor = { "https://github.com/nlohmann/json/releases/latest/download/json.hpp": "vendor/nlohmann/json.hpp", "https://github.com/nlohmann/json/releases/latest/download/json_fwd.hpp": "vendor/nlohmann/json_fwd.hpp", @@ -12,8 +14,8 @@ vendor = { # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/httplib.h": "vendor/cpp-httplib/httplib.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.2/LICENSE": "vendor/cpp-httplib/LICENSE", + f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/httplib.h": "vendor/cpp-httplib/httplib.h", + f"https://raw.githubusercontent.com/yhirose/cpp-httplib/{HTTPLIB_VERSION}/LICENSE": "vendor/cpp-httplib/LICENSE", "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index ba5f9c8ff9..e309a7ad5d 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -6,8 +6,472 @@ namespace httplib { * Implementation that will be part of the .cc file if split into .h + .cc. */ +namespace stream { + +// stream::Result implementations +Result::Result() : chunk_size_(8192) {} + +Result::Result(ClientImpl::StreamHandle &&handle, size_t chunk_size) + : handle_(std::move(handle)), chunk_size_(chunk_size) {} + +Result::Result(Result &&other) noexcept + : handle_(std::move(other.handle_)), buffer_(std::move(other.buffer_)), + current_size_(other.current_size_), chunk_size_(other.chunk_size_), + finished_(other.finished_) { + other.current_size_ = 0; + other.finished_ = true; +} + +Result &Result::operator=(Result &&other) noexcept { + if (this != &other) { + handle_ = std::move(other.handle_); + buffer_ = std::move(other.buffer_); + current_size_ = other.current_size_; + chunk_size_ = other.chunk_size_; + finished_ = other.finished_; + other.current_size_ = 0; + other.finished_ = true; + } + return *this; +} + +bool Result::is_valid() const { return handle_.is_valid(); } +Result::operator bool() const { return is_valid(); } + +int Result::status() const { + return handle_.response ? handle_.response->status : -1; +} + +const Headers &Result::headers() const { + static const Headers empty_headers; + return handle_.response ? handle_.response->headers : empty_headers; +} + +std::string Result::get_header_value(const std::string &key, + const char *def) const { + return handle_.response ? handle_.response->get_header_value(key, def) : def; +} + +bool Result::has_header(const std::string &key) const { + return handle_.response ? handle_.response->has_header(key) : false; +} + +Error Result::error() const { return handle_.error; } +Error Result::read_error() const { return handle_.get_read_error(); } +bool Result::has_read_error() const { return handle_.has_read_error(); } + +bool Result::next() { + if (!handle_.is_valid() || finished_) { return false; } + + if (buffer_.size() < chunk_size_) { buffer_.resize(chunk_size_); } + + ssize_t n = handle_.read(&buffer_[0], chunk_size_); + if (n > 0) { + current_size_ = static_cast(n); + return true; + } + + current_size_ = 0; + finished_ = true; + return false; +} + +const char *Result::data() const { return buffer_.data(); } +size_t Result::size() const { return current_size_; } + +std::string Result::read_all() { + std::string result; + while (next()) { + result.append(data(), size()); + } + return result; +} + +} // namespace stream + +namespace sse { + +// SSEMessage implementations +SSEMessage::SSEMessage() : event("message") {} + +void SSEMessage::clear() { + event = "message"; + data.clear(); + id.clear(); +} + +// SSEClient implementations +SSEClient::SSEClient(Client &client, const std::string &path) + : client_(client), path_(path) {} + +SSEClient::SSEClient(Client &client, const std::string &path, + const Headers &headers) + : client_(client), path_(path), headers_(headers) {} + +SSEClient::~SSEClient() { stop(); } + +SSEClient &SSEClient::on_message(MessageHandler handler) { + on_message_ = std::move(handler); + return *this; +} + +SSEClient &SSEClient::on_event(const std::string &type, + MessageHandler handler) { + event_handlers_[type] = std::move(handler); + return *this; +} + +SSEClient &SSEClient::on_open(OpenHandler handler) { + on_open_ = std::move(handler); + return *this; +} + +SSEClient &SSEClient::on_error(ErrorHandler handler) { + on_error_ = std::move(handler); + return *this; +} + +SSEClient &SSEClient::set_reconnect_interval(int ms) { + reconnect_interval_ms_ = ms; + return *this; +} + +SSEClient &SSEClient::set_max_reconnect_attempts(int n) { + max_reconnect_attempts_ = n; + return *this; +} + +bool SSEClient::is_connected() const { return connected_.load(); } + +const std::string &SSEClient::last_event_id() const { + return last_event_id_; +} + +void SSEClient::start() { + running_.store(true); + run_event_loop(); +} + +void SSEClient::start_async() { + running_.store(true); + async_thread_ = std::thread([this]() { run_event_loop(); }); +} + +void SSEClient::stop() { + running_.store(false); + client_.stop(); // Cancel any pending operations + if (async_thread_.joinable()) { async_thread_.join(); } +} + +bool SSEClient::parse_sse_line(const std::string &line, SSEMessage &msg, + int &retry_ms) { + // Blank line signals end of event + if (line.empty() || line == "\r") { return true; } + + // Lines starting with ':' are comments (ignored) + if (!line.empty() && line[0] == ':') { return false; } + + // Find the colon separator + auto colon_pos = line.find(':'); + if (colon_pos == std::string::npos) { + // Line with no colon is treated as field name with empty value + return false; + } + + auto field = line.substr(0, colon_pos); + std::string value; + + // Value starts after colon, skip optional single space + if (colon_pos + 1 < line.size()) { + auto value_start = colon_pos + 1; + if (line[value_start] == ' ') { value_start++; } + value = line.substr(value_start); + // Remove trailing \r if present + if (!value.empty() && value.back() == '\r') { value.pop_back(); } + } + + // Handle known fields + if (field == "event") { + msg.event = value; + } else if (field == "data") { + // Multiple data lines are concatenated with newlines + if (!msg.data.empty()) { msg.data += "\n"; } + msg.data += value; + } else if (field == "id") { + // Empty id is valid (clears the last event ID) + msg.id = value; + } else if (field == "retry") { + // Parse retry interval in milliseconds + { + int v = 0; + auto res = + detail::from_chars(value.data(), value.data() + value.size(), v); + if (res.ec == std::errc{}) { retry_ms = v; } + } + } + // Unknown fields are ignored per SSE spec + + return false; +} + +void SSEClient::run_event_loop() { + auto reconnect_count = 0; + + while (running_.load()) { + // Build headers, including Last-Event-ID if we have one + auto request_headers = headers_; + if (!last_event_id_.empty()) { + request_headers.emplace("Last-Event-ID", last_event_id_); + } + + // Open streaming connection + auto result = stream::Get(client_, path_, request_headers); + + // Connection error handling + if (!result) { + connected_.store(false); + if (on_error_) { on_error_(result.error()); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + if (result.status() != 200) { + connected_.store(false); + // For certain errors, don't reconnect + if (result.status() == 204 || // No Content - server wants us to stop + result.status() == 404 || // Not Found + result.status() == 401 || // Unauthorized + result.status() == 403) { // Forbidden + if (on_error_) { on_error_(Error::Connection); } + break; + } + + if (on_error_) { on_error_(Error::Connection); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + // Connection successful + connected_.store(true); + reconnect_count = 0; + if (on_open_) { on_open_(); } + + // Event receiving loop + std::string buffer; + SSEMessage current_msg; + + while (running_.load() && result.next()) { + buffer.append(result.data(), result.size()); + + // Process complete lines in the buffer + size_t line_start = 0; + size_t newline_pos; + + while ((newline_pos = buffer.find('\n', line_start)) != + std::string::npos) { + auto line = buffer.substr(line_start, newline_pos - line_start); + line_start = newline_pos + 1; + + // Parse the line and check if event is complete + auto event_complete = + parse_sse_line(line, current_msg, reconnect_interval_ms_); + + if (event_complete && !current_msg.data.empty()) { + // Update last_event_id for reconnection + if (!current_msg.id.empty()) { last_event_id_ = current_msg.id; } + + // Dispatch event to appropriate handler + dispatch_event(current_msg); + + current_msg.clear(); + } + } + + // Keep unprocessed data in buffer + buffer.erase(0, line_start); + } + + // Connection ended + connected_.store(false); + + if (!running_.load()) { break; } + + // Check for read errors + if (result.has_read_error()) { + if (on_error_) { on_error_(result.read_error()); } + } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + } + + connected_.store(false); +} + +void SSEClient::dispatch_event(const SSEMessage &msg) { + // Check for specific event type handler first + auto it = event_handlers_.find(msg.event); + if (it != event_handlers_.end()) { + it->second(msg); + return; + } + + // Fall back to generic message handler + if (on_message_) { on_message_(msg); } +} + +bool SSEClient::should_reconnect(int count) const { + if (!running_.load()) { return false; } + if (max_reconnect_attempts_ == 0) { return true; } // unlimited + return count < max_reconnect_attempts_; +} + +void SSEClient::wait_for_reconnect() { + // Use small increments to check running_ flag frequently + auto waited = 0; + while (running_.load() && waited < reconnect_interval_ms_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + waited += 100; + } +} + +} // namespace sse + +#ifdef CPPHTTPLIB_SSL_ENABLED +/* + * TLS abstraction layer - internal function declarations + * These are implementation details and not part of the public API. + */ +namespace tls { + +// Client context +ctx_t create_client_context(); +void free_context(ctx_t ctx); +bool set_min_version(ctx_t ctx, Version version); +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len); +bool load_ca_file(ctx_t ctx, const char *file_path); +bool load_ca_dir(ctx_t ctx, const char *dir_path); +bool load_system_certs(ctx_t ctx); +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password); +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password); + +// Server context +ctx_t create_server_context(); +bool set_server_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password); +bool set_server_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password); +bool set_client_ca_file(ctx_t ctx, const char *ca_file, const char *ca_dir); +void set_verify_client(ctx_t ctx, bool require); + +// Session management +session_t create_session(ctx_t ctx, socket_t sock); +void free_session(session_t session); +bool set_sni(session_t session, const char *hostname); +bool set_hostname(session_t session, const char *hostname); + +// Handshake (non-blocking capable) +TlsError connect(session_t session); +TlsError accept(session_t session); + +// Handshake with timeout (blocking until timeout) +bool connect_nonblocking(session_t session, socket_t sock, time_t timeout_sec, + time_t timeout_usec, TlsError *err); +bool accept_nonblocking(session_t session, socket_t sock, time_t timeout_sec, + time_t timeout_usec, TlsError *err); + +// I/O (non-blocking capable) +ssize_t read(session_t session, void *buf, size_t len, TlsError &err); +ssize_t write(session_t session, const void *buf, size_t len, TlsError &err); +int pending(const_session_t session); +void shutdown(session_t session, bool graceful); + +// Connection state +bool is_peer_closed(session_t session, socket_t sock); + +// Certificate verification +cert_t get_peer_cert(const_session_t session); +void free_cert(cert_t cert); +bool verify_hostname(cert_t cert, const char *hostname); +uint64_t hostname_mismatch_code(); +long get_verify_result(const_session_t session); + +// Certificate introspection +std::string get_cert_subject_cn(cert_t cert); +std::string get_cert_issuer_name(cert_t cert); +bool get_cert_sans(cert_t cert, std::vector &sans); +bool get_cert_validity(cert_t cert, time_t ¬_before, time_t ¬_after); +std::string get_cert_serial(cert_t cert); +bool get_cert_der(cert_t cert, std::vector &der); +const char *get_sni(const_session_t session); + +// CA store management +ca_store_t create_ca_store(const char *pem, size_t len); +void free_ca_store(ca_store_t store); +bool set_ca_store(ctx_t ctx, ca_store_t store); +size_t get_ca_certs(ctx_t ctx, std::vector &certs); +std::vector get_ca_names(ctx_t ctx); + +// Dynamic certificate update (for servers) +bool update_server_cert(ctx_t ctx, const char *cert_pem, const char *key_pem, + const char *password); +bool update_server_client_ca(ctx_t ctx, const char *ca_pem); + +// Certificate verification callback +bool set_verify_callback(ctx_t ctx, VerifyCallback callback); +long get_verify_error(const_session_t session); +std::string verify_error_string(long error_code); + +// TlsError information +uint64_t peek_error(); +uint64_t get_error(); +std::string error_string(uint64_t code); + +} // namespace tls +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 1: detail namespace - Non-SSL utilities + */ + namespace detail { +bool set_socket_opt_impl(socket_t sock, int level, int optname, + const void *optval, socklen_t optlen) { + return setsockopt(sock, level, optname, +#ifdef _WIN32 + reinterpret_cast(optval), +#else + optval, +#endif + optlen) == 0; +} + +bool set_socket_opt(socket_t sock, int level, int optname, int optval) { + return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval)); +} + +bool set_socket_opt_time(socket_t sock, int level, int optname, + time_t sec, time_t usec) { +#ifdef _WIN32 + auto timeout = static_cast(sec * 1000 + usec / 1000); +#else + timeval timeout; + timeout.tv_sec = static_cast(sec); + timeout.tv_usec = static_cast(usec); +#endif + return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout)); +} + bool is_hex(char c, int &v) { if (isdigit(c)) { v = c - '0'; @@ -940,39 +1404,6 @@ private: static const size_t read_buff_size_ = 1024l * 4; }; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -class SSLSocketStream final : public Stream { -public: - SSLSocketStream( - socket_t sock, SSL *ssl, time_t read_timeout_sec, - time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, time_t max_timeout_msec = 0, - std::chrono::time_point start_time = - (std::chrono::steady_clock::time_point::min)()); - ~SSLSocketStream() override; - - bool is_readable() const override; - bool wait_readable() const override; - bool wait_writable() const override; - ssize_t read(char *ptr, size_t size) override; - ssize_t write(const char *ptr, size_t size) override; - void get_remote_ip_and_port(std::string &ip, int &port) const override; - void get_local_ip_and_port(std::string &ip, int &port) const override; - socket_t socket() const override; - time_t duration() const override; - -private: - socket_t sock_; - SSL *ssl_; - time_t read_timeout_sec_; - time_t read_timeout_usec_; - time_t write_timeout_sec_; - time_t write_timeout_usec_; - time_t max_timeout_msec_; - const std::chrono::time_point start_time_; -}; -#endif - bool keep_alive(const std::atomic &svr_sock, socket_t sock, time_t keep_alive_timeout_sec) { using namespace std::chrono; @@ -2270,14 +2701,23 @@ bool read_headers(Stream &strm, Headers &headers) { return true; } -bool read_content_with_length(Stream &strm, size_t len, - DownloadProgress progress, - ContentReceiverWithProgress out) { +enum class ReadContentResult { + Success, // Successfully read the content + PayloadTooLarge, // The content exceeds the specified payload limit + Error // An error occurred while reading the content +}; + +ReadContentResult read_content_with_length( + Stream &strm, size_t len, DownloadProgress progress, + ContentReceiverWithProgress out, + size_t payload_max_length = (std::numeric_limits::max)()) { char buf[CPPHTTPLIB_RECV_BUFSIZ]; detail::BodyReader br; br.stream = &strm; + br.has_content_length = true; br.content_length = len; + br.payload_max_length = payload_max_length; br.chunked = false; br.bytes_read = 0; br.last_error = Error::Success; @@ -2287,36 +2727,27 @@ bool read_content_with_length(Stream &strm, size_t len, auto read_len = static_cast(len - r); auto to_read = (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ); auto n = detail::read_body_content(&strm, br, buf, to_read); - if (n <= 0) { return false; } + if (n <= 0) { + // Check if it was a payload size error + if (br.last_error == Error::ExceedMaxPayloadSize) { + return ReadContentResult::PayloadTooLarge; + } + return ReadContentResult::Error; + } - if (!out(buf, static_cast(n), r, len)) { return false; } + if (!out(buf, static_cast(n), r, len)) { + return ReadContentResult::Error; + } r += static_cast(n); if (progress) { - if (!progress(r, len)) { return false; } + if (!progress(r, len)) { return ReadContentResult::Error; } } } - return true; + return ReadContentResult::Success; } -void skip_content_with_length(Stream &strm, size_t len) { - char buf[CPPHTTPLIB_RECV_BUFSIZ]; - size_t r = 0; - while (r < len) { - auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); - if (n <= 0) { return; } - r += static_cast(n); - } -} - -enum class ReadContentResult { - Success, // Successfully read the content - PayloadTooLarge, // The content exceeds the specified payload limit - Error // An error occurred while reading the content -}; - ReadContentResult read_content_without_length(Stream &strm, size_t payload_max_length, ContentReceiverWithProgress out) { @@ -2462,12 +2893,13 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, if (is_invalid_value) { ret = false; - } else if (len > payload_max_length) { - exceed_payload_max_length = true; - skip_content_with_length(strm, len); - ret = false; } else if (len > 0) { - ret = read_content_with_length(strm, len, std::move(progress), out); + auto result = read_content_with_length( + strm, len, std::move(progress), out, payload_max_length); + ret = (result == ReadContentResult::Success); + if (result == ReadContentResult::PayloadTooLarge) { + exceed_payload_max_length = true; + } } } @@ -3645,226 +4077,6 @@ bool has_crlf(const std::string &s) { return false; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -std::string message_digest(const std::string &s, const EVP_MD *algo) { - auto context = std::unique_ptr( - EVP_MD_CTX_new(), EVP_MD_CTX_free); - - unsigned int hash_length = 0; - unsigned char hash[EVP_MAX_MD_SIZE]; - - EVP_DigestInit_ex(context.get(), algo, nullptr); - EVP_DigestUpdate(context.get(), s.c_str(), s.size()); - EVP_DigestFinal_ex(context.get(), hash, &hash_length); - - std::stringstream ss; - for (auto i = 0u; i < hash_length; ++i) { - ss << std::hex << std::setw(2) << std::setfill('0') - << static_cast(hash[i]); - } - - return ss.str(); -} - -std::string MD5(const std::string &s) { - return message_digest(s, EVP_md5()); -} - -std::string SHA_256(const std::string &s) { - return message_digest(s, EVP_sha256()); -} - -std::string SHA_512(const std::string &s) { - return message_digest(s, EVP_sha512()); -} - -std::pair make_digest_authentication_header( - const Request &req, const std::map &auth, - size_t cnonce_count, const std::string &cnonce, const std::string &username, - const std::string &password, bool is_proxy = false) { - std::string nc; - { - std::stringstream ss; - ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; - nc = ss.str(); - } - - std::string qop; - if (auth.find("qop") != auth.end()) { - qop = auth.at("qop"); - if (qop.find("auth-int") != std::string::npos) { - qop = "auth-int"; - } else if (qop.find("auth") != std::string::npos) { - qop = "auth"; - } else { - qop.clear(); - } - } - - std::string algo = "MD5"; - if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } - - std::string response; - { - auto H = algo == "SHA-256" ? detail::SHA_256 - : algo == "SHA-512" ? detail::SHA_512 - : detail::MD5; - - auto A1 = username + ":" + auth.at("realm") + ":" + password; - - auto A2 = req.method + ":" + req.path; - if (qop == "auth-int") { A2 += ":" + H(req.body); } - - if (qop.empty()) { - response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); - } else { - response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + - ":" + qop + ":" + H(A2)); - } - } - - auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; - - auto field = "Digest username=\"" + username + "\", realm=\"" + - auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + - "\", uri=\"" + req.path + "\", algorithm=" + algo + - (qop.empty() ? ", response=\"" - : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + - cnonce + "\", response=\"") + - response + "\"" + - (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); - - auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; - return std::make_pair(key, field); -} - -bool is_ssl_peer_could_be_closed(SSL *ssl, socket_t sock) { - detail::set_nonblocking(sock, true); - auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); - - char buf[1]; - return !SSL_peek(ssl, buf, 1) && - SSL_get_error(ssl, 0) == SSL_ERROR_ZERO_RETURN; -} - -#ifdef _WIN32 -// NOTE: This code came up with the following stackoverflow post: -// https://stackoverflow.com/questions/9507184/can-openssl-on-windows-use-the-system-certificate-store -bool load_system_certs_on_windows(X509_STORE *store) { - auto hStore = CertOpenSystemStoreW((HCRYPTPROV_LEGACY)NULL, L"ROOT"); - if (!hStore) { return false; } - - auto result = false; - PCCERT_CONTEXT pContext = NULL; - while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != - nullptr) { - auto encoded_cert = - static_cast(pContext->pbCertEncoded); - - auto x509 = d2i_X509(NULL, &encoded_cert, pContext->cbCertEncoded); - if (x509) { - X509_STORE_add_cert(store, x509); - X509_free(x509); - result = true; - } - } - - CertFreeCertificateContext(pContext); - CertCloseStore(hStore, 0); - - return result; -} -#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && TARGET_OS_MAC -template -using CFObjectPtr = - std::unique_ptr::type, void (*)(CFTypeRef)>; - -void cf_object_ptr_deleter(CFTypeRef obj) { - if (obj) { CFRelease(obj); } -} - -bool retrieve_certs_from_keychain(CFObjectPtr &certs) { - CFStringRef keys[] = {kSecClass, kSecMatchLimit, kSecReturnRef}; - CFTypeRef values[] = {kSecClassCertificate, kSecMatchLimitAll, - kCFBooleanTrue}; - - CFObjectPtr query( - CFDictionaryCreate(nullptr, reinterpret_cast(keys), values, - sizeof(keys) / sizeof(keys[0]), - &kCFTypeDictionaryKeyCallBacks, - &kCFTypeDictionaryValueCallBacks), - cf_object_ptr_deleter); - - if (!query) { return false; } - - CFTypeRef security_items = nullptr; - if (SecItemCopyMatching(query.get(), &security_items) != errSecSuccess || - CFArrayGetTypeID() != CFGetTypeID(security_items)) { - return false; - } - - certs.reset(reinterpret_cast(security_items)); - return true; -} - -bool retrieve_root_certs_from_keychain(CFObjectPtr &certs) { - CFArrayRef root_security_items = nullptr; - if (SecTrustCopyAnchorCertificates(&root_security_items) != errSecSuccess) { - return false; - } - - certs.reset(root_security_items); - return true; -} - -bool add_certs_to_x509_store(CFArrayRef certs, X509_STORE *store) { - auto result = false; - for (auto i = 0; i < CFArrayGetCount(certs); ++i) { - const auto cert = reinterpret_cast( - CFArrayGetValueAtIndex(certs, i)); - - if (SecCertificateGetTypeID() != CFGetTypeID(cert)) { continue; } - - CFDataRef cert_data = nullptr; - if (SecItemExport(cert, kSecFormatX509Cert, 0, nullptr, &cert_data) != - errSecSuccess) { - continue; - } - - CFObjectPtr cert_data_ptr(cert_data, cf_object_ptr_deleter); - - auto encoded_cert = static_cast( - CFDataGetBytePtr(cert_data_ptr.get())); - - auto x509 = - d2i_X509(NULL, &encoded_cert, CFDataGetLength(cert_data_ptr.get())); - - if (x509) { - X509_STORE_add_cert(store, x509); - X509_free(x509); - result = true; - } - } - - return result; -} - -bool load_system_certs_on_macos(X509_STORE *store) { - auto result = false; - CFObjectPtr certs(nullptr, cf_object_ptr_deleter); - if (retrieve_certs_from_keychain(certs) && certs) { - result = add_certs_to_x509_store(certs.get(), store); - } - - if (retrieve_root_certs_from_keychain(certs) && certs) { - result = add_certs_to_x509_store(certs.get(), store) || result; - } - - return result; -} -#endif // _WIN32 -#endif // CPPHTTPLIB_OPENSSL_SUPPORT - #ifdef _WIN32 class WSInit { public: @@ -3984,8 +4196,393 @@ bool is_field_content(const std::string &s) { bool is_field_value(const std::string &s) { return is_field_content(s); } } // namespace fields +} // namespace detail + +/* + * Group 2: detail namespace - SSL common utilities + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED +namespace detail { + +class SSLSocketStream final : public Stream { +public: + SSLSocketStream( + socket_t sock, tls::session_t session, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec = 0, + std::chrono::time_point start_time = + (std::chrono::steady_clock::time_point::min)()); + ~SSLSocketStream() override; + + bool is_readable() const override; + bool wait_readable() const override; + bool wait_writable() const override; + ssize_t read(char *ptr, size_t size) override; + ssize_t write(const char *ptr, size_t size) override; + void get_remote_ip_and_port(std::string &ip, int &port) const override; + void get_local_ip_and_port(std::string &ip, int &port) const override; + socket_t socket() const override; + time_t duration() const override; + +private: + socket_t sock_; + tls::session_t session_; + time_t read_timeout_sec_; + time_t read_timeout_usec_; + time_t write_timeout_sec_; + time_t write_timeout_usec_; + time_t max_timeout_msec_; + const std::chrono::time_point start_time_; +}; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +std::string message_digest(const std::string &s, const EVP_MD *algo) { + auto context = std::unique_ptr( + EVP_MD_CTX_new(), EVP_MD_CTX_free); + + unsigned int hash_length = 0; + unsigned char hash[EVP_MAX_MD_SIZE]; + + EVP_DigestInit_ex(context.get(), algo, nullptr); + EVP_DigestUpdate(context.get(), s.c_str(), s.size()); + EVP_DigestFinal_ex(context.get(), hash, &hash_length); + + std::stringstream ss; + for (auto i = 0u; i < hash_length; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + + return ss.str(); +} + +std::string MD5(const std::string &s) { + return message_digest(s, EVP_md5()); +} + +std::string SHA_256(const std::string &s) { + return message_digest(s, EVP_sha256()); +} + +std::string SHA_512(const std::string &s) { + return message_digest(s, EVP_sha512()); +} +#elif defined(CPPHTTPLIB_MBEDTLS_SUPPORT) +namespace { +template +std::string hash_to_hex(const unsigned char (&hash)[N]) { + std::stringstream ss; + for (size_t i = 0; i < N; ++i) { + ss << std::hex << std::setw(2) << std::setfill('0') + << static_cast(hash[i]); + } + return ss.str(); +} +} // namespace + +std::string MD5(const std::string &s) { + unsigned char hash[16]; +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_md5(reinterpret_cast(s.c_str()), s.size(), + hash); +#else + mbedtls_md5_ret(reinterpret_cast(s.c_str()), s.size(), + hash); +#endif + return hash_to_hex(hash); +} + +std::string SHA_256(const std::string &s) { + unsigned char hash[32]; +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_sha256(reinterpret_cast(s.c_str()), s.size(), + hash, 0); +#else + mbedtls_sha256_ret(reinterpret_cast(s.c_str()), + s.size(), hash, 0); +#endif + return hash_to_hex(hash); +} + +std::string SHA_512(const std::string &s) { + unsigned char hash[64]; +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_sha512(reinterpret_cast(s.c_str()), s.size(), + hash, 0); +#else + mbedtls_sha512_ret(reinterpret_cast(s.c_str()), + s.size(), hash, 0); +#endif + return hash_to_hex(hash); +} +#endif + +bool is_ip_address(const std::string &host) { + struct in_addr addr4; + struct in6_addr addr6; + return inet_pton(AF_INET, host.c_str(), &addr4) == 1 || + inet_pton(AF_INET6, host.c_str(), &addr6) == 1; +} + +template +bool process_server_socket_ssl( + const std::atomic &svr_sock, tls::session_t session, + socket_t sock, size_t keep_alive_max_count, time_t keep_alive_timeout_sec, + time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, T callback) { + return process_server_socket_core( + svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, + [&](bool close_connection, bool &connection_closed) { + SSLSocketStream strm(sock, session, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec); + return callback(strm, close_connection, connection_closed); + }); +} + +template +bool process_client_socket_ssl( + tls::session_t session, socket_t sock, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec, + std::chrono::time_point start_time, T callback) { + SSLSocketStream strm(sock, session, read_timeout_sec, read_timeout_usec, + write_timeout_sec, write_timeout_usec, max_timeout_msec, + start_time); + return callback(strm); +} + +std::pair make_digest_authentication_header( + const Request &req, const std::map &auth, + size_t cnonce_count, const std::string &cnonce, const std::string &username, + const std::string &password, bool is_proxy = false) { + std::string nc; + { + std::stringstream ss; + ss << std::setfill('0') << std::setw(8) << std::hex << cnonce_count; + nc = ss.str(); + } + + std::string qop; + if (auth.find("qop") != auth.end()) { + qop = auth.at("qop"); + if (qop.find("auth-int") != std::string::npos) { + qop = "auth-int"; + } else if (qop.find("auth") != std::string::npos) { + qop = "auth"; + } else { + qop.clear(); + } + } + + std::string algo = "MD5"; + if (auth.find("algorithm") != auth.end()) { algo = auth.at("algorithm"); } + + std::string response; + { + auto H = algo == "SHA-256" ? detail::SHA_256 + : algo == "SHA-512" ? detail::SHA_512 + : detail::MD5; + + auto A1 = username + ":" + auth.at("realm") + ":" + password; + + auto A2 = req.method + ":" + req.path; + if (qop == "auth-int") { A2 += ":" + H(req.body); } + + if (qop.empty()) { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + H(A2)); + } else { + response = H(H(A1) + ":" + auth.at("nonce") + ":" + nc + ":" + cnonce + + ":" + qop + ":" + H(A2)); + } + } + + auto opaque = (auth.find("opaque") != auth.end()) ? auth.at("opaque") : ""; + + auto field = "Digest username=\"" + username + "\", realm=\"" + + auth.at("realm") + "\", nonce=\"" + auth.at("nonce") + + "\", uri=\"" + req.path + "\", algorithm=" + algo + + (qop.empty() ? ", response=\"" + : ", qop=" + qop + ", nc=" + nc + ", cnonce=\"" + + cnonce + "\", response=\"") + + response + "\"" + + (opaque.empty() ? "" : ", opaque=\"" + opaque + "\""); + + auto key = is_proxy ? "Proxy-Authorization" : "Authorization"; + return std::make_pair(key, field); +} + +bool match_hostname(const std::string &pattern, + const std::string &hostname) { + // Exact match (case-insensitive) + if (detail::case_ignore::equal(hostname, pattern)) { return true; } + + // Split both pattern and hostname into components by '.' + std::vector pattern_components; + if (!pattern.empty()) { + split(pattern.data(), pattern.data() + pattern.size(), '.', + [&](const char *b, const char *e) { + pattern_components.emplace_back(b, e); + }); + } + + std::vector host_components; + if (!hostname.empty()) { + split(hostname.data(), hostname.data() + hostname.size(), '.', + [&](const char *b, const char *e) { + host_components.emplace_back(b, e); + }); + } + + // Component count must match + if (host_components.size() != pattern_components.size()) { return false; } + + // Compare each component with wildcard support + // Supports: "*" (full wildcard), "prefix*" (partial wildcard) + // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 + auto itr = pattern_components.begin(); + for (const auto &h : host_components) { + auto &p = *itr; + if (!detail::case_ignore::equal(p, h) && p != "*") { + bool partial_match = false; + if (!p.empty() && p[p.size() - 1] == '*') { + const auto prefix_length = p.size() - 1; + if (prefix_length == 0) { + partial_match = true; + } else if (h.size() >= prefix_length) { + partial_match = + std::equal(p.begin(), + p.begin() + static_cast( + prefix_length), + h.begin(), [](const char ca, const char cb) { + return detail::case_ignore::to_lower(ca) == + detail::case_ignore::to_lower(cb); + }); + } + } + if (!partial_match) { return false; } + } + ++itr; + } + + return true; +} + +#ifdef _WIN32 +// Verify certificate using Windows CertGetCertificateChain API. +// This provides real-time certificate validation with Windows Update +// integration, independent of the TLS backend (OpenSSL or MbedTLS). +bool verify_cert_with_windows_schannel( + const std::vector &der_cert, const std::string &hostname, + bool verify_hostname, unsigned long &out_error) { + if (der_cert.empty()) { return false; } + + out_error = 0; + + // Create Windows certificate context from DER data + auto cert_context = CertCreateCertificateContext( + X509_ASN_ENCODING | PKCS_7_ASN_ENCODING, der_cert.data(), + static_cast(der_cert.size())); + + if (!cert_context) { + out_error = GetLastError(); + return false; + } + + auto cert_guard = + scope_exit([&] { CertFreeCertificateContext(cert_context); }); + + // Setup chain parameters + CERT_CHAIN_PARA chain_para = {}; + chain_para.cbSize = sizeof(chain_para); + + // Build certificate chain with revocation checking + PCCERT_CHAIN_CONTEXT chain_context = nullptr; + auto chain_result = CertGetCertificateChain( + nullptr, cert_context, nullptr, cert_context->hCertStore, &chain_para, + CERT_CHAIN_CACHE_END_CERT | CERT_CHAIN_REVOCATION_CHECK_END_CERT | + CERT_CHAIN_REVOCATION_ACCUMULATIVE_TIMEOUT, + nullptr, &chain_context); + + if (!chain_result || !chain_context) { + out_error = GetLastError(); + return false; + } + + auto chain_guard = + scope_exit([&] { CertFreeCertificateChain(chain_context); }); + + // Check if chain has errors + if (chain_context->TrustStatus.dwErrorStatus != CERT_TRUST_NO_ERROR) { + out_error = chain_context->TrustStatus.dwErrorStatus; + return false; + } + + // Verify SSL policy + SSL_EXTRA_CERT_CHAIN_POLICY_PARA extra_policy_para = {}; + extra_policy_para.cbSize = sizeof(extra_policy_para); +#ifdef AUTHTYPE_SERVER + extra_policy_para.dwAuthType = AUTHTYPE_SERVER; +#endif + + std::wstring whost; + if (verify_hostname) { + whost = u8string_to_wstring(hostname.c_str()); + extra_policy_para.pwszServerName = const_cast(whost.c_str()); + } + + CERT_CHAIN_POLICY_PARA policy_para = {}; + policy_para.cbSize = sizeof(policy_para); +#ifdef CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS + policy_para.dwFlags = CERT_CHAIN_POLICY_IGNORE_ALL_REV_UNKNOWN_FLAGS; +#else + policy_para.dwFlags = 0; +#endif + policy_para.pvExtraPolicyPara = &extra_policy_para; + + CERT_CHAIN_POLICY_STATUS policy_status = {}; + policy_status.cbSize = sizeof(policy_status); + + if (!CertVerifyCertificateChainPolicy(CERT_CHAIN_POLICY_SSL, chain_context, + &policy_para, &policy_status)) { + out_error = GetLastError(); + return false; + } + + if (policy_status.dwError != 0) { + out_error = policy_status.dwError; + return false; + } + + return true; +} +#endif // _WIN32 } // namespace detail +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 3: httplib namespace - Non-SSL public API implementations + */ + +void default_socket_options(socket_t sock) { + detail::set_socket_opt(sock, SOL_SOCKET, +#ifdef SO_REUSEPORT + SO_REUSEPORT, +#else + SO_REUSEADDR, +#endif + 1); +} + +std::string get_bearer_token_auth(const Request &req) { + if (req.has_header("Authorization")) { + constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); + return req.get_header_value("Authorization") + .substr(bearer_header_prefix_len); + } + return ""; +} const char *status_message(int status) { switch (status) { @@ -4426,6 +5023,11 @@ make_bearer_token_authentication_header(const std::string &token, } // Request implementation +size_t Request::get_header_value_u64(const std::string &key, size_t def, + size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + bool Request::has_header(const std::string &key) const { return detail::has_header(headers, key); } @@ -4547,6 +5149,11 @@ size_t MultipartFormData::get_file_count(const std::string &key) const { } // Response implementation +size_t Response::get_header_value_u64(const std::string &key, size_t def, + size_t id) const { + return detail::get_header_value_u64(headers, key, def, id); +} + bool Response::has_header(const std::string &key) const { return headers.find(key) != headers.end(); } @@ -4662,6 +5269,12 @@ void Response::set_file_content(const std::string &path) { } // Result implementation +size_t Result::get_request_header_value_u64(const std::string &key, + size_t def, + size_t id) const { + return detail::get_header_value_u64(request_headers_, key, def, id); +} + bool Result::has_request_header(const std::string &key) const { return request_headers_.find(key) != request_headers_.end(); } @@ -4697,13 +5310,16 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) { if (!chunked) { // Content-Length based reading - if (bytes_read >= content_length) { + if (has_content_length && bytes_read >= content_length) { eof = true; return 0; } - auto remaining = content_length - bytes_read; - auto to_read = (std::min)(len, remaining); + auto to_read = len; + if (has_content_length) { + auto remaining = content_length - bytes_read; + to_read = (std::min)(len, remaining); + } auto n = stream->read(buf, to_read); if (n < 0) { @@ -4721,7 +5337,12 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) { } bytes_read += static_cast(n); - if (bytes_read >= content_length) { eof = true; } + if (has_content_length && bytes_read >= content_length) { eof = true; } + if (payload_max_length > 0 && bytes_read > payload_max_length) { + last_error = Error::ExceedMaxPayloadSize; + eof = true; + return -1; + } return n; } @@ -4745,9 +5366,83 @@ ssize_t detail::BodyReader::read(char *buf, size_t len) { } bytes_read += static_cast(n); + if (payload_max_length > 0 && bytes_read > payload_max_length) { + last_error = Error::ExceedMaxPayloadSize; + eof = true; + return -1; + } return n; } +// ThreadPool implementation +ThreadPool::ThreadPool(size_t n, size_t mqr) + : shutdown_(false), max_queued_requests_(mqr) { + threads_.reserve(n); + while (n) { + threads_.emplace_back(worker(*this)); + n--; + } +} + +bool ThreadPool::enqueue(std::function fn) { + { + std::unique_lock lock(mutex_); + if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { + return false; + } + jobs_.push_back(std::move(fn)); + } + + cond_.notify_one(); + return true; +} + +void ThreadPool::shutdown() { + // Stop all worker threads... + { + std::unique_lock lock(mutex_); + shutdown_ = true; + } + + cond_.notify_all(); + + // Join... + for (auto &t : threads_) { + t.join(); + } +} + +ThreadPool::worker::worker(ThreadPool &pool) : pool_(pool) {} + +void ThreadPool::worker::operator()() { + for (;;) { + std::function fn; + { + std::unique_lock lock(pool_.mutex_); + + pool_.cond_.wait(lock, + [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); + + if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } + + fn = pool_.jobs_.front(); + pool_.jobs_.pop_front(); + } + + assert(true == static_cast(fn)); + fn(); + } + +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ + !defined(LIBRESSL_VERSION_NUMBER) + OPENSSL_thread_stop(); +#endif +} + +/* + * Group 1 (continued): detail namespace - Stream implementations + */ + namespace detail { void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, @@ -5076,6 +5771,155 @@ bool check_and_write_headers(Stream &strm, Headers &headers, } // namespace detail +/* + * Group 2 (continued): detail namespace - SSLSocketStream implementation + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED +namespace detail { + +// SSL socket stream implementation +SSLSocketStream::SSLSocketStream( + socket_t sock, tls::session_t session, time_t read_timeout_sec, + time_t read_timeout_usec, time_t write_timeout_sec, + time_t write_timeout_usec, time_t max_timeout_msec, + std::chrono::time_point start_time) + : sock_(sock), session_(session), read_timeout_sec_(read_timeout_sec), + read_timeout_usec_(read_timeout_usec), + write_timeout_sec_(write_timeout_sec), + write_timeout_usec_(write_timeout_usec), + max_timeout_msec_(max_timeout_msec), start_time_(start_time) { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Clear AUTO_RETRY for proper non-blocking I/O timeout handling + // Note: create_session() also clears this, but SSLClient currently + // uses ssl_new() which does not. Until full TLS API migration is complete, + // we need to ensure AUTO_RETRY is cleared here regardless of how the + // SSL session was created. + SSL_clear_mode(static_cast(session), SSL_MODE_AUTO_RETRY); +#endif +} + +SSLSocketStream::~SSLSocketStream() = default; + +bool SSLSocketStream::is_readable() const { + return tls::pending(session_) > 0; +} + +bool SSLSocketStream::wait_readable() const { + if (max_timeout_msec_ <= 0) { + return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; + } + + time_t read_timeout_sec; + time_t read_timeout_usec; + calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, + read_timeout_usec_, read_timeout_sec, read_timeout_usec); + + return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; +} + +bool SSLSocketStream::wait_writable() const { + return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && + is_socket_alive(sock_) && !tls::is_peer_closed(session_, sock_); +} + +ssize_t SSLSocketStream::read(char *ptr, size_t size) { + if (tls::pending(session_) > 0) { + tls::TlsError err; + auto ret = tls::read(session_, ptr, size, err); + if (ret == 0 || err.code == tls::ErrorCode::PeerClosed) { + error_ = Error::ConnectionClosed; + } + return ret; + } else if (wait_readable()) { + tls::TlsError err; + auto ret = tls::read(session_, ptr, size, err); + if (ret < 0) { + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err.code == tls::ErrorCode::WantRead || + (err.code == tls::ErrorCode::SyscallError && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err.code == tls::ErrorCode::WantRead) { +#endif + if (tls::pending(session_) > 0) { + return tls::read(session_, ptr, size, err); + } else if (wait_readable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = tls::read(session_, ptr, size, err); + if (ret >= 0) { return ret; } + } else { + break; + } + } + assert(ret < 0); + } else if (ret == 0 || err.code == tls::ErrorCode::PeerClosed) { + error_ = Error::ConnectionClosed; + } + return ret; + } else { + error_ = Error::Timeout; + return -1; + } +} + +ssize_t SSLSocketStream::write(const char *ptr, size_t size) { + if (wait_writable()) { + auto handle_size = + std::min(size, (std::numeric_limits::max)()); + + tls::TlsError err; + auto ret = tls::write(session_, ptr, handle_size, err); + if (ret < 0) { + auto n = 1000; +#ifdef _WIN32 + while (--n >= 0 && (err.code == tls::ErrorCode::WantWrite || + (err.code == tls::ErrorCode::SyscallError && + WSAGetLastError() == WSAETIMEDOUT))) { +#else + while (--n >= 0 && err.code == tls::ErrorCode::WantWrite) { +#endif + if (wait_writable()) { + std::this_thread::sleep_for(std::chrono::microseconds{10}); + ret = tls::write(session_, ptr, handle_size, err); + if (ret >= 0) { return ret; } + } else { + break; + } + } + assert(ret < 0); + } + return ret; + } + return -1; +} + +void SSLSocketStream::get_remote_ip_and_port(std::string &ip, + int &port) const { + detail::get_remote_ip_and_port(sock_, ip, port); +} + +void SSLSocketStream::get_local_ip_and_port(std::string &ip, + int &port) const { + detail::get_local_ip_and_port(sock_, ip, port); +} + +socket_t SSLSocketStream::socket() const { return sock_; } + +time_t SSLSocketStream::duration() const { + return std::chrono::duration_cast( + std::chrono::steady_clock::now() - start_time_) + .count(); +} + +} // namespace detail +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 4: Server implementation + */ + // HTTP server implementation Server::Server() : new_task_queue( @@ -5677,36 +6521,40 @@ bool Server::read_content_core( // are true (no Transfer-Encoding and no Content-Length), then the message // body length is zero (no message body is present). // - // For non-SSL builds, peek into the socket to detect clients that send a - // body without a Content-Length header (raw HTTP over TCP). If there is - // pending data that exceeds the configured payload limit, treat this as an - // oversized request and fail early (causing connection close). For SSL - // builds we cannot reliably peek the decrypted application bytes, so keep - // the original behaviour. -#if !defined(CPPHTTPLIB_OPENSSL_SUPPORT) + // For non-SSL builds, detect clients that send a body without a + // Content-Length header (raw HTTP over TCP). Check both the stream's + // internal read buffer (data already read from the socket during header + // parsing) and the socket itself for pending data. If data is found and + // exceeds the configured payload limit, reject with 413. + // For SSL builds we cannot reliably peek the decrypted application bytes, + // so keep the original behaviour. +#if !defined(CPPHTTPLIB_SSL_ENABLED) if (!req.has_header("Content-Length") && !detail::is_chunked_transfer_encoding(req.headers)) { - // Only peek if payload_max_length is set to a finite value + // Only check if payload_max_length is set to a finite value if (payload_max_length_ > 0 && payload_max_length_ < (std::numeric_limits::max)()) { - socket_t s = strm.socket(); - if (s != INVALID_SOCKET) { - // Peek to check if there is any pending data - char peekbuf[1]; - ssize_t n = ::recv(s, peekbuf, 1, MSG_PEEK); - if (n > 0) { - // There is data, so read it with payload limit enforcement - auto result = detail::read_content_without_length( - strm, payload_max_length_, out); - if (result == detail::ReadContentResult::PayloadTooLarge) { - res.status = StatusCode::PayloadTooLarge_413; - return false; - } else if (result != detail::ReadContentResult::Success) { - return false; - } - return true; + // Check if there is data already buffered in the stream (read during + // header parsing) or pending on the socket. Use a non-blocking socket + // check to avoid deadlock when the client sends no body. + bool has_data = strm.is_readable(); + if (!has_data) { + socket_t s = strm.socket(); + if (s != INVALID_SOCKET) { + has_data = detail::select_read(s, 0, 0) > 0; } } + if (has_data) { + auto result = + detail::read_content_without_length(strm, payload_max_length_, out); + if (result == detail::ReadContentResult::PayloadTooLarge) { + res.status = StatusCode::PayloadTooLarge_413; + return false; + } else if (result != detail::ReadContentResult::Success) { + return false; + } + return true; + } } return true; } @@ -5815,8 +6663,10 @@ bool Server::check_if_not_modified(const Request &req, Response &res, // simplified implementation requires exact matches. auto ret = detail::split_find(val.data(), val.data() + val.size(), ',', [&](const char *b, const char *e) { - return std::equal(b, e, "*") || - std::equal(b, e, etag.begin()); + auto seg_len = static_cast(e - b); + return (seg_len == 1 && *b == '*') || + (seg_len == etag.size() && + std::equal(b, e, etag.begin())); }); if (ret) { @@ -6518,6 +7368,9 @@ void Server::output_error_log(const Error &err, } } +/* + * Group 5: ClientImpl and Client (Universal) implementation + */ // HTTP client implementation ClientImpl::ClientImpl(const std::string &host) : ClientImpl(host, 80, std::string(), std::string()) {} @@ -6561,10 +7414,6 @@ void ClientImpl::copy_settings(const ClientImpl &rhs) { basic_auth_username_ = rhs.basic_auth_username_; basic_auth_password_ = rhs.basic_auth_password_; bearer_token_auth_token_ = rhs.bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - digest_auth_username_ = rhs.digest_auth_username_; - digest_auth_password_ = rhs.digest_auth_password_; -#endif keep_alive_ = rhs.keep_alive_; follow_location_ = rhs.follow_location_; path_encode_ = rhs.path_encode_; @@ -6574,28 +7423,27 @@ void ClientImpl::copy_settings(const ClientImpl &rhs) { socket_options_ = rhs.socket_options_; compress_ = rhs.compress_; decompress_ = rhs.decompress_; + payload_max_length_ = rhs.payload_max_length_; + has_payload_max_length_ = rhs.has_payload_max_length_; interface_ = rhs.interface_; proxy_host_ = rhs.proxy_host_; proxy_port_ = rhs.proxy_port_; proxy_basic_auth_username_ = rhs.proxy_basic_auth_username_; proxy_basic_auth_password_ = rhs.proxy_basic_auth_password_; proxy_bearer_token_auth_token_ = rhs.proxy_bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; - proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; -#endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - ca_cert_file_path_ = rhs.ca_cert_file_path_; - ca_cert_dir_path_ = rhs.ca_cert_dir_path_; - ca_cert_store_ = rhs.ca_cert_store_; -#endif -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - server_certificate_verification_ = rhs.server_certificate_verification_; - server_hostname_verification_ = rhs.server_hostname_verification_; - server_certificate_verifier_ = rhs.server_certificate_verifier_; -#endif logger_ = rhs.logger_; error_logger_ = rhs.error_logger_; + +#ifdef CPPHTTPLIB_SSL_ENABLED + digest_auth_username_ = rhs.digest_auth_username_; + digest_auth_password_ = rhs.digest_auth_password_; + proxy_digest_auth_username_ = rhs.proxy_digest_auth_username_; + proxy_digest_auth_password_ = rhs.proxy_digest_auth_password_; + ca_cert_file_path_ = rhs.ca_cert_file_path_; + ca_cert_dir_path_ = rhs.ca_cert_dir_path_; + server_certificate_verification_ = rhs.server_certificate_verification_; + server_hostname_verification_ = rhs.server_hostname_verification_; +#endif } socket_t ClientImpl::create_client_socket(Error &error) const { @@ -6631,22 +7479,6 @@ bool ClientImpl::ensure_socket_connection(Socket &socket, Error &error) { return create_and_connect_socket(socket, error); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -bool SSLClient::ensure_socket_connection(Socket &socket, Error &error) { - if (!ClientImpl::ensure_socket_connection(socket, error)) { return false; } - - if (!proxy_host_.empty() && proxy_port_ != -1) { return true; } - - if (!initialize_ssl(socket, error)) { - shutdown_socket(socket); - close_socket(socket); - return false; - } - - return true; -} -#endif - void ClientImpl::shutdown_ssl(Socket & /*socket*/, bool /*shutdown_gracefully*/) { // If there are any requests in flight from threads other than us, then it's @@ -6671,9 +7503,10 @@ void ClientImpl::close_socket(Socket &socket) { socket_requests_are_from_thread_ == std::this_thread::get_id()); // It is also a bug if this happens while SSL is still active -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED assert(socket.ssl == nullptr); #endif + if (socket.sock == INVALID_SOCKET) { return; } detail::close_socket(socket.sock); socket.sock = INVALID_SOCKET; @@ -6722,6 +7555,8 @@ bool ClientImpl::send(Request &req, Response &res, Error &error) { if (error == Error::SSLPeerCouldBeClosed_) { assert(!ret); ret = send_(req, res, error); + // If still failing with SSLPeerCouldBeClosed_, convert to Read error + if (error == Error::SSLPeerCouldBeClosed_) { error = Error::Read; } } return ret; } @@ -6739,9 +7574,9 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { if (socket_.is_open()) { is_alive = detail::is_socket_alive(socket_.sock); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (is_alive && is_ssl()) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (tls::is_peer_closed(socket_.ssl, socket_.sock)) { is_alive = false; } } @@ -6765,7 +7600,7 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { return false; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED // TODO: refactoring if (is_ssl()) { auto &scli = static_cast(*this); @@ -6847,9 +7682,9 @@ Result ClientImpl::send_(Request &&req) { auto res = detail::make_unique(); auto error = Error::Success; auto ret = send(req, *res, error); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers), - last_ssl_error_, last_openssl_error_}; + last_ssl_error_, last_backend_error_}; #else return Result{ret ? std::move(res) : nullptr, error, std::move(req.headers)}; #endif @@ -6926,9 +7761,9 @@ ClientImpl::open_stream(const std::string &method, const std::string &path, auto is_alive = false; if (socket_.is_open()) { is_alive = detail::is_socket_alive(socket_.sock); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (is_alive && is_ssl()) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (tls::is_peer_closed(socket_.ssl, socket_.sock)) { is_alive = false; } } @@ -6946,7 +7781,7 @@ ClientImpl::open_stream(const std::string &method, const std::string &path, return handle; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (is_ssl()) { auto &scli = static_cast(*this); if (!proxy_host_.empty() && proxy_port_ != -1) { @@ -6962,11 +7797,12 @@ ClientImpl::open_stream(const std::string &method, const std::string &path, transfer_socket_ownership_to_handle(handle); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (is_ssl() && handle.connection_->ssl) { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl() && handle.connection_->session) { handle.socket_stream_ = detail::make_unique( - handle.connection_->sock, handle.connection_->ssl, read_timeout_sec_, - read_timeout_usec_, write_timeout_sec_, write_timeout_usec_); + handle.connection_->sock, handle.connection_->session, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_); } else { handle.socket_stream_ = detail::make_unique( handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, @@ -7016,9 +7852,11 @@ ClientImpl::open_stream(const std::string &method, const std::string &path, } handle.body_reader_.stream = handle.stream_; + handle.body_reader_.payload_max_length = payload_max_length_; auto content_length_str = handle.response->get_header_value("Content-Length"); if (!content_length_str.empty()) { + handle.body_reader_.has_content_length = true; handle.body_reader_.content_length = static_cast(std::stoull(content_length_str)); } @@ -7066,6 +7904,7 @@ ssize_t ClientImpl::StreamHandle::read_with_decompression(char *buf, auto to_copy = (std::min)(len, available); std::memcpy(buf, decompress_buffer_.data() + decompress_offset_, to_copy); decompress_offset_ += to_copy; + decompressed_bytes_read_ += to_copy; return static_cast(to_copy); } @@ -7081,12 +7920,16 @@ ssize_t ClientImpl::StreamHandle::read_with_decompression(char *buf, if (n <= 0) { return n; } - bool decompress_ok = - decompressor_->decompress(compressed_buf, static_cast(n), - [this](const char *data, size_t data_len) { - decompress_buffer_.append(data, data_len); - return true; - }); + bool decompress_ok = decompressor_->decompress( + compressed_buf, static_cast(n), + [this](const char *data, size_t data_len) { + decompress_buffer_.append(data, data_len); + auto limit = body_reader_.payload_max_length; + if (decompressed_bytes_read_ + decompress_buffer_.size() > limit) { + return false; + } + return true; + }); if (!decompress_ok) { body_reader_.last_error = Error::Read; @@ -7099,6 +7942,7 @@ ssize_t ClientImpl::StreamHandle::read_with_decompression(char *buf, auto to_copy = (std::min)(len, decompress_buffer_.size()); std::memcpy(buf, decompress_buffer_.data(), to_copy); decompress_offset_ = to_copy; + decompressed_bytes_read_ += to_copy; return static_cast(to_copy); } @@ -7121,7 +7965,6 @@ void ClientImpl::StreamHandle::parse_trailers_if_needed() { } } -// Inline method implementations for `ChunkedDecoder`. namespace detail { ChunkedDecoder::ChunkedDecoder(Stream &s) : strm(s) {} @@ -7185,8 +8028,8 @@ bool ChunkedDecoder::parse_trailers_into(Headers &dest, void ClientImpl::transfer_socket_ownership_to_handle(StreamHandle &handle) { handle.connection_->sock = socket_.sock; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - handle.connection_->ssl = socket_.ssl; +#ifdef CPPHTTPLIB_SSL_ENABLED + handle.connection_->session = socket_.ssl; socket_.ssl = nullptr; #endif socket_.sock = INVALID_SOCKET; @@ -7239,7 +8082,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, ret = redirect(req, res, error); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if ((res.status == StatusCode::Unauthorized_401 || res.status == StatusCode::ProxyAuthenticationRequired_407) && req.authorization_count_ < 5) { @@ -7343,7 +8186,7 @@ bool ClientImpl::create_redirect_client( // Create appropriate client type and handle redirect if (need_ssl) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED // Create SSL client for HTTPS redirect SSLClient redirect_client(host, port); @@ -7363,9 +8206,10 @@ bool ClientImpl::create_redirect_client( server_hostname_verification_); } - // Handle CA certificate store and paths if available - if (ca_cert_store_ && X509_STORE_up_ref(ca_cert_store_)) { - redirect_client.set_ca_cert_store(ca_cert_store_); + // Transfer CA certificate to redirect client + if (!ca_cert_pem_.empty()) { + redirect_client.load_ca_cert_store(ca_cert_pem_.c_str(), + ca_cert_pem_.size()); } if (!ca_cert_file_path_.empty()) { redirect_client.set_ca_cert_path(ca_cert_file_path_, ca_cert_dir_path_); @@ -7418,7 +8262,7 @@ void ClientImpl::setup_redirect_client(ClientType &client) { if (!bearer_token_auth_token_.empty()) { client.set_bearer_token_auth(bearer_token_auth_token_); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (!digest_auth_username_.empty()) { client.set_digest_auth(digest_auth_username_, digest_auth_password_); } @@ -7438,7 +8282,7 @@ void ClientImpl::setup_redirect_client(ClientType &client) { if (!proxy_bearer_token_auth_token_.empty()) { client.set_proxy_bearer_token_auth(proxy_bearer_token_auth_token_); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (!proxy_digest_auth_username_.empty()) { client.set_proxy_digest_auth(proxy_digest_auth_username_, proxy_digest_auth_password_); @@ -7809,9 +8653,9 @@ Result ClientImpl::send_with_content_provider_and_receiver( std::move(content_provider_without_length), content_type, std::move(content_receiver), error); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED return Result{std::move(res), error, std::move(req.headers), last_ssl_error_, - last_openssl_error_}; + last_backend_error_}; #else return Result{std::move(res), error, std::move(req.headers)}; #endif @@ -7851,11 +8695,11 @@ bool ClientImpl::process_request(Stream &strm, Request &req, auto write_request_success = write_request(strm, req, close_connection, error, expect_100_continue); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (is_ssl()) { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (is_ssl() && !expect_100_continue) { auto is_proxy_enabled = !proxy_host_.empty() && proxy_port_ != -1; if (!is_proxy_enabled) { - if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + if (tls::is_peer_closed(socket_.ssl, socket_.sock)) { error = Error::SSLPeerCouldBeClosed_; output_error_log(error, &req); return false; @@ -7937,6 +8781,11 @@ bool ClientImpl::process_request(Stream &strm, Request &req, [&](const char *buf, size_t n, size_t /*off*/, size_t /*len*/) { assert(res.body.size() + n <= res.body.max_size()); + if (payload_max_length_ > 0 && + (res.body.size() >= payload_max_length_ || + n > payload_max_length_ - res.body.size())) { + return false; + } res.body.append(buf, n); return true; }); @@ -7965,9 +8814,12 @@ bool ClientImpl::process_request(Stream &strm, Request &req, if (res.status != StatusCode::NotModified_304) { int dummy_status; - if (!detail::read_content(strm, res, (std::numeric_limits::max)(), - dummy_status, std::move(progress), - std::move(out), decompress_)) { + auto max_length = (!has_payload_max_length_ && req.content_receiver) + ? (std::numeric_limits::max)() + : payload_max_length_; + if (!detail::read_content(strm, res, max_length, dummy_status, + std::move(progress), std::move(out), + decompress_)) { if (error != Error::Canceled) { error = Error::Read; } output_error_log(error, &req); return false; @@ -8878,14 +9730,6 @@ void ClientImpl::set_bearer_token_auth(const std::string &token) { bearer_token_auth_token_ = token; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void ClientImpl::set_digest_auth(const std::string &username, - const std::string &password) { - digest_auth_username_ = username; - digest_auth_password_ = password; -} -#endif - void ClientImpl::set_keep_alive(bool on) { keep_alive_ = on; } void ClientImpl::set_follow_location(bool on) { follow_location_ = on; } @@ -8922,6 +9766,11 @@ void ClientImpl::set_compress(bool on) { compress_ = on; } void ClientImpl::set_decompress(bool on) { decompress_ = on; } +void ClientImpl::set_payload_max_length(size_t length) { + payload_max_length_ = length; + has_payload_max_length_ = true; +} + void ClientImpl::set_interface(const std::string &intf) { interface_ = intf; } @@ -8941,11 +9790,11 @@ void ClientImpl::set_proxy_bearer_token_auth(const std::string &token) { proxy_bearer_token_auth_token_ = token; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void ClientImpl::set_proxy_digest_auth(const std::string &username, - const std::string &password) { - proxy_digest_auth_username_ = username; - proxy_digest_auth_password_ = password; +#ifdef CPPHTTPLIB_SSL_ENABLED +void ClientImpl::set_digest_auth(const std::string &username, + const std::string &password) { + digest_auth_username_ = username; + digest_auth_password_ = password; } void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, @@ -8954,12 +9803,23 @@ void ClientImpl::set_ca_cert_path(const std::string &ca_cert_file_path, ca_cert_dir_path_ = ca_cert_dir_path; } -void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store && ca_cert_store != ca_cert_store_) { - ca_cert_store_ = ca_cert_store; - } +void ClientImpl::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + proxy_digest_auth_username_ = username; + proxy_digest_auth_password_ = password; } +void ClientImpl::enable_server_certificate_verification(bool enabled) { + server_certificate_verification_ = enabled; +} + +void ClientImpl::enable_server_hostname_verification(bool enabled) { + server_hostname_verification_ = enabled; +} +#endif + +// ClientImpl::set_ca_cert_store is defined after TLS namespace (uses helpers) +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, std::size_t size) const { auto mem = BIO_new_mem_buf(ca_cert, static_cast(size)); @@ -8984,17 +9844,9 @@ X509_STORE *ClientImpl::create_ca_cert_store(const char *ca_cert, return cts; } -void ClientImpl::enable_server_certificate_verification(bool enabled) { - server_certificate_verification_ = enabled; -} - -void ClientImpl::enable_server_hostname_verification(bool enabled) { - server_hostname_verification_ = enabled; -} - void ClientImpl::set_server_certificate_verifier( - std::function verifier) { - server_certificate_verifier_ = verifier; + std::function /*verifier*/) { + // Base implementation does nothing - SSLClient overrides this } #endif @@ -9007,958 +9859,24 @@ void ClientImpl::set_error_logger(ErrorLogger error_logger) { } /* - * SSL Implementation + * SSL/TLS Common Implementation */ -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -namespace detail { -bool is_ip_address(const std::string &host) { - struct in_addr addr4; - struct in6_addr addr6; - return inet_pton(AF_INET, host.c_str(), &addr4) == 1 || - inet_pton(AF_INET6, host.c_str(), &addr6) == 1; -} - -template -SSL *ssl_new(socket_t sock, SSL_CTX *ctx, std::mutex &ctx_mutex, - U SSL_connect_or_accept, V setup) { - SSL *ssl = nullptr; - { - std::lock_guard guard(ctx_mutex); - ssl = SSL_new(ctx); - } - - if (ssl) { - set_nonblocking(sock, true); - auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); - BIO_set_nbio(bio, 1); - SSL_set_bio(ssl, bio, bio); - - if (!setup(ssl) || SSL_connect_or_accept(ssl) != 1) { - SSL_shutdown(ssl); - { - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); - } - set_nonblocking(sock, false); - return nullptr; - } - BIO_set_nbio(bio, 0); - set_nonblocking(sock, false); - } - - return ssl; -} - -void ssl_delete(std::mutex &ctx_mutex, SSL *ssl, socket_t sock, - bool shutdown_gracefully) { - // sometimes we may want to skip this to try to avoid SIGPIPE if we know - // the remote has closed the network connection - // Note that it is not always possible to avoid SIGPIPE, this is merely a - // best-efforts. - if (shutdown_gracefully) { - (void)(sock); - // SSL_shutdown() returns 0 on first call (indicating close_notify alert - // sent) and 1 on subsequent call (indicating close_notify alert received) - if (SSL_shutdown(ssl) == 0) { - // Expected to return 1, but even if it doesn't, we free ssl - SSL_shutdown(ssl); - } - } - - std::lock_guard guard(ctx_mutex); - SSL_free(ssl); -} - -template -bool ssl_connect_or_accept_nonblocking(socket_t sock, SSL *ssl, - U ssl_connect_or_accept, - time_t timeout_sec, time_t timeout_usec, - int *ssl_error) { - auto res = 0; - while ((res = ssl_connect_or_accept(ssl)) != 1) { - auto err = SSL_get_error(ssl, res); - switch (err) { - case SSL_ERROR_WANT_READ: - if (select_read(sock, timeout_sec, timeout_usec) > 0) { continue; } - break; - case SSL_ERROR_WANT_WRITE: - if (select_write(sock, timeout_sec, timeout_usec) > 0) { continue; } - break; - default: break; - } - if (ssl_error) { *ssl_error = err; } - return false; - } - return true; -} - -template -bool process_server_socket_ssl( - const std::atomic &svr_sock, SSL *ssl, socket_t sock, - size_t keep_alive_max_count, time_t keep_alive_timeout_sec, - time_t read_timeout_sec, time_t read_timeout_usec, time_t write_timeout_sec, - time_t write_timeout_usec, T callback) { - return process_server_socket_core( - svr_sock, sock, keep_alive_max_count, keep_alive_timeout_sec, - [&](bool close_connection, bool &connection_closed) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec); - return callback(strm, close_connection, connection_closed); - }); -} - -template -bool process_client_socket_ssl( - SSL *ssl, socket_t sock, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec, - time_t max_timeout_msec, - std::chrono::time_point start_time, T callback) { - SSLSocketStream strm(sock, ssl, read_timeout_sec, read_timeout_usec, - write_timeout_sec, write_timeout_usec, max_timeout_msec, - start_time); - return callback(strm); -} - -// SSL socket stream implementation -SSLSocketStream::SSLSocketStream( - socket_t sock, SSL *ssl, time_t read_timeout_sec, time_t read_timeout_usec, - time_t write_timeout_sec, time_t write_timeout_usec, - time_t max_timeout_msec, - std::chrono::time_point start_time) - : sock_(sock), ssl_(ssl), read_timeout_sec_(read_timeout_sec), - read_timeout_usec_(read_timeout_usec), - write_timeout_sec_(write_timeout_sec), - write_timeout_usec_(write_timeout_usec), - max_timeout_msec_(max_timeout_msec), start_time_(start_time) { - SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); -} - -SSLSocketStream::~SSLSocketStream() = default; - -bool SSLSocketStream::is_readable() const { - return SSL_pending(ssl_) > 0; -} - -bool SSLSocketStream::wait_readable() const { - if (max_timeout_msec_ <= 0) { - return select_read(sock_, read_timeout_sec_, read_timeout_usec_) > 0; - } - - time_t read_timeout_sec; - time_t read_timeout_usec; - calc_actual_timeout(max_timeout_msec_, duration(), read_timeout_sec_, - read_timeout_usec_, read_timeout_sec, read_timeout_usec); - - return select_read(sock_, read_timeout_sec, read_timeout_usec) > 0; -} - -bool SSLSocketStream::wait_writable() const { - return select_write(sock_, write_timeout_sec_, write_timeout_usec_) > 0 && - is_socket_alive(sock_) && !is_ssl_peer_could_be_closed(ssl_, sock_); -} - -ssize_t SSLSocketStream::read(char *ptr, size_t size) { - if (SSL_pending(ssl_) > 0) { - auto ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret == 0) { error_ = Error::ConnectionClosed; } - return ret; - } else if (wait_readable()) { - auto ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); - auto n = 1000; -#ifdef _WIN32 - while (--n >= 0 && (err == SSL_ERROR_WANT_READ || - (err == SSL_ERROR_SYSCALL && - WSAGetLastError() == WSAETIMEDOUT))) { -#else - while (--n >= 0 && err == SSL_ERROR_WANT_READ) { -#endif - if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); - } else if (wait_readable()) { - std::this_thread::sleep_for(std::chrono::microseconds{10}); - ret = SSL_read(ssl_, ptr, static_cast(size)); - if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); - } else { - break; - } - } - assert(ret < 0); - } else if (ret == 0) { - error_ = Error::ConnectionClosed; - } - return ret; - } else { - error_ = Error::Timeout; - return -1; - } -} - -ssize_t SSLSocketStream::write(const char *ptr, size_t size) { - if (wait_writable()) { - auto handle_size = static_cast( - std::min(size, (std::numeric_limits::max)())); - - auto ret = SSL_write(ssl_, ptr, static_cast(handle_size)); - if (ret < 0) { - auto err = SSL_get_error(ssl_, ret); - auto n = 1000; -#ifdef _WIN32 - while (--n >= 0 && (err == SSL_ERROR_WANT_WRITE || - (err == SSL_ERROR_SYSCALL && - WSAGetLastError() == WSAETIMEDOUT))) { -#else - while (--n >= 0 && err == SSL_ERROR_WANT_WRITE) { -#endif - if (wait_writable()) { - std::this_thread::sleep_for(std::chrono::microseconds{10}); - ret = SSL_write(ssl_, ptr, static_cast(handle_size)); - if (ret >= 0) { return ret; } - err = SSL_get_error(ssl_, ret); - } else { - break; - } - } - assert(ret < 0); - } - return ret; - } - return -1; -} - -void SSLSocketStream::get_remote_ip_and_port(std::string &ip, - int &port) const { - detail::get_remote_ip_and_port(sock_, ip, port); -} - -void SSLSocketStream::get_local_ip_and_port(std::string &ip, - int &port) const { - detail::get_local_ip_and_port(sock_, ip, port); -} - -socket_t SSLSocketStream::socket() const { return sock_; } - -time_t SSLSocketStream::duration() const { - return std::chrono::duration_cast( - std::chrono::steady_clock::now() - start_time_) - .count(); -} - -} // namespace detail - -// SSL HTTP server implementation -SSLServer::SSLServer(const char *cert_path, const char *private_key_path, - const char *client_ca_cert_file_path, - const char *client_ca_cert_dir_path, - const char *private_key_password) { - ctx_ = SSL_CTX_new(TLS_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); - - if (private_key_password != nullptr && (private_key_password[0] != '\0')) { - SSL_CTX_set_default_passwd_cb_userdata( - ctx_, - reinterpret_cast(const_cast(private_key_password))); - } - - if (SSL_CTX_use_certificate_chain_file(ctx_, cert_path) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, private_key_path, SSL_FILETYPE_PEM) != - 1 || - SSL_CTX_check_private_key(ctx_) != 1) { - last_ssl_error_ = static_cast(ERR_get_error()); - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_file_path || client_ca_cert_dir_path) { - SSL_CTX_load_verify_locations(ctx_, client_ca_cert_file_path, - client_ca_cert_dir_path); - - // Set client CA list to be sent to clients during TLS handshake - if (client_ca_cert_file_path) { - auto ca_list = SSL_load_client_CA_file(client_ca_cert_file_path); - if (ca_list != nullptr) { - SSL_CTX_set_client_CA_list(ctx_, ca_list); - } else { - // Failed to load client CA list, but we continue since - // SSL_CTX_load_verify_locations already succeeded and - // certificate verification will still work - last_ssl_error_ = static_cast(ERR_get_error()); - } - } - - SSL_CTX_set_verify( - ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); - } - } -} - -SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store) { - ctx_ = SSL_CTX_new(TLS_server_method()); - - if (ctx_) { - SSL_CTX_set_options(ctx_, - SSL_OP_NO_COMPRESSION | - SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); - - SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); - - if (SSL_CTX_use_certificate(ctx_, cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, private_key) != 1) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } else if (client_ca_cert_store) { - SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); - - // Extract CA names from the store and set them as the client CA list - auto ca_list = extract_ca_names_from_x509_store(client_ca_cert_store); - if (ca_list) { - SSL_CTX_set_client_CA_list(ctx_, ca_list); - } else { - // Failed to extract CA names, record the error - last_ssl_error_ = static_cast(ERR_get_error()); - } - - SSL_CTX_set_verify( - ctx_, SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT, nullptr); - } - } -} - -SSLServer::SSLServer( - const std::function &setup_ssl_ctx_callback) { - ctx_ = SSL_CTX_new(TLS_method()); - if (ctx_) { - if (!setup_ssl_ctx_callback(*ctx_)) { - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } - } -} - -SSLServer::~SSLServer() { - if (ctx_) { SSL_CTX_free(ctx_); } -} - -bool SSLServer::is_valid() const { return ctx_; } - -SSL_CTX *SSLServer::ssl_context() const { return ctx_; } - -void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store) { - - std::lock_guard guard(ctx_mutex_); - - SSL_CTX_use_certificate(ctx_, cert); - SSL_CTX_use_PrivateKey(ctx_, private_key); - - if (client_ca_cert_store != nullptr) { - SSL_CTX_set_cert_store(ctx_, client_ca_cert_store); - } -} - -bool SSLServer::process_and_close_socket(socket_t sock) { - auto ssl = detail::ssl_new( - sock, ctx_, ctx_mutex_, - [&](SSL *ssl2) { - return detail::ssl_connect_or_accept_nonblocking( - sock, ssl2, SSL_accept, read_timeout_sec_, read_timeout_usec_, - &last_ssl_error_); - }, - [](SSL * /*ssl2*/) { return true; }); - - auto ret = false; - if (ssl) { - std::string remote_addr; - int remote_port = 0; - detail::get_remote_ip_and_port(sock, remote_addr, remote_port); - - std::string local_addr; - int local_port = 0; - detail::get_local_ip_and_port(sock, local_addr, local_port); - - ret = detail::process_server_socket_ssl( - svr_sock_, ssl, sock, keep_alive_max_count_, keep_alive_timeout_sec_, - read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, - write_timeout_usec_, - [&](Stream &strm, bool close_connection, bool &connection_closed) { - return process_request(strm, remote_addr, remote_port, local_addr, - local_port, close_connection, - connection_closed, - [&](Request &req) { req.ssl = ssl; }); - }); - - // Shutdown gracefully if the result seemed successful, non-gracefully if - // the connection appeared to be closed. - const bool shutdown_gracefully = ret; - detail::ssl_delete(ctx_mutex_, ssl, sock, shutdown_gracefully); - } - - detail::shutdown_socket(sock); - detail::close_socket(sock); - return ret; -} - -STACK_OF(X509_NAME) * SSLServer::extract_ca_names_from_x509_store( - X509_STORE *store) { - if (!store) { return nullptr; } - - auto ca_list = sk_X509_NAME_new_null(); - if (!ca_list) { return nullptr; } - - // Get all objects from the store - auto objs = X509_STORE_get0_objects(store); - if (!objs) { - sk_X509_NAME_free(ca_list); - return nullptr; - } - - // Iterate through objects and extract certificate subject names - for (int i = 0; i < sk_X509_OBJECT_num(objs); i++) { - auto obj = sk_X509_OBJECT_value(objs, i); - if (X509_OBJECT_get_type(obj) == X509_LU_X509) { - auto cert = X509_OBJECT_get0_X509(obj); - if (cert) { - auto subject = X509_get_subject_name(cert); - if (subject) { - auto name_dup = X509_NAME_dup(subject); - if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); } - } - } - } - } - - // If no names were extracted, free the list and return nullptr - if (sk_X509_NAME_num(ca_list) == 0) { - sk_X509_NAME_free(ca_list); - return nullptr; - } - - return ca_list; -} - -// SSL HTTP client implementation -SSLClient::SSLClient(const std::string &host) - : SSLClient(host, 443, std::string(), std::string()) {} - -SSLClient::SSLClient(const std::string &host, int port) - : SSLClient(host, port, std::string(), std::string()) {} - -SSLClient::SSLClient(const std::string &host, int port, - const std::string &client_cert_path, - const std::string &client_key_path, - const std::string &private_key_password) - : ClientImpl(host, port, client_cert_path, client_key_path) { - ctx_ = SSL_CTX_new(TLS_client_method()); - - SSL_CTX_set_min_proto_version(ctx_, TLS1_2_VERSION); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(b, e); - }); - - if (!client_cert_path.empty() && !client_key_path.empty()) { - if (!private_key_password.empty()) { - SSL_CTX_set_default_passwd_cb_userdata( - ctx_, reinterpret_cast( - const_cast(private_key_password.c_str()))); - } - - if (SSL_CTX_use_certificate_file(ctx_, client_cert_path.c_str(), - SSL_FILETYPE_PEM) != 1 || - SSL_CTX_use_PrivateKey_file(ctx_, client_key_path.c_str(), - SSL_FILETYPE_PEM) != 1) { - last_openssl_error_ = ERR_get_error(); - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } - } -} - -SSLClient::SSLClient(const std::string &host, int port, - X509 *client_cert, EVP_PKEY *client_key, - const std::string &private_key_password) - : ClientImpl(host, port) { - ctx_ = SSL_CTX_new(TLS_client_method()); - - detail::split(&host_[0], &host_[host_.size()], '.', - [&](const char *b, const char *e) { - host_components_.emplace_back(b, e); - }); - - if (client_cert != nullptr && client_key != nullptr) { - if (!private_key_password.empty()) { - SSL_CTX_set_default_passwd_cb_userdata( - ctx_, reinterpret_cast( - const_cast(private_key_password.c_str()))); - } - - if (SSL_CTX_use_certificate(ctx_, client_cert) != 1 || - SSL_CTX_use_PrivateKey(ctx_, client_key) != 1) { - last_openssl_error_ = ERR_get_error(); - SSL_CTX_free(ctx_); - ctx_ = nullptr; - } - } -} - -SSLClient::~SSLClient() { - if (ctx_) { SSL_CTX_free(ctx_); } - // Make sure to shut down SSL since shutdown_ssl will resolve to the - // base function rather than the derived function once we get to the - // base class destructor, and won't free the SSL (causing a leak). - shutdown_ssl_impl(socket_, true); -} - -bool SSLClient::is_valid() const { return ctx_; } - -void SSLClient::set_ca_cert_store(X509_STORE *ca_cert_store) { - if (ca_cert_store) { - if (ctx_) { - if (SSL_CTX_get_cert_store(ctx_) != ca_cert_store) { - // Free memory allocated for old cert and use new store - // `ca_cert_store` - SSL_CTX_set_cert_store(ctx_, ca_cert_store); - ca_cert_store_ = ca_cert_store; - } - } else { - X509_STORE_free(ca_cert_store); - } - } -} - -void SSLClient::load_ca_cert_store(const char *ca_cert, - std::size_t size) { - set_ca_cert_store(ClientImpl::create_ca_cert_store(ca_cert, size)); -} - -long SSLClient::get_openssl_verify_result() const { - return verify_result_; -} - -SSL_CTX *SSLClient::ssl_context() const { return ctx_; } - -bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { - if (!is_valid()) { - error = Error::SSLConnection; - return false; - } - return ClientImpl::create_and_connect_socket(socket, error); -} - -// Assumes that socket_mutex_ is locked and that there are no requests in -// flight -bool SSLClient::connect_with_proxy( - Socket &socket, - std::chrono::time_point start_time, - Response &res, bool &success, Error &error) { - success = true; - Response proxy_res; - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, - start_time, [&](Stream &strm) { - Request req2; - req2.method = "CONNECT"; - req2.path = - detail::make_host_and_port_string_always_port(host_, port_); - if (max_timeout_msec_ > 0) { - req2.start_time_ = std::chrono::steady_clock::now(); - } - return process_request(strm, req2, proxy_res, false, error); - })) { - // Thread-safe to close everything because we are assuming there are no - // requests in flight - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - success = false; - return false; - } - - if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { - if (!proxy_digest_auth_username_.empty() && - !proxy_digest_auth_password_.empty()) { - std::map auth; - if (detail::parse_www_authenticate(proxy_res, auth, true)) { - // Close the current socket and create a new one for the authenticated - // request - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - - // Create a new socket for the authenticated CONNECT request - if (!ensure_socket_connection(socket, error)) { - success = false; - output_error_log(error, nullptr); - return false; - } - - proxy_res = Response(); - if (!detail::process_client_socket( - socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, - start_time, [&](Stream &strm) { - Request req3; - req3.method = "CONNECT"; - req3.path = detail::make_host_and_port_string_always_port( - host_, port_); - req3.headers.insert(detail::make_digest_authentication_header( - req3, auth, 1, detail::random_string(10), - proxy_digest_auth_username_, proxy_digest_auth_password_, - true)); - if (max_timeout_msec_ > 0) { - req3.start_time_ = std::chrono::steady_clock::now(); - } - return process_request(strm, req3, proxy_res, false, error); - })) { - // Thread-safe to close everything because we are assuming there are - // no requests in flight - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - success = false; - return false; - } - } - } - } - - // If status code is not 200, proxy request is failed. - // Set error to ProxyConnection and return proxy response - // as the response of the request - if (proxy_res.status != StatusCode::OK_200) { - error = Error::ProxyConnection; - output_error_log(error, nullptr); - res = std::move(proxy_res); - // Thread-safe to close everything because we are assuming there are - // no requests in flight - shutdown_ssl(socket, true); - shutdown_socket(socket); - close_socket(socket); - return false; - } - - return true; -} - -bool SSLClient::load_certs() { - auto ret = true; - - std::call_once(initialize_cert_, [&]() { - std::lock_guard guard(ctx_mutex_); - if (!ca_cert_file_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, ca_cert_file_path_.c_str(), - nullptr)) { - last_openssl_error_ = ERR_get_error(); - ret = false; - } - } else if (!ca_cert_dir_path_.empty()) { - if (!SSL_CTX_load_verify_locations(ctx_, nullptr, - ca_cert_dir_path_.c_str())) { - last_openssl_error_ = ERR_get_error(); - ret = false; - } - } else if (!ca_cert_store_) { - auto loaded = false; -#ifdef _WIN32 - loaded = - detail::load_system_certs_on_windows(SSL_CTX_get_cert_store(ctx_)); -#elif defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) && TARGET_OS_MAC - loaded = detail::load_system_certs_on_macos(SSL_CTX_get_cert_store(ctx_)); -#endif // _WIN32 - if (!loaded) { SSL_CTX_set_default_verify_paths(ctx_); } - } - }); - - return ret; -} - -bool SSLClient::initialize_ssl(Socket &socket, Error &error) { - auto ssl = detail::ssl_new( - socket.sock, ctx_, ctx_mutex_, - [&](SSL *ssl2) { - if (server_certificate_verification_) { - if (!load_certs()) { - error = Error::SSLLoadingCerts; - output_error_log(error, nullptr); - return false; - } - SSL_set_verify(ssl2, SSL_VERIFY_NONE, nullptr); - } - - if (!detail::ssl_connect_or_accept_nonblocking( - socket.sock, ssl2, SSL_connect, connection_timeout_sec_, - connection_timeout_usec_, &last_ssl_error_)) { - error = Error::SSLConnection; - output_error_log(error, nullptr); - return false; - } - - if (server_certificate_verification_) { - auto verification_status = SSLVerifierResponse::NoDecisionMade; - - if (server_certificate_verifier_) { - verification_status = server_certificate_verifier_(ssl2); - } - - if (verification_status == SSLVerifierResponse::CertificateRejected) { - last_openssl_error_ = ERR_get_error(); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - if (verification_status == SSLVerifierResponse::NoDecisionMade) { - verify_result_ = SSL_get_verify_result(ssl2); - - if (verify_result_ != X509_V_OK) { - last_openssl_error_ = static_cast(verify_result_); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - auto server_cert = SSL_get1_peer_certificate(ssl2); - auto se = detail::scope_exit([&] { X509_free(server_cert); }); - - if (server_cert == nullptr) { - last_openssl_error_ = ERR_get_error(); - error = Error::SSLServerVerification; - output_error_log(error, nullptr); - return false; - } - - if (server_hostname_verification_) { - if (!verify_host(server_cert)) { - last_openssl_error_ = X509_V_ERR_HOSTNAME_MISMATCH; - error = Error::SSLServerHostnameVerification; - output_error_log(error, nullptr); - return false; - } - } - } - } - - return true; - }, - [&](SSL *ssl2) { - // Set SNI only if host is not IP address - if (!detail::is_ip_address(host_)) { -#if defined(OPENSSL_IS_BORINGSSL) - SSL_set_tlsext_host_name(ssl2, host_.c_str()); -#else - // NOTE: Direct call instead of using the OpenSSL macro to suppress - // -Wold-style-cast warning - SSL_ctrl(ssl2, SSL_CTRL_SET_TLSEXT_HOSTNAME, - TLSEXT_NAMETYPE_host_name, - static_cast(const_cast(host_.c_str()))); -#endif - } - return true; - }); - - if (ssl) { - socket.ssl = ssl; - return true; - } - - if (ctx_ == nullptr) { - error = Error::SSLConnection; - last_openssl_error_ = ERR_get_error(); - } - - shutdown_socket(socket); - close_socket(socket); - return false; -} - -void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { - shutdown_ssl_impl(socket, shutdown_gracefully); -} - -void SSLClient::shutdown_ssl_impl(Socket &socket, - bool shutdown_gracefully) { - if (socket.sock == INVALID_SOCKET) { - assert(socket.ssl == nullptr); - return; - } - if (socket.ssl) { - detail::ssl_delete(ctx_mutex_, socket.ssl, socket.sock, - shutdown_gracefully); - socket.ssl = nullptr; - } - assert(socket.ssl == nullptr); -} - -bool SSLClient::process_socket( - const Socket &socket, - std::chrono::time_point start_time, - std::function callback) { - assert(socket.ssl); - return detail::process_client_socket_ssl( - socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, - write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, - std::move(callback)); -} - -bool SSLClient::is_ssl() const { return true; } - -bool SSLClient::verify_host(X509 *server_cert) const { - /* Quote from RFC2818 section 3.1 "Server Identity" - - If a subjectAltName extension of type dNSName is present, that MUST - be used as the identity. Otherwise, the (most specific) Common Name - field in the Subject field of the certificate MUST be used. Although - the use of the Common Name is existing practice, it is deprecated and - Certification Authorities are encouraged to use the dNSName instead. - - Matching is performed using the matching rules specified by - [RFC2459]. If more than one identity of a given type is present in - the certificate (e.g., more than one dNSName name, a match in any one - of the set is considered acceptable.) Names may contain the wildcard - character * which is considered to match any single domain name - component or component fragment. E.g., *.a.com matches foo.a.com but - not bar.foo.a.com. f*.com matches foo.com but not bar.com. - - In some cases, the URI is specified as an IP address rather than a - hostname. In this case, the iPAddress subjectAltName must be present - in the certificate and must exactly match the IP in the URI. - - */ - return verify_host_with_subject_alt_name(server_cert) || - verify_host_with_common_name(server_cert); -} - -bool -SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { - auto ret = false; - - auto type = GEN_DNS; - - struct in6_addr addr6 = {}; - struct in_addr addr = {}; - size_t addr_len = 0; - -#ifndef __MINGW32__ - if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { - type = GEN_IPADD; - addr_len = sizeof(struct in6_addr); - } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { - type = GEN_IPADD; - addr_len = sizeof(struct in_addr); +ClientConnection::~ClientConnection() { +#ifdef CPPHTTPLIB_SSL_ENABLED + if (session) { + tls::shutdown(session, true); + tls::free_session(session); + session = nullptr; } #endif - auto alt_names = static_cast( - X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); - - if (alt_names) { - auto dsn_matched = false; - auto ip_matched = false; - - auto count = sk_GENERAL_NAME_num(alt_names); - - for (decltype(count) i = 0; i < count && !dsn_matched; i++) { - auto val = sk_GENERAL_NAME_value(alt_names, i); - if (!val || val->type != type) { continue; } - - auto name = - reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); - if (name == nullptr) { continue; } - - auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); - - switch (type) { - case GEN_DNS: dsn_matched = check_host_name(name, name_len); break; - - case GEN_IPADD: - if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { - ip_matched = true; - } - break; - } - } - - if (dsn_matched || ip_matched) { ret = true; } + if (sock != INVALID_SOCKET) { + detail::close_socket(sock); + sock = INVALID_SOCKET; } - - GENERAL_NAMES_free(const_cast( - reinterpret_cast(alt_names))); - return ret; } -bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { - const auto subject_name = X509_get_subject_name(server_cert); - - if (subject_name != nullptr) { - char name[BUFSIZ]; - auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, - name, sizeof(name)); - - if (name_len != -1) { - return check_host_name(name, static_cast(name_len)); - } - } - - return false; -} - -bool SSLClient::check_host_name(const char *pattern, - size_t pattern_len) const { - // Exact match (case-insensitive) - if (host_.size() == pattern_len && - detail::case_ignore::equal(host_, std::string(pattern, pattern_len))) { - return true; - } - - // Wildcard match - // https://bugs.launchpad.net/ubuntu/+source/firefox-3.0/+bug/376484 - std::vector pattern_components; - detail::split(&pattern[0], &pattern[pattern_len], '.', - [&](const char *b, const char *e) { - pattern_components.emplace_back(b, e); - }); - - if (host_components_.size() != pattern_components.size()) { return false; } - - auto itr = pattern_components.begin(); - for (const auto &h : host_components_) { - auto &p = *itr; - if (!httplib::detail::case_ignore::equal(p, h) && p != "*") { - bool partial_match = false; - if (!p.empty() && p[p.size() - 1] == '*') { - const auto prefix_length = p.size() - 1; - if (prefix_length == 0) { - partial_match = true; - } else if (h.size() >= prefix_length) { - partial_match = - std::equal(p.begin(), - p.begin() + static_cast( - prefix_length), - h.begin(), [](const char ca, const char cb) { - return httplib::detail::case_ignore::to_lower(ca) == - httplib::detail::case_ignore::to_lower(cb); - }); - } - } - if (!partial_match) { return false; } - } - ++itr; - } - - return true; -} -#endif - // Universal client implementation Client::Client(const std::string &scheme_host_port) : Client(scheme_host_port, std::string(), std::string()) {} @@ -9973,7 +9891,7 @@ Client::Client(const std::string &scheme_host_port, if (std::regex_match(scheme_host_port, m, re)) { auto scheme = m[1].str(); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED if (!scheme.empty() && (scheme != "http" && scheme != "https")) { #else if (!scheme.empty() && scheme != "http") { @@ -9994,7 +9912,7 @@ Client::Client(const std::string &scheme_host_port, auto port = !port_str.empty() ? std::stoi(port_str) : (is_ssl ? 443 : 80); if (is_ssl) { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED cli_ = detail::make_unique(host, port, client_cert_path, client_key_path); is_ssl_ = is_ssl; @@ -10579,12 +10497,6 @@ void Client::set_basic_auth(const std::string &username, void Client::set_bearer_token_auth(const std::string &token) { cli_->set_bearer_token_auth(token); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void Client::set_digest_auth(const std::string &username, - const std::string &password) { - cli_->set_digest_auth(username, password); -} -#endif void Client::set_keep_alive(bool on) { cli_->set_keep_alive(on); } void Client::set_follow_location(bool on) { @@ -10602,6 +10514,10 @@ void Client::set_compress(bool on) { cli_->set_compress(on); } void Client::set_decompress(bool on) { cli_->set_decompress(on); } +void Client::set_payload_max_length(size_t length) { + cli_->set_payload_max_length(length); +} + void Client::set_interface(const std::string &intf) { cli_->set_interface(intf); } @@ -10616,27 +10532,6 @@ void Client::set_proxy_basic_auth(const std::string &username, void Client::set_proxy_bearer_token_auth(const std::string &token) { cli_->set_proxy_bearer_token_auth(token); } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void Client::set_proxy_digest_auth(const std::string &username, - const std::string &password) { - cli_->set_proxy_digest_auth(username, password); -} -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT -void Client::enable_server_certificate_verification(bool enabled) { - cli_->enable_server_certificate_verification(enabled); -} - -void Client::enable_server_hostname_verification(bool enabled) { - cli_->enable_server_hostname_verification(enabled); -} - -void Client::set_server_certificate_verifier( - std::function verifier) { - cli_->set_server_certificate_verifier(verifier); -} -#endif void Client::set_logger(Logger logger) { cli_->set_logger(std::move(logger)); @@ -10646,35 +10541,3399 @@ void Client::set_error_logger(ErrorLogger error_logger) { cli_->set_error_logger(std::move(error_logger)); } +/* + * Group 6: SSL Server and Client implementation + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED + +// SSL HTTP server implementation +SSLServer::SSLServer(const char *cert_path, const char *private_key_path, + const char *client_ca_cert_file_path, + const char *client_ca_cert_dir_path, + const char *private_key_password) { + using namespace tls; + + ctx_ = create_server_context(); + if (!ctx_) { return; } + + // Load server certificate and private key + if (!set_server_cert_file(ctx_, cert_path, private_key_path, + private_key_password)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + return; + } + + // Load client CA certificates for client authentication + if (client_ca_cert_file_path || client_ca_cert_dir_path) { + if (!set_client_ca_file(ctx_, client_ca_cert_file_path, + client_ca_cert_dir_path)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + return; + } + // Enable client certificate verification + set_verify_client(ctx_, true); + } +} + +SSLServer::SSLServer(const PemMemory &pem) { + using namespace tls; + ctx_ = create_server_context(); + if (ctx_) { + if (!set_server_cert_pem(ctx_, pem.cert_pem, pem.key_pem, + pem.private_key_password)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + } else if (pem.client_ca_pem && pem.client_ca_pem_len > 0) { + if (!load_ca_pem(ctx_, pem.client_ca_pem, pem.client_ca_pem_len)) { + last_ssl_error_ = static_cast(get_error()); + free_context(ctx_); + ctx_ = nullptr; + } else { + set_verify_client(ctx_, true); + } + } + } +} + +SSLServer::SSLServer(const tls::ContextSetupCallback &setup_callback) { + using namespace tls; + ctx_ = create_server_context(); + if (ctx_) { + if (!setup_callback(ctx_)) { + free_context(ctx_); + ctx_ = nullptr; + } + } +} + +SSLServer::~SSLServer() { + if (ctx_) { tls::free_context(ctx_); } +} + +bool SSLServer::is_valid() const { return ctx_ != nullptr; } + +bool SSLServer::process_and_close_socket(socket_t sock) { + using namespace tls; + + // Create TLS session with mutex protection + session_t session = nullptr; + { + std::lock_guard guard(ctx_mutex_); + session = create_session(static_cast(ctx_), sock); + } + + if (!session) { + last_ssl_error_ = static_cast(get_error()); + detail::shutdown_socket(sock); + detail::close_socket(sock); + return false; + } + + // Use scope_exit to ensure cleanup on all paths (including exceptions) + bool handshake_done = false; + bool ret = false; + auto cleanup = detail::scope_exit([&] { + // Shutdown gracefully if handshake succeeded and processing was successful + if (handshake_done) { shutdown(session, ret); } + free_session(session); + detail::shutdown_socket(sock); + detail::close_socket(sock); + }); + + // Perform TLS accept handshake with timeout + TlsError tls_err; + if (!accept_nonblocking(session, sock, read_timeout_sec_, read_timeout_usec_, + &tls_err)) { #ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Map TlsError to legacy ssl_error for backward compatibility + if (tls_err.code == ErrorCode::WantRead) { + last_ssl_error_ = SSL_ERROR_WANT_READ; + } else if (tls_err.code == ErrorCode::WantWrite) { + last_ssl_error_ = SSL_ERROR_WANT_WRITE; + } else { + last_ssl_error_ = SSL_ERROR_SSL; + } +#else + last_ssl_error_ = static_cast(get_error()); +#endif + return false; + } + + handshake_done = true; + + std::string remote_addr; + int remote_port = 0; + detail::get_remote_ip_and_port(sock, remote_addr, remote_port); + + std::string local_addr; + int local_port = 0; + detail::get_local_ip_and_port(sock, local_addr, local_port); + + ret = detail::process_server_socket_ssl( + svr_sock_, session, sock, keep_alive_max_count_, keep_alive_timeout_sec_, + read_timeout_sec_, read_timeout_usec_, write_timeout_sec_, + write_timeout_usec_, + [&](Stream &strm, bool close_connection, bool &connection_closed) { + return process_request(strm, remote_addr, remote_port, local_addr, + local_port, close_connection, connection_closed, + [&](Request &req) { req.ssl = session; }); + }); + + return ret; +} + +bool SSLServer::update_certs_pem(const char *cert_pem, + const char *key_pem, + const char *client_ca_pem, + const char *password) { + if (!ctx_) { return false; } + std::lock_guard guard(ctx_mutex_); + if (!tls::update_server_cert(ctx_, cert_pem, key_pem, password)) { + return false; + } + if (client_ca_pem) { + return tls::update_server_client_ca(ctx_, client_ca_pem); + } + return true; +} + +// SSL HTTP client implementation +SSLClient::~SSLClient() { + if (ctx_) { tls::free_context(ctx_); } + // Make sure to shut down SSL since shutdown_ssl will resolve to the + // base function rather than the derived function once we get to the + // base class destructor, and won't free the SSL (causing a leak). + shutdown_ssl_impl(socket_, true); +} + +bool SSLClient::is_valid() const { return ctx_ != nullptr; } + +void SSLClient::shutdown_ssl(Socket &socket, bool shutdown_gracefully) { + shutdown_ssl_impl(socket, shutdown_gracefully); +} + +void SSLClient::shutdown_ssl_impl(Socket &socket, + bool shutdown_gracefully) { + if (socket.sock == INVALID_SOCKET) { + assert(socket.ssl == nullptr); + return; + } + if (socket.ssl) { + tls::shutdown(socket.ssl, shutdown_gracefully); + { + std::lock_guard guard(ctx_mutex_); + tls::free_session(socket.ssl); + } + socket.ssl = nullptr; + } + assert(socket.ssl == nullptr); +} + +bool SSLClient::process_socket( + const Socket &socket, + std::chrono::time_point start_time, + std::function callback) { + assert(socket.ssl); + return detail::process_client_socket_ssl( + socket.ssl, socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, start_time, + std::move(callback)); +} + +bool SSLClient::is_ssl() const { return true; } + +bool SSLClient::create_and_connect_socket(Socket &socket, Error &error) { + if (!is_valid()) { + error = Error::SSLConnection; + return false; + } + return ClientImpl::create_and_connect_socket(socket, error); +} + +// Assumes that socket_mutex_ is locked and that there are no requests in +// flight +bool SSLClient::connect_with_proxy( + Socket &socket, + std::chrono::time_point start_time, + Response &res, bool &success, Error &error) { + success = true; + Response proxy_res; + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req2; + req2.method = "CONNECT"; + req2.path = + detail::make_host_and_port_string_always_port(host_, port_); + if (max_timeout_msec_ > 0) { + req2.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req2, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are no + // requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + + if (proxy_res.status == StatusCode::ProxyAuthenticationRequired_407) { + if (!proxy_digest_auth_username_.empty() && + !proxy_digest_auth_password_.empty()) { + std::map auth; + if (detail::parse_www_authenticate(proxy_res, auth, true)) { + // Close the current socket and create a new one for the authenticated + // request + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + + // Create a new socket for the authenticated CONNECT request + if (!ensure_socket_connection(socket, error)) { + success = false; + output_error_log(error, nullptr); + return false; + } + + proxy_res = Response(); + if (!detail::process_client_socket( + socket.sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_, max_timeout_msec_, + start_time, [&](Stream &strm) { + Request req3; + req3.method = "CONNECT"; + req3.path = detail::make_host_and_port_string_always_port( + host_, port_); + req3.headers.insert(detail::make_digest_authentication_header( + req3, auth, 1, detail::random_string(10), + proxy_digest_auth_username_, proxy_digest_auth_password_, + true)); + if (max_timeout_msec_ > 0) { + req3.start_time_ = std::chrono::steady_clock::now(); + } + return process_request(strm, req3, proxy_res, false, error); + })) { + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + success = false; + return false; + } + } + } + } + + // If status code is not 200, proxy request is failed. + // Set error to ProxyConnection and return proxy response + // as the response of the request + if (proxy_res.status != StatusCode::OK_200) { + error = Error::ProxyConnection; + output_error_log(error, nullptr); + res = std::move(proxy_res); + // Thread-safe to close everything because we are assuming there are + // no requests in flight + shutdown_ssl(socket, true); + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +bool SSLClient::ensure_socket_connection(Socket &socket, Error &error) { + if (!ClientImpl::ensure_socket_connection(socket, error)) { return false; } + + if (!proxy_host_.empty() && proxy_port_ != -1) { return true; } + + if (!initialize_ssl(socket, error)) { + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} + +// SSL HTTP client implementation +SSLClient::SSLClient(const std::string &host) + : SSLClient(host, 443, std::string(), std::string()) {} + +SSLClient::SSLClient(const std::string &host, int port) + : SSLClient(host, port, std::string(), std::string()) {} + +SSLClient::SSLClient(const std::string &host, int port, + const std::string &client_cert_path, + const std::string &client_key_path, + const std::string &private_key_password) + : ClientImpl(host, port, client_cert_path, client_key_path) { + ctx_ = tls::create_client_context(); + if (!ctx_) { return; } + + tls::set_min_version(ctx_, tls::Version::TLS1_2); + + if (!client_cert_path.empty() && !client_key_path.empty()) { + const char *password = + private_key_password.empty() ? nullptr : private_key_password.c_str(); + if (!tls::set_client_cert_file(ctx_, client_cert_path.c_str(), + client_key_path.c_str(), password)) { + last_backend_error_ = tls::get_error(); + tls::free_context(ctx_); + ctx_ = nullptr; + } + } +} + +SSLClient::SSLClient(const std::string &host, int port, + const PemMemory &pem) + : ClientImpl(host, port) { + ctx_ = tls::create_client_context(); + if (!ctx_) { return; } + + tls::set_min_version(ctx_, tls::Version::TLS1_2); + + if (pem.cert_pem && pem.key_pem) { + if (!tls::set_client_cert_pem(ctx_, pem.cert_pem, pem.key_pem, + pem.private_key_password)) { + last_backend_error_ = tls::get_error(); + tls::free_context(ctx_); + ctx_ = nullptr; + } + } +} + +void SSLClient::set_ca_cert_store(tls::ca_store_t ca_cert_store) { + if (ca_cert_store && ctx_) { + // set_ca_store takes ownership of ca_cert_store + tls::set_ca_store(ctx_, ca_cert_store); + } else if (ca_cert_store) { + tls::free_ca_store(ca_cert_store); + } +} + +void +SSLClient::set_server_certificate_verifier(tls::VerifyCallback verifier) { + if (!ctx_) { return; } + tls::set_verify_callback(ctx_, verifier); +} + +void SSLClient::set_session_verifier( + std::function verifier) { + session_verifier_ = std::move(verifier); +} + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +void SSLClient::enable_windows_certificate_verification(bool enabled) { + enable_windows_cert_verification_ = enabled; +} +#endif + +void SSLClient::load_ca_cert_store(const char *ca_cert, + std::size_t size) { + if (ctx_ && ca_cert && size > 0) { + ca_cert_pem_.assign(ca_cert, size); // Store for redirect transfer + tls::load_ca_pem(ctx_, ca_cert, size); + } +} + +bool SSLClient::load_certs() { + auto ret = true; + + std::call_once(initialize_cert_, [&]() { + std::lock_guard guard(ctx_mutex_); + + if (!ca_cert_file_path_.empty()) { + if (!tls::load_ca_file(ctx_, ca_cert_file_path_.c_str())) { + last_backend_error_ = tls::get_error(); + ret = false; + } + } else if (!ca_cert_dir_path_.empty()) { + if (!tls::load_ca_dir(ctx_, ca_cert_dir_path_.c_str())) { + last_backend_error_ = tls::get_error(); + ret = false; + } + } else if (ca_cert_pem_.empty()) { + if (!tls::load_system_certs(ctx_)) { + last_backend_error_ = tls::get_error(); + } + } + }); + + return ret; +} + +bool SSLClient::initialize_ssl(Socket &socket, Error &error) { + using namespace tls; + + // Load CA certificates if server verification is enabled + if (server_certificate_verification_) { + if (!load_certs()) { + error = Error::SSLLoadingCerts; + output_error_log(error, nullptr); + return false; + } + } + + bool is_ip = detail::is_ip_address(host_); + +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT + // MbedTLS needs explicit verification mode (OpenSSL uses SSL_VERIFY_NONE + // by default and performs all verification post-handshake). + // For IP addresses with verification enabled, use OPTIONAL mode since + // MbedTLS requires hostname for VERIFY_REQUIRED. + if (is_ip && server_certificate_verification_) { + set_verify_client(ctx_, false); + } else { + set_verify_client(ctx_, server_certificate_verification_); + } +#endif + + // Create TLS session + session_t session = nullptr; + { + std::lock_guard guard(ctx_mutex_); + session = create_session(ctx_, socket.sock); + } + + if (!session) { + error = Error::SSLConnection; + last_backend_error_ = get_error(); + return false; + } + + // Use scope_exit to ensure session is freed on error paths + bool success = false; + auto session_guard = detail::scope_exit([&] { + if (!success) { free_session(session); } + }); + + // Set SNI extension (skip for IP addresses per RFC 6066). + // On MbedTLS, set_sni also enables hostname verification internally. + // On OpenSSL, set_sni only sets SNI; verification is done post-handshake. + if (!is_ip) { + if (!set_sni(session, host_.c_str())) { + error = Error::SSLConnection; + last_backend_error_ = get_error(); + return false; + } + } + + // Perform non-blocking TLS handshake with timeout + TlsError tls_err; + if (!connect_nonblocking(session, socket.sock, connection_timeout_sec_, + connection_timeout_usec_, &tls_err)) { + last_ssl_error_ = static_cast(tls_err.code); + last_backend_error_ = tls_err.backend_code; + if (tls_err.code == ErrorCode::CertVerifyFailed) { + error = Error::SSLServerVerification; + } else if (tls_err.code == ErrorCode::HostnameMismatch) { + error = Error::SSLServerHostnameVerification; + } else { + error = Error::SSLConnection; + } + output_error_log(error, nullptr); + return false; + } + + // Post-handshake session verifier callback + auto verification_status = SSLVerifierResponse::NoDecisionMade; + if (session_verifier_) { verification_status = session_verifier_(session); } + + if (verification_status == SSLVerifierResponse::CertificateRejected) { + last_backend_error_ = get_error(); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + + // Default server certificate verification + if (verification_status == SSLVerifierResponse::NoDecisionMade && + server_certificate_verification_) { + verify_result_ = tls::get_verify_result(session); + if (verify_result_ != 0) { + last_backend_error_ = static_cast(verify_result_); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + + auto server_cert = get_peer_cert(session); + if (!server_cert) { + last_backend_error_ = get_error(); + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + auto cert_guard = detail::scope_exit([&] { free_cert(server_cert); }); + + // Hostname verification (post-handshake for all cases). + // On OpenSSL, verification is always post-handshake (SSL_VERIFY_NONE). + // On MbedTLS, set_sni already enabled hostname verification during + // handshake for non-IP hosts, but this check is still needed for IP + // addresses where SNI is not set. + if (server_hostname_verification_) { + if (!verify_hostname(server_cert, host_.c_str())) { + last_backend_error_ = hostname_mismatch_code(); + error = Error::SSLServerHostnameVerification; + output_error_log(error, nullptr); + return false; + } + } + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) + // Additional Windows Schannel verification. + // This provides real-time certificate validation with Windows Update + // integration, working with both OpenSSL and MbedTLS backends. + // Skip when a custom CA cert is specified, as the Windows certificate + // store would not know about user-provided CA certificates. + if (enable_windows_cert_verification_ && ca_cert_file_path_.empty() && + ca_cert_dir_path_.empty() && ca_cert_pem_.empty()) { + std::vector der; + if (get_cert_der(server_cert, der)) { + unsigned long wincrypt_error = 0; + if (!detail::verify_cert_with_windows_schannel( + der, host_, server_hostname_verification_, wincrypt_error)) { + last_backend_error_ = wincrypt_error; + error = Error::SSLServerVerification; + output_error_log(error, nullptr); + return false; + } + } + } +#endif + } + + success = true; + socket.ssl = session; + return true; +} + +void Client::set_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_digest_auth(username, password); +} + +void Client::set_proxy_digest_auth(const std::string &username, + const std::string &password) { + cli_->set_proxy_digest_auth(username, password); +} + +void Client::enable_server_certificate_verification(bool enabled) { + cli_->enable_server_certificate_verification(enabled); +} + +void Client::enable_server_hostname_verification(bool enabled) { + cli_->enable_server_hostname_verification(enabled); +} + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) +void Client::enable_windows_certificate_verification(bool enabled) { + if (is_ssl_) { + static_cast(*cli_).enable_windows_certificate_verification( + enabled); + } +} +#endif + void Client::set_ca_cert_path(const std::string &ca_cert_file_path, const std::string &ca_cert_dir_path) { cli_->set_ca_cert_path(ca_cert_file_path, ca_cert_dir_path); } -void Client::set_ca_cert_store(X509_STORE *ca_cert_store) { +void Client::set_ca_cert_store(tls::ca_store_t ca_cert_store) { if (is_ssl_) { static_cast(*cli_).set_ca_cert_store(ca_cert_store); - } else { - cli_->set_ca_cert_store(ca_cert_store); + } else if (ca_cert_store) { + tls::free_ca_store(ca_cert_store); } } void Client::load_ca_cert_store(const char *ca_cert, std::size_t size) { - set_ca_cert_store(cli_->create_ca_cert_store(ca_cert, size)); + set_ca_cert_store(tls::create_ca_store(ca_cert, size)); } -long Client::get_openssl_verify_result() const { +void +Client::set_server_certificate_verifier(tls::VerifyCallback verifier) { if (is_ssl_) { - return static_cast(*cli_).get_openssl_verify_result(); + static_cast(*cli_).set_server_certificate_verifier( + std::move(verifier)); } - return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? } +void Client::set_session_verifier( + std::function verifier) { + if (is_ssl_) { + static_cast(*cli_).set_session_verifier(std::move(verifier)); + } +} + +tls::ctx_t Client::tls_context() const { + if (is_ssl_) { return static_cast(*cli_).tls_context(); } + return nullptr; +} + +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 7: TLS abstraction layer - Common API + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED + +namespace tls { + +// Helper for PeerCert construction +PeerCert get_peer_cert_from_session(const_session_t session) { + return PeerCert(get_peer_cert(session)); +} + +namespace impl { + +VerifyCallback &get_verify_callback() { + static thread_local VerifyCallback callback; + return callback; +} + +VerifyCallback &get_mbedtls_verify_callback() { + static thread_local VerifyCallback callback; + return callback; +} + +} // namespace impl + +bool set_client_ca_file(ctx_t ctx, const char *ca_file, + const char *ca_dir) { + if (!ctx) { return false; } + + bool success = true; + if (ca_file && *ca_file) { + if (!load_ca_file(ctx, ca_file)) { success = false; } + } + if (ca_dir && *ca_dir) { + if (!load_ca_dir(ctx, ca_dir)) { success = false; } + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + // Set CA list for client certificate request (CertificateRequest message) + if (ca_file && *ca_file) { + auto list = SSL_load_client_CA_file(ca_file); + if (list) { SSL_CTX_set_client_CA_list(static_cast(ctx), list); } + } +#endif + + return success; +} + +bool set_server_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + return set_client_cert_pem(ctx, cert, key, password); +} + +bool set_server_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + return set_client_cert_file(ctx, cert_path, key_path, password); +} + +// PeerCert implementation +PeerCert::PeerCert() = default; + +PeerCert::PeerCert(cert_t cert) : cert_(cert) {} + +PeerCert::PeerCert(PeerCert &&other) noexcept : cert_(other.cert_) { + other.cert_ = nullptr; +} + +PeerCert &PeerCert::operator=(PeerCert &&other) noexcept { + if (this != &other) { + if (cert_) { free_cert(cert_); } + cert_ = other.cert_; + other.cert_ = nullptr; + } + return *this; +} + +PeerCert::~PeerCert() { + if (cert_) { free_cert(cert_); } +} + +PeerCert::operator bool() const { return cert_ != nullptr; } + +std::string PeerCert::subject_cn() const { + return cert_ ? get_cert_subject_cn(cert_) : std::string(); +} + +std::string PeerCert::issuer_name() const { + return cert_ ? get_cert_issuer_name(cert_) : std::string(); +} + +bool PeerCert::check_hostname(const char *hostname) const { + return cert_ ? verify_hostname(cert_, hostname) : false; +} + +std::vector PeerCert::sans() const { + std::vector result; + if (cert_) { get_cert_sans(cert_, result); } + return result; +} + +bool PeerCert::validity(time_t ¬_before, time_t ¬_after) const { + return cert_ ? get_cert_validity(cert_, not_before, not_after) : false; +} + +std::string PeerCert::serial() const { + return cert_ ? get_cert_serial(cert_) : std::string(); +} + +// VerifyContext method implementations +std::string VerifyContext::subject_cn() const { + return cert ? get_cert_subject_cn(cert) : std::string(); +} + +std::string VerifyContext::issuer_name() const { + return cert ? get_cert_issuer_name(cert) : std::string(); +} + +bool VerifyContext::check_hostname(const char *hostname) const { + return cert ? verify_hostname(cert, hostname) : false; +} + +std::vector VerifyContext::sans() const { + std::vector result; + if (cert) { get_cert_sans(cert, result); } + return result; +} + +bool VerifyContext::validity(time_t ¬_before, + time_t ¬_after) const { + return cert ? get_cert_validity(cert, not_before, not_after) : false; +} + +std::string VerifyContext::serial() const { + return cert ? get_cert_serial(cert) : std::string(); +} + +// TlsError static method implementation +std::string TlsError::verify_error_to_string(long error_code) { + return verify_error_string(error_code); +} + +} // namespace tls + +// Request::peer_cert() implementation +tls::PeerCert Request::peer_cert() const { + return tls::get_peer_cert_from_session(ssl); +} + +// Request::sni() implementation +std::string Request::sni() const { + if (!ssl) { return std::string(); } + const char *s = tls::get_sni(ssl); + return s ? std::string(s) : std::string(); +} + +#endif // CPPHTTPLIB_SSL_ENABLED + +/* + * Group 8: TLS abstraction layer - OpenSSL backend + */ + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT SSL_CTX *Client::ssl_context() const { if (is_ssl_) { return static_cast(*cli_).ssl_context(); } return nullptr; } + +void Client::set_server_certificate_verifier( + std::function verifier) { + cli_->set_server_certificate_verifier(verifier); +} + +long Client::get_verify_result() const { + if (is_ssl_) { return static_cast(*cli_).get_verify_result(); } + return -1; // NOTE: -1 doesn't match any of X509_V_ERR_??? +} +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +/* + * OpenSSL Backend Implementation + */ + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +namespace tls { + +namespace impl { + +// OpenSSL-specific helpers for converting native types to PEM +std::string x509_to_pem(X509 *cert) { + if (!cert) return {}; + BIO *bio = BIO_new(BIO_s_mem()); + if (!bio) return {}; + if (PEM_write_bio_X509(bio, cert) != 1) { + BIO_free(bio); + return {}; + } + char *data = nullptr; + long len = BIO_get_mem_data(bio, &data); + std::string pem(data, static_cast(len)); + BIO_free(bio); + return pem; +} + +std::string evp_pkey_to_pem(EVP_PKEY *key) { + if (!key) return {}; + BIO *bio = BIO_new(BIO_s_mem()); + if (!bio) return {}; + if (PEM_write_bio_PrivateKey(bio, key, nullptr, nullptr, 0, nullptr, + nullptr) != 1) { + BIO_free(bio); + return {}; + } + char *data = nullptr; + long len = BIO_get_mem_data(bio, &data); + std::string pem(data, static_cast(len)); + BIO_free(bio); + return pem; +} + +std::string x509_store_to_pem(X509_STORE *store) { + if (!store) return {}; + std::string pem; + auto objs = X509_STORE_get0_objects(store); + if (!objs) return {}; + auto count = sk_X509_OBJECT_num(objs); + for (decltype(count) i = 0; i < count; i++) { + auto obj = sk_X509_OBJECT_value(objs, i); + if (X509_OBJECT_get_type(obj) == X509_LU_X509) { + auto cert = X509_OBJECT_get0_X509(obj); + if (cert) { pem += x509_to_pem(cert); } + } + } + return pem; +} + +// Helper to map OpenSSL SSL_get_error to ErrorCode +ErrorCode map_ssl_error(int ssl_error, int &out_errno) { + switch (ssl_error) { + case SSL_ERROR_NONE: return ErrorCode::Success; + case SSL_ERROR_WANT_READ: return ErrorCode::WantRead; + case SSL_ERROR_WANT_WRITE: return ErrorCode::WantWrite; + case SSL_ERROR_ZERO_RETURN: return ErrorCode::PeerClosed; + case SSL_ERROR_SYSCALL: out_errno = errno; return ErrorCode::SyscallError; + case SSL_ERROR_SSL: + default: return ErrorCode::Fatal; + } +} + +// Helper: Create client CA list from PEM string +// Returns a new STACK_OF(X509_NAME)* or nullptr on failure +// Caller takes ownership of returned list +STACK_OF(X509_NAME) * + create_client_ca_list_from_pem(const char *ca_pem) { + if (!ca_pem) { return nullptr; } + + auto ca_list = sk_X509_NAME_new_null(); + if (!ca_list) { return nullptr; } + + BIO *bio = BIO_new_mem_buf(ca_pem, -1); + if (!bio) { + sk_X509_NAME_pop_free(ca_list, X509_NAME_free); + return nullptr; + } + + X509 *cert = nullptr; + while ((cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr)) != + nullptr) { + X509_NAME *name = X509_get_subject_name(cert); + if (name) { sk_X509_NAME_push(ca_list, X509_NAME_dup(name)); } + X509_free(cert); + } + BIO_free(bio); + + return ca_list; +} + +// Helper: Extract CA names from X509_STORE +// Returns a new STACK_OF(X509_NAME)* or nullptr on failure +// Caller takes ownership of returned list +STACK_OF(X509_NAME) * + extract_client_ca_list_from_store(X509_STORE *store) { + if (!store) { return nullptr; } + + auto ca_list = sk_X509_NAME_new_null(); + if (!ca_list) { return nullptr; } + + auto objs = X509_STORE_get0_objects(store); + if (!objs) { + sk_X509_NAME_free(ca_list); + return nullptr; + } + + auto count = sk_X509_OBJECT_num(objs); + for (decltype(count) i = 0; i < count; i++) { + auto obj = sk_X509_OBJECT_value(objs, i); + if (X509_OBJECT_get_type(obj) == X509_LU_X509) { + auto cert = X509_OBJECT_get0_X509(obj); + if (cert) { + auto subject = X509_get_subject_name(cert); + if (subject) { + auto name_dup = X509_NAME_dup(subject); + if (name_dup) { sk_X509_NAME_push(ca_list, name_dup); } + } + } + } + } + + if (sk_X509_NAME_num(ca_list) == 0) { + sk_X509_NAME_free(ca_list); + return nullptr; + } + + return ca_list; +} + +// OpenSSL verify callback wrapper +int openssl_verify_callback(int preverify_ok, X509_STORE_CTX *ctx) { + auto &callback = get_verify_callback(); + if (!callback) { return preverify_ok; } + + // Get SSL object from X509_STORE_CTX + auto ssl = static_cast( + X509_STORE_CTX_get_ex_data(ctx, SSL_get_ex_data_X509_STORE_CTX_idx())); + if (!ssl) { return preverify_ok; } + + // Get current certificate and depth + auto cert = X509_STORE_CTX_get_current_cert(ctx); + int depth = X509_STORE_CTX_get_error_depth(ctx); + int error = X509_STORE_CTX_get_error(ctx); + + // Build context + VerifyContext verify_ctx; + verify_ctx.session = static_cast(ssl); + verify_ctx.cert = static_cast(cert); + verify_ctx.depth = depth; + verify_ctx.preverify_ok = (preverify_ok != 0); + verify_ctx.error_code = error; + verify_ctx.error_string = + (error != X509_V_OK) ? X509_verify_cert_error_string(error) : nullptr; + + return callback(verify_ctx) ? 1 : 0; +} + +} // namespace impl + +ctx_t create_client_context() { + SSL_CTX *ctx = SSL_CTX_new(TLS_client_method()); + if (ctx) { + // Disable auto-retry to properly handle non-blocking I/O + SSL_CTX_clear_mode(ctx, SSL_MODE_AUTO_RETRY); + // Set minimum TLS version + SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); + } + return static_cast(ctx); +} + +void free_context(ctx_t ctx) { + if (ctx) { SSL_CTX_free(static_cast(ctx)); } +} + +bool set_min_version(ctx_t ctx, Version version) { + if (!ctx) return false; + return SSL_CTX_set_min_proto_version(static_cast(ctx), + static_cast(version)) == 1; +} + +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len) { + if (!ctx || !pem || len == 0) return false; + + auto ssl_ctx = static_cast(ctx); + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) return false; + + auto bio = BIO_new_mem_buf(pem, static_cast(len)); + if (!bio) return false; + + bool ok = true; + X509 *cert = nullptr; + while ((cert = PEM_read_bio_X509(bio, nullptr, nullptr, nullptr)) != + nullptr) { + if (X509_STORE_add_cert(store, cert) != 1) { + // Ignore duplicate errors + auto err = ERR_peek_last_error(); + if (ERR_GET_REASON(err) != X509_R_CERT_ALREADY_IN_HASH_TABLE) { + ok = false; + } + } + X509_free(cert); + if (!ok) break; + } + BIO_free(bio); + + // Clear any "no more certificates" errors + ERR_clear_error(); + return ok; +} + +bool load_ca_file(ctx_t ctx, const char *file_path) { + if (!ctx || !file_path) return false; + return SSL_CTX_load_verify_locations(static_cast(ctx), file_path, + nullptr) == 1; +} + +bool load_ca_dir(ctx_t ctx, const char *dir_path) { + if (!ctx || !dir_path) return false; + return SSL_CTX_load_verify_locations(static_cast(ctx), nullptr, + dir_path) == 1; +} + +bool load_system_certs(ctx_t ctx) { + if (!ctx) return false; + auto ssl_ctx = static_cast(ctx); + +#ifdef _WIN32 + // Windows: Load from system certificate store (ROOT and CA) + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) return false; + + bool loaded_any = false; + static const wchar_t *store_names[] = {L"ROOT", L"CA"}; + for (auto store_name : store_names) { + auto hStore = CertOpenSystemStoreW(NULL, store_name); + if (!hStore) continue; + + PCCERT_CONTEXT pContext = nullptr; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + const unsigned char *data = pContext->pbCertEncoded; + auto x509 = d2i_X509(nullptr, &data, pContext->cbCertEncoded); + if (x509) { + if (X509_STORE_add_cert(store, x509) == 1) { loaded_any = true; } + X509_free(x509); + } + } + CertCloseStore(hStore, 0); + } + return loaded_any; + +#elif defined(__APPLE__) +#ifdef CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN + // macOS: Load from Keychain + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) return false; + + CFArrayRef certs = nullptr; + if (SecTrustCopyAnchorCertificates(&certs) != errSecSuccess || !certs) { + return SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; + } + + bool loaded_any = false; + auto count = CFArrayGetCount(certs); + for (CFIndex i = 0; i < count; i++) { + auto cert = reinterpret_cast( + const_cast(CFArrayGetValueAtIndex(certs, i))); + CFDataRef der = SecCertificateCopyData(cert); + if (der) { + const unsigned char *data = CFDataGetBytePtr(der); + auto x509 = d2i_X509(nullptr, &data, CFDataGetLength(der)); + if (x509) { + if (X509_STORE_add_cert(store, x509) == 1) { loaded_any = true; } + X509_free(x509); + } + CFRelease(der); + } + } + CFRelease(certs); + return loaded_any || SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; +#else + return SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; #endif +#else + // Other Unix: use default verify paths + return SSL_CTX_set_default_verify_paths(ssl_ctx) == 1; +#endif +} + +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + if (!ctx || !cert || !key) return false; + + auto ssl_ctx = static_cast(ctx); + + // Load certificate + auto cert_bio = BIO_new_mem_buf(cert, -1); + if (!cert_bio) return false; + + auto x509 = PEM_read_bio_X509(cert_bio, nullptr, nullptr, nullptr); + BIO_free(cert_bio); + if (!x509) return false; + + auto cert_ok = SSL_CTX_use_certificate(ssl_ctx, x509) == 1; + X509_free(x509); + if (!cert_ok) return false; + + // Load private key + auto key_bio = BIO_new_mem_buf(key, -1); + if (!key_bio) return false; + + auto pkey = PEM_read_bio_PrivateKey(key_bio, nullptr, nullptr, + password ? const_cast(password) + : nullptr); + BIO_free(key_bio); + if (!pkey) return false; + + auto key_ok = SSL_CTX_use_PrivateKey(ssl_ctx, pkey) == 1; + EVP_PKEY_free(pkey); + + return key_ok && SSL_CTX_check_private_key(ssl_ctx) == 1; +} + +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + if (!ctx || !cert_path || !key_path) return false; + + auto ssl_ctx = static_cast(ctx); + + if (password && password[0] != '\0') { + SSL_CTX_set_default_passwd_cb_userdata( + ssl_ctx, reinterpret_cast(const_cast(password))); + } + + return SSL_CTX_use_certificate_chain_file(ssl_ctx, cert_path) == 1 && + SSL_CTX_use_PrivateKey_file(ssl_ctx, key_path, SSL_FILETYPE_PEM) == 1; +} + +ctx_t create_server_context() { + SSL_CTX *ctx = SSL_CTX_new(TLS_server_method()); + if (ctx) { + SSL_CTX_set_options(ctx, SSL_OP_NO_COMPRESSION | + SSL_OP_NO_SESSION_RESUMPTION_ON_RENEGOTIATION); + SSL_CTX_set_min_proto_version(ctx, TLS1_2_VERSION); + } + return static_cast(ctx); +} + +void set_verify_client(ctx_t ctx, bool require) { + if (!ctx) return; + SSL_CTX_set_verify(static_cast(ctx), + require + ? (SSL_VERIFY_PEER | SSL_VERIFY_FAIL_IF_NO_PEER_CERT) + : SSL_VERIFY_NONE, + nullptr); +} + +session_t create_session(ctx_t ctx, socket_t sock) { + if (!ctx || sock == INVALID_SOCKET) return nullptr; + + auto ssl_ctx = static_cast(ctx); + SSL *ssl = SSL_new(ssl_ctx); + if (!ssl) return nullptr; + + // Disable auto-retry for proper non-blocking I/O handling + SSL_clear_mode(ssl, SSL_MODE_AUTO_RETRY); + + auto bio = BIO_new_socket(static_cast(sock), BIO_NOCLOSE); + if (!bio) { + SSL_free(ssl); + return nullptr; + } + + SSL_set_bio(ssl, bio, bio); + return static_cast(ssl); +} + +void free_session(session_t session) { + if (session) { SSL_free(static_cast(session)); } +} + +bool set_sni(session_t session, const char *hostname) { + if (!session || !hostname) return false; + + auto ssl = static_cast(session); + + // Set SNI (Server Name Indication) only - does not enable verification +#if defined(OPENSSL_IS_BORINGSSL) + return SSL_set_tlsext_host_name(ssl, hostname) == 1; +#else + // Direct call instead of macro to suppress -Wold-style-cast warning + return SSL_ctrl(ssl, SSL_CTRL_SET_TLSEXT_HOSTNAME, TLSEXT_NAMETYPE_host_name, + static_cast(const_cast(hostname))) == 1; +#endif +} + +bool set_hostname(session_t session, const char *hostname) { + if (!session || !hostname) return false; + + auto ssl = static_cast(session); + + // Set SNI (Server Name Indication) + if (!set_sni(session, hostname)) { return false; } + + // Enable hostname verification + auto param = SSL_get0_param(ssl); + if (!param) return false; + + X509_VERIFY_PARAM_set_hostflags(param, X509_CHECK_FLAG_NO_PARTIAL_WILDCARDS); + if (X509_VERIFY_PARAM_set1_host(param, hostname, 0) != 1) { return false; } + + SSL_set_verify(ssl, SSL_VERIFY_PEER, nullptr); + return true; +} + +TlsError connect(session_t session) { + if (!session) { return TlsError(); } + + auto ssl = static_cast(session); + auto ret = SSL_connect(ssl); + + TlsError err; + if (ret == 1) { + err.code = ErrorCode::Success; + } else { + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + err.backend_code = ERR_get_error(); + } + return err; +} + +TlsError accept(session_t session) { + if (!session) { return TlsError(); } + + auto ssl = static_cast(session); + auto ret = SSL_accept(ssl); + + TlsError err; + if (ret == 1) { + err.code = ErrorCode::Success; + } else { + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + err.backend_code = ERR_get_error(); + } + return err; +} + +bool connect_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto ssl = static_cast(session); + auto bio = SSL_get_rbio(ssl); + + // Set non-blocking mode for handshake + detail::set_nonblocking(sock, true); + if (bio) { BIO_set_nbio(bio, 1); } + + auto cleanup = detail::scope_exit([&]() { + // Restore blocking mode after handshake + if (bio) { BIO_set_nbio(bio, 0); } + detail::set_nonblocking(sock, false); + }); + + auto res = 0; + while ((res = SSL_connect(ssl)) != 1) { + auto ssl_err = SSL_get_error(ssl, res); + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + case SSL_ERROR_WANT_WRITE: + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + default: break; + } + if (err) { + err->code = impl::map_ssl_error(ssl_err, err->sys_errno); + err->backend_code = ERR_get_error(); + } + return false; + } + if (err) { err->code = ErrorCode::Success; } + return true; +} + +bool accept_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto ssl = static_cast(session); + auto bio = SSL_get_rbio(ssl); + + // Set non-blocking mode for handshake + detail::set_nonblocking(sock, true); + if (bio) { BIO_set_nbio(bio, 1); } + + auto cleanup = detail::scope_exit([&]() { + // Restore blocking mode after handshake + if (bio) { BIO_set_nbio(bio, 0); } + detail::set_nonblocking(sock, false); + }); + + auto res = 0; + while ((res = SSL_accept(ssl)) != 1) { + auto ssl_err = SSL_get_error(ssl, res); + switch (ssl_err) { + case SSL_ERROR_WANT_READ: + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + case SSL_ERROR_WANT_WRITE: + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + break; + default: break; + } + if (err) { + err->code = impl::map_ssl_error(ssl_err, err->sys_errno); + err->backend_code = ERR_get_error(); + } + return false; + } + if (err) { err->code = ErrorCode::Success; } + return true; +} + +ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto ssl = static_cast(session); + constexpr auto max_len = + static_cast((std::numeric_limits::max)()); + if (len > max_len) { len = max_len; } + auto ret = SSL_read(ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return ret; + } + + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + if (err.code == ErrorCode::Fatal) { err.backend_code = ERR_get_error(); } + return -1; +} + +ssize_t write(session_t session, const void *buf, size_t len, + TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto ssl = static_cast(session); + auto ret = SSL_write(ssl, buf, static_cast(len)); + + if (ret > 0) { + err.code = ErrorCode::Success; + return ret; + } + + auto ssl_err = SSL_get_error(ssl, ret); + err.code = impl::map_ssl_error(ssl_err, err.sys_errno); + if (err.code == ErrorCode::Fatal) { err.backend_code = ERR_get_error(); } + return -1; +} + +int pending(const_session_t session) { + if (!session) return 0; + return SSL_pending(static_cast(const_cast(session))); +} + +void shutdown(session_t session, bool graceful) { + if (!session) return; + + auto ssl = static_cast(session); + if (graceful) { + // First call sends close_notify + if (SSL_shutdown(ssl) == 0) { + // Second call waits for peer's close_notify + SSL_shutdown(ssl); + } + } +} + +bool is_peer_closed(session_t session, socket_t sock) { + if (!session) return true; + + // Temporarily set socket to non-blocking to avoid blocking on SSL_peek + detail::set_nonblocking(sock, true); + auto se = detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + auto ssl = static_cast(session); + char buf; + auto ret = SSL_peek(ssl, &buf, 1); + if (ret > 0) return false; + + auto err = SSL_get_error(ssl, ret); + return err == SSL_ERROR_ZERO_RETURN; +} + +cert_t get_peer_cert(const_session_t session) { + if (!session) return nullptr; + return static_cast(SSL_get1_peer_certificate( + static_cast(const_cast(session)))); +} + +void free_cert(cert_t cert) { + if (cert) { X509_free(static_cast(cert)); } +} + +bool verify_hostname(cert_t cert, const char *hostname) { + if (!cert || !hostname) return false; + + auto x509 = static_cast(cert); + + // Use X509_check_ip_asc for IP addresses, X509_check_host for DNS names + if (detail::is_ip_address(hostname)) { + return X509_check_ip_asc(x509, hostname, 0) == 1; + } + return X509_check_host(x509, hostname, strlen(hostname), 0, nullptr) == 1; +} + +uint64_t hostname_mismatch_code() { + return static_cast(X509_V_ERR_HOSTNAME_MISMATCH); +} + +long get_verify_result(const_session_t session) { + if (!session) return X509_V_ERR_UNSPECIFIED; + return SSL_get_verify_result(static_cast(const_cast(session))); +} + +std::string get_cert_subject_cn(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + auto subject_name = X509_get_subject_name(x509); + if (!subject_name) return ""; + + char buf[256]; + auto len = + X509_NAME_get_text_by_NID(subject_name, NID_commonName, buf, sizeof(buf)); + if (len < 0) return ""; + return std::string(buf, static_cast(len)); +} + +std::string get_cert_issuer_name(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + auto issuer_name = X509_get_issuer_name(x509); + if (!issuer_name) return ""; + + char buf[256]; + X509_NAME_oneline(issuer_name, buf, sizeof(buf)); + return std::string(buf); +} + +bool get_cert_sans(cert_t cert, std::vector &sans) { + sans.clear(); + if (!cert) return false; + auto x509 = static_cast(cert); + + auto names = static_cast( + X509_get_ext_d2i(x509, NID_subject_alt_name, nullptr, nullptr)); + if (!names) return true; // No SANs is valid + + auto count = sk_GENERAL_NAME_num(names); + for (int i = 0; i < count; i++) { + auto gen = sk_GENERAL_NAME_value(names, i); + if (!gen) continue; + + SanEntry entry; + switch (gen->type) { + case GEN_DNS: + entry.type = SanType::DNS; + if (gen->d.dNSName) { + entry.value = std::string( + reinterpret_cast( + ASN1_STRING_get0_data(gen->d.dNSName)), + static_cast(ASN1_STRING_length(gen->d.dNSName))); + } + break; + case GEN_IPADD: + entry.type = SanType::IP; + if (gen->d.iPAddress) { + auto data = ASN1_STRING_get0_data(gen->d.iPAddress); + auto len = ASN1_STRING_length(gen->d.iPAddress); + if (len == 4) { + // IPv4 + char buf[INET_ADDRSTRLEN]; + inet_ntop(AF_INET, data, buf, sizeof(buf)); + entry.value = buf; + } else if (len == 16) { + // IPv6 + char buf[INET6_ADDRSTRLEN]; + inet_ntop(AF_INET6, data, buf, sizeof(buf)); + entry.value = buf; + } + } + break; + case GEN_EMAIL: + entry.type = SanType::EMAIL; + if (gen->d.rfc822Name) { + entry.value = std::string( + reinterpret_cast( + ASN1_STRING_get0_data(gen->d.rfc822Name)), + static_cast(ASN1_STRING_length(gen->d.rfc822Name))); + } + break; + case GEN_URI: + entry.type = SanType::URI; + if (gen->d.uniformResourceIdentifier) { + entry.value = std::string( + reinterpret_cast( + ASN1_STRING_get0_data(gen->d.uniformResourceIdentifier)), + static_cast( + ASN1_STRING_length(gen->d.uniformResourceIdentifier))); + } + break; + default: entry.type = SanType::OTHER; break; + } + + if (!entry.value.empty()) { sans.push_back(std::move(entry)); } + } + + GENERAL_NAMES_free(names); + return true; +} + +bool get_cert_validity(cert_t cert, time_t ¬_before, + time_t ¬_after) { + if (!cert) return false; + auto x509 = static_cast(cert); + + auto nb = X509_get0_notBefore(x509); + auto na = X509_get0_notAfter(x509); + if (!nb || !na) return false; + + ASN1_TIME *epoch = ASN1_TIME_new(); + if (!epoch) return false; + auto se = detail::scope_exit([&] { ASN1_TIME_free(epoch); }); + + if (!ASN1_TIME_set(epoch, 0)) return false; + + int pday, psec; + + if (!ASN1_TIME_diff(&pday, &psec, epoch, nb)) return false; + not_before = 86400 * (time_t)pday + psec; + + if (!ASN1_TIME_diff(&pday, &psec, epoch, na)) return false; + not_after = 86400 * (time_t)pday + psec; + + return true; +} + +std::string get_cert_serial(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + auto serial = X509_get_serialNumber(x509); + if (!serial) return ""; + + auto bn = ASN1_INTEGER_to_BN(serial, nullptr); + if (!bn) return ""; + + auto hex = BN_bn2hex(bn); + BN_free(bn); + if (!hex) return ""; + + std::string result(hex); + OPENSSL_free(hex); + return result; +} + +bool get_cert_der(cert_t cert, std::vector &der) { + if (!cert) return false; + auto x509 = static_cast(cert); + auto len = i2d_X509(x509, nullptr); + if (len < 0) return false; + der.resize(static_cast(len)); + auto p = der.data(); + i2d_X509(x509, &p); + return true; +} + +const char *get_sni(const_session_t session) { + if (!session) return nullptr; + auto ssl = static_cast(const_cast(session)); + return SSL_get_servername(ssl, TLSEXT_NAMETYPE_host_name); +} + +uint64_t peek_error() { return ERR_peek_last_error(); } + +uint64_t get_error() { return ERR_get_error(); } + +std::string error_string(uint64_t code) { + char buf[256]; + ERR_error_string_n(static_cast(code), buf, sizeof(buf)); + return std::string(buf); +} + +ca_store_t create_ca_store(const char *pem, size_t len) { + auto mem = BIO_new_mem_buf(pem, static_cast(len)); + if (!mem) { return nullptr; } + auto mem_guard = detail::scope_exit([&] { BIO_free_all(mem); }); + + auto inf = PEM_X509_INFO_read_bio(mem, nullptr, nullptr, nullptr); + if (!inf) { return nullptr; } + + auto store = X509_STORE_new(); + if (store) { + for (auto i = 0; i < static_cast(sk_X509_INFO_num(inf)); i++) { + auto itmp = sk_X509_INFO_value(inf, i); + if (!itmp) { continue; } + if (itmp->x509) { X509_STORE_add_cert(store, itmp->x509); } + if (itmp->crl) { X509_STORE_add_crl(store, itmp->crl); } + } + } + + sk_X509_INFO_pop_free(inf, X509_INFO_free); + return static_cast(store); +} + +void free_ca_store(ca_store_t store) { + if (store) { X509_STORE_free(static_cast(store)); } +} + +bool set_ca_store(ctx_t ctx, ca_store_t store) { + if (!ctx || !store) { return false; } + auto ssl_ctx = static_cast(ctx); + auto x509_store = static_cast(store); + + // Check if same store is already set + if (SSL_CTX_get_cert_store(ssl_ctx) == x509_store) { return true; } + + // SSL_CTX_set_cert_store takes ownership and frees the old store + SSL_CTX_set_cert_store(ssl_ctx, x509_store); + return true; +} + +size_t get_ca_certs(ctx_t ctx, std::vector &certs) { + certs.clear(); + if (!ctx) { return 0; } + auto ssl_ctx = static_cast(ctx); + + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) { return 0; } + + auto objs = X509_STORE_get0_objects(store); + if (!objs) { return 0; } + + auto count = sk_X509_OBJECT_num(objs); + for (decltype(count) i = 0; i < count; i++) { + auto obj = sk_X509_OBJECT_value(objs, i); + if (!obj) { continue; } + if (X509_OBJECT_get_type(obj) == X509_LU_X509) { + auto x509 = X509_OBJECT_get0_X509(obj); + if (x509) { + // Increment reference count so caller can free it + X509_up_ref(x509); + certs.push_back(static_cast(x509)); + } + } + } + return certs.size(); +} + +std::vector get_ca_names(ctx_t ctx) { + std::vector names; + if (!ctx) { return names; } + auto ssl_ctx = static_cast(ctx); + + auto store = SSL_CTX_get_cert_store(ssl_ctx); + if (!store) { return names; } + + auto objs = X509_STORE_get0_objects(store); + if (!objs) { return names; } + + auto count = sk_X509_OBJECT_num(objs); + for (decltype(count) i = 0; i < count; i++) { + auto obj = sk_X509_OBJECT_value(objs, i); + if (!obj) { continue; } + if (X509_OBJECT_get_type(obj) == X509_LU_X509) { + auto x509 = X509_OBJECT_get0_X509(obj); + if (x509) { + auto subject = X509_get_subject_name(x509); + if (subject) { + char buf[512]; + X509_NAME_oneline(subject, buf, sizeof(buf)); + names.push_back(buf); + } + } + } + } + return names; +} + +bool update_server_cert(ctx_t ctx, const char *cert_pem, + const char *key_pem, const char *password) { + if (!ctx || !cert_pem || !key_pem) { return false; } + auto ssl_ctx = static_cast(ctx); + + // Load certificate from PEM + auto cert_bio = BIO_new_mem_buf(cert_pem, -1); + if (!cert_bio) { return false; } + auto cert = PEM_read_bio_X509(cert_bio, nullptr, nullptr, nullptr); + BIO_free(cert_bio); + if (!cert) { return false; } + + // Load private key from PEM + auto key_bio = BIO_new_mem_buf(key_pem, -1); + if (!key_bio) { + X509_free(cert); + return false; + } + auto key = PEM_read_bio_PrivateKey(key_bio, nullptr, nullptr, + password ? const_cast(password) + : nullptr); + BIO_free(key_bio); + if (!key) { + X509_free(cert); + return false; + } + + // Update certificate and key + auto ret = SSL_CTX_use_certificate(ssl_ctx, cert) == 1 && + SSL_CTX_use_PrivateKey(ssl_ctx, key) == 1; + + X509_free(cert); + EVP_PKEY_free(key); + return ret; +} + +bool update_server_client_ca(ctx_t ctx, const char *ca_pem) { + if (!ctx || !ca_pem) { return false; } + auto ssl_ctx = static_cast(ctx); + + // Create new X509_STORE from PEM + auto store = create_ca_store(ca_pem, strlen(ca_pem)); + if (!store) { return false; } + + // SSL_CTX_set_cert_store takes ownership + SSL_CTX_set_cert_store(ssl_ctx, static_cast(store)); + + // Set client CA list for client certificate request + auto ca_list = impl::create_client_ca_list_from_pem(ca_pem); + if (ca_list) { + // SSL_CTX_set_client_CA_list takes ownership of ca_list + SSL_CTX_set_client_CA_list(ssl_ctx, ca_list); + } + + return true; +} + +bool set_verify_callback(ctx_t ctx, VerifyCallback callback) { + if (!ctx) { return false; } + auto ssl_ctx = static_cast(ctx); + + impl::get_verify_callback() = std::move(callback); + + if (impl::get_verify_callback()) { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, impl::openssl_verify_callback); + } else { + SSL_CTX_set_verify(ssl_ctx, SSL_VERIFY_PEER, nullptr); + } + return true; +} + +long get_verify_error(const_session_t session) { + if (!session) { return -1; } + auto ssl = static_cast(const_cast(session)); + return SSL_get_verify_result(ssl); +} + +std::string verify_error_string(long error_code) { + if (error_code == X509_V_OK) { return ""; } + const char *str = X509_verify_cert_error_string(static_cast(error_code)); + return str ? str : "unknown error"; +} + +namespace impl { + +// OpenSSL-specific helpers for public API wrappers +ctx_t create_server_context_from_x509(X509 *cert, EVP_PKEY *key, + X509_STORE *client_ca_store, + int &out_error) { + out_error = 0; + auto cert_pem = x509_to_pem(cert); + auto key_pem = evp_pkey_to_pem(key); + if (cert_pem.empty() || key_pem.empty()) { + out_error = static_cast(ERR_get_error()); + return nullptr; + } + + auto ctx = create_server_context(); + if (!ctx) { + out_error = static_cast(get_error()); + return nullptr; + } + + if (!set_server_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr)) { + out_error = static_cast(get_error()); + free_context(ctx); + return nullptr; + } + + if (client_ca_store) { + // Set cert store for verification (SSL_CTX_set_cert_store takes ownership) + SSL_CTX_set_cert_store(static_cast(ctx), client_ca_store); + + // Extract and set client CA list directly from store (more efficient than + // PEM conversion) + auto ca_list = extract_client_ca_list_from_store(client_ca_store); + if (ca_list) { + SSL_CTX_set_client_CA_list(static_cast(ctx), ca_list); + } + + set_verify_client(ctx, true); + } + + return ctx; +} + +void update_server_certs_from_x509(ctx_t ctx, X509 *cert, EVP_PKEY *key, + X509_STORE *client_ca_store) { + auto cert_pem = x509_to_pem(cert); + auto key_pem = evp_pkey_to_pem(key); + + if (!cert_pem.empty() && !key_pem.empty()) { + update_server_cert(ctx, cert_pem.c_str(), key_pem.c_str(), nullptr); + } + + if (client_ca_store) { + auto ca_pem = x509_store_to_pem(client_ca_store); + if (!ca_pem.empty()) { update_server_client_ca(ctx, ca_pem.c_str()); } + X509_STORE_free(client_ca_store); + } +} + +ctx_t create_client_context_from_x509(X509 *cert, EVP_PKEY *key, + const char *password, + unsigned long &out_error) { + out_error = 0; + auto ctx = create_client_context(); + if (!ctx) { + out_error = static_cast(get_error()); + return nullptr; + } + + if (cert && key) { + auto cert_pem = x509_to_pem(cert); + auto key_pem = evp_pkey_to_pem(key); + if (cert_pem.empty() || key_pem.empty()) { + out_error = ERR_get_error(); + free_context(ctx); + return nullptr; + } + if (!set_client_cert_pem(ctx, cert_pem.c_str(), key_pem.c_str(), + password)) { + out_error = static_cast(get_error()); + free_context(ctx); + return nullptr; + } + } + + return ctx; +} + +} // namespace impl + +} // namespace tls + +// ClientImpl::set_ca_cert_store - defined here to use +// tls::impl::x509_store_to_pem Deprecated: converts X509_STORE to PEM and +// stores for redirect transfer +void ClientImpl::set_ca_cert_store(X509_STORE *ca_cert_store) { + if (ca_cert_store) { + ca_cert_pem_ = tls::impl::x509_store_to_pem(ca_cert_store); + } +} + +SSLServer::SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + ctx_ = tls::impl::create_server_context_from_x509( + cert, private_key, client_ca_cert_store, last_ssl_error_); +} + +SSLServer::SSLServer( + const std::function &setup_ssl_ctx_callback) { + // Use abstract API to create context + ctx_ = tls::create_server_context(); + if (ctx_) { + // Pass to OpenSSL-specific callback (ctx_ is SSL_CTX* internally) + auto ssl_ctx = static_cast(ctx_); + if (!setup_ssl_ctx_callback(*ssl_ctx)) { + tls::free_context(ctx_); + ctx_ = nullptr; + } + } +} + +SSL_CTX *SSLServer::ssl_context() const { + return static_cast(ctx_); +} + +void SSLServer::update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store) { + std::lock_guard guard(ctx_mutex_); + tls::impl::update_server_certs_from_x509(ctx_, cert, private_key, + client_ca_cert_store); +} + +SSLClient::SSLClient(const std::string &host, int port, + X509 *client_cert, EVP_PKEY *client_key, + const std::string &private_key_password) + : ClientImpl(host, port) { + const char *password = + private_key_password.empty() ? nullptr : private_key_password.c_str(); + ctx_ = tls::impl::create_client_context_from_x509( + client_cert, client_key, password, last_backend_error_); +} + +long SSLClient::get_verify_result() const { return verify_result_; } + +void SSLClient::set_server_certificate_verifier( + std::function verifier) { + // Wrap SSL* callback into backend-independent session_verifier_ + auto v = std::make_shared>( + std::move(verifier)); + session_verifier_ = [v](tls::session_t session) { + return (*v)(static_cast(session)); + }; +} + +SSL_CTX *SSLClient::ssl_context() const { + return static_cast(ctx_); +} + +bool SSLClient::verify_host(X509 *server_cert) const { + /* Quote from RFC2818 section 3.1 "Server Identity" + + If a subjectAltName extension of type dNSName is present, that MUST + be used as the identity. Otherwise, the (most specific) Common Name + field in the Subject field of the certificate MUST be used. Although + the use of the Common Name is existing practice, it is deprecated and + Certification Authorities are encouraged to use the dNSName instead. + + Matching is performed using the matching rules specified by + [RFC2459]. If more than one identity of a given type is present in + the certificate (e.g., more than one dNSName name, a match in any one + of the set is considered acceptable.) Names may contain the wildcard + character * which is considered to match any single domain name + component or component fragment. E.g., *.a.com matches foo.a.com but + not bar.foo.a.com. f*.com matches foo.com but not bar.com. + + In some cases, the URI is specified as an IP address rather than a + hostname. In this case, the iPAddress subjectAltName must be present + in the certificate and must exactly match the IP in the URI. + + */ + return verify_host_with_subject_alt_name(server_cert) || + verify_host_with_common_name(server_cert); +} + +bool +SSLClient::verify_host_with_subject_alt_name(X509 *server_cert) const { + auto ret = false; + + auto type = GEN_DNS; + + struct in6_addr addr6 = {}; + struct in_addr addr = {}; + size_t addr_len = 0; + +#ifndef __MINGW32__ + if (inet_pton(AF_INET6, host_.c_str(), &addr6)) { + type = GEN_IPADD; + addr_len = sizeof(struct in6_addr); + } else if (inet_pton(AF_INET, host_.c_str(), &addr)) { + type = GEN_IPADD; + addr_len = sizeof(struct in_addr); + } +#endif + + auto alt_names = static_cast( + X509_get_ext_d2i(server_cert, NID_subject_alt_name, nullptr, nullptr)); + + if (alt_names) { + auto dsn_matched = false; + auto ip_matched = false; + + auto count = sk_GENERAL_NAME_num(alt_names); + + for (decltype(count) i = 0; i < count && !dsn_matched; i++) { + auto val = sk_GENERAL_NAME_value(alt_names, i); + if (!val || val->type != type) { continue; } + + auto name = + reinterpret_cast(ASN1_STRING_get0_data(val->d.ia5)); + if (name == nullptr) { continue; } + + auto name_len = static_cast(ASN1_STRING_length(val->d.ia5)); + + switch (type) { + case GEN_DNS: + dsn_matched = + detail::match_hostname(std::string(name, name_len), host_); + break; + + case GEN_IPADD: + if (!memcmp(&addr6, name, addr_len) || !memcmp(&addr, name, addr_len)) { + ip_matched = true; + } + break; + } + } + + if (dsn_matched || ip_matched) { ret = true; } + } + + GENERAL_NAMES_free(const_cast( + reinterpret_cast(alt_names))); + return ret; +} + +bool SSLClient::verify_host_with_common_name(X509 *server_cert) const { + const auto subject_name = X509_get_subject_name(server_cert); + + if (subject_name != nullptr) { + char name[BUFSIZ]; + auto name_len = X509_NAME_get_text_by_NID(subject_name, NID_commonName, + name, sizeof(name)); + + if (name_len != -1) { + return detail::match_hostname( + std::string(name, static_cast(name_len)), host_); + } + } + + return false; +} + +#endif // CPPHTTPLIB_OPENSSL_SUPPORT + +/* + * Group 9: TLS abstraction layer - Mbed TLS backend + */ + +/* + * Mbed TLS Backend Implementation + */ + +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT +namespace tls { + +namespace impl { + +// Mbed TLS session wrapper +struct MbedTlsSession { + mbedtls_ssl_context ssl; + socket_t sock = INVALID_SOCKET; + std::string hostname; // For client: set via set_sni + std::string sni_hostname; // For server: received from client via SNI callback + + MbedTlsSession() { mbedtls_ssl_init(&ssl); } + + ~MbedTlsSession() { mbedtls_ssl_free(&ssl); } + + MbedTlsSession(const MbedTlsSession &) = delete; + MbedTlsSession &operator=(const MbedTlsSession &) = delete; +}; + +// Thread-local error code accessor for Mbed TLS (since it doesn't have an error +// queue) +int &mbedtls_last_error() { + static thread_local int err = 0; + return err; +} + +// Helper to map Mbed TLS error to ErrorCode +ErrorCode map_mbedtls_error(int ret, int &out_errno) { + if (ret == 0) { return ErrorCode::Success; } + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { return ErrorCode::WantRead; } + if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { return ErrorCode::WantWrite; } + if (ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY) { + return ErrorCode::PeerClosed; + } + if (ret == MBEDTLS_ERR_NET_CONN_RESET || ret == MBEDTLS_ERR_NET_SEND_FAILED || + ret == MBEDTLS_ERR_NET_RECV_FAILED) { + out_errno = errno; + return ErrorCode::SyscallError; + } + if (ret == MBEDTLS_ERR_X509_CERT_VERIFY_FAILED) { + return ErrorCode::CertVerifyFailed; + } + return ErrorCode::Fatal; +} + +// BIO-like send callback for Mbed TLS +int mbedtls_net_send_cb(void *ctx, const unsigned char *buf, + size_t len) { + auto sock = *static_cast(ctx); +#ifdef _WIN32 + auto ret = + send(sock, reinterpret_cast(buf), static_cast(len), 0); + if (ret == SOCKET_ERROR) { + int err = WSAGetLastError(); + if (err == WSAEWOULDBLOCK) { return MBEDTLS_ERR_SSL_WANT_WRITE; } + return MBEDTLS_ERR_NET_SEND_FAILED; + } +#else + auto ret = send(sock, buf, len, 0); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return MBEDTLS_ERR_SSL_WANT_WRITE; + } + return MBEDTLS_ERR_NET_SEND_FAILED; + } +#endif + return static_cast(ret); +} + +// BIO-like recv callback for Mbed TLS +int mbedtls_net_recv_cb(void *ctx, unsigned char *buf, size_t len) { + auto sock = *static_cast(ctx); +#ifdef _WIN32 + auto ret = + recv(sock, reinterpret_cast(buf), static_cast(len), 0); + if (ret == SOCKET_ERROR) { + int err = WSAGetLastError(); + if (err == WSAEWOULDBLOCK) { return MBEDTLS_ERR_SSL_WANT_READ; } + return MBEDTLS_ERR_NET_RECV_FAILED; + } +#else + auto ret = recv(sock, buf, len, 0); + if (ret < 0) { + if (errno == EAGAIN || errno == EWOULDBLOCK) { + return MBEDTLS_ERR_SSL_WANT_READ; + } + return MBEDTLS_ERR_NET_RECV_FAILED; + } +#endif + if (ret == 0) { return MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY; } + return static_cast(ret); +} + +// MbedTlsContext constructor/destructor implementations +MbedTlsContext::MbedTlsContext() { + mbedtls_ssl_config_init(&conf); + mbedtls_entropy_init(&entropy); + mbedtls_ctr_drbg_init(&ctr_drbg); + mbedtls_x509_crt_init(&ca_chain); + mbedtls_x509_crt_init(&own_cert); + mbedtls_pk_init(&own_key); +} + +MbedTlsContext::~MbedTlsContext() { + mbedtls_pk_free(&own_key); + mbedtls_x509_crt_free(&own_cert); + mbedtls_x509_crt_free(&ca_chain); + mbedtls_ctr_drbg_free(&ctr_drbg); + mbedtls_entropy_free(&entropy); + mbedtls_ssl_config_free(&conf); +} + +// Thread-local storage for SNI captured during handshake +// This is needed because the SNI callback doesn't have a way to pass +// session-specific data before the session is fully set up +std::string &mbedpending_sni() { + static thread_local std::string sni; + return sni; +} + +// SNI callback for Mbed TLS server to capture client's SNI hostname +int mbedtls_sni_callback(void *p_ctx, mbedtls_ssl_context *ssl, + const unsigned char *name, size_t name_len) { + (void)p_ctx; + (void)ssl; + + // Store SNI name in thread-local storage + // It will be retrieved and stored in the session after handshake + if (name && name_len > 0) { + mbedpending_sni().assign(reinterpret_cast(name), name_len); + } else { + mbedpending_sni().clear(); + } + return 0; // Accept any SNI +} + +int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt, + int cert_depth, uint32_t *flags); + +// Check if a string is an IPv4 address +bool is_ipv4_address(const std::string &str) { + int dots = 0; + for (char c : str) { + if (c == '.') { + dots++; + } else if (!isdigit(static_cast(c))) { + return false; + } + } + return dots == 3; +} + +// Parse IPv4 address string to bytes +bool parse_ipv4(const std::string &str, unsigned char *out) { + int parts[4]; + if (sscanf(str.c_str(), "%d.%d.%d.%d", &parts[0], &parts[1], &parts[2], + &parts[3]) != 4) { + return false; + } + for (int i = 0; i < 4; i++) { + if (parts[i] < 0 || parts[i] > 255) return false; + out[i] = static_cast(parts[i]); + } + return true; +} + +// MbedTLS verify callback wrapper +int mbedtls_verify_callback(void *data, mbedtls_x509_crt *crt, + int cert_depth, uint32_t *flags) { + auto &callback = get_verify_callback(); + if (!callback) { return 0; } // Continue with default verification + + // data points to the MbedTlsSession + auto *session = static_cast(data); + + // Build context + VerifyContext verify_ctx; + verify_ctx.session = static_cast(session); + verify_ctx.cert = static_cast(crt); + verify_ctx.depth = cert_depth; + verify_ctx.preverify_ok = (*flags == 0); + verify_ctx.error_code = static_cast(*flags); + + // Convert Mbed TLS flags to error string + static thread_local char error_buf[256]; + if (*flags != 0) { + mbedtls_x509_crt_verify_info(error_buf, sizeof(error_buf), "", *flags); + verify_ctx.error_string = error_buf; + } else { + verify_ctx.error_string = nullptr; + } + + bool accepted = callback(verify_ctx); + + if (accepted) { + *flags = 0; // Clear all error flags + return 0; + } + return MBEDTLS_ERR_X509_CERT_VERIFY_FAILED; +} + +} // namespace impl + +ctx_t create_client_context() { + auto ctx = new (std::nothrow) impl::MbedTlsContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = false; + + // Seed the random number generator + const char *pers = "httplib_client"; + int ret = mbedtls_ctr_drbg_seed( + &ctx->ctr_drbg, mbedtls_entropy_func, &ctx->entropy, + reinterpret_cast(pers), strlen(pers)); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set up SSL config for client + ret = mbedtls_ssl_config_defaults(&ctx->conf, MBEDTLS_SSL_IS_CLIENT, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set random number generator + mbedtls_ssl_conf_rng(&ctx->conf, mbedtls_ctr_drbg_random, &ctx->ctr_drbg); + + // Default: verify peer certificate + mbedtls_ssl_conf_authmode(&ctx->conf, MBEDTLS_SSL_VERIFY_REQUIRED); + + // Set minimum TLS version to 1.2 +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_ssl_conf_min_tls_version(&ctx->conf, MBEDTLS_SSL_VERSION_TLS1_2); +#else + mbedtls_ssl_conf_min_version(&ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, + MBEDTLS_SSL_MINOR_VERSION_3); +#endif + + return static_cast(ctx); +} + +ctx_t create_server_context() { + auto ctx = new (std::nothrow) impl::MbedTlsContext(); + if (!ctx) { return nullptr; } + + ctx->is_server = true; + + // Seed the random number generator + const char *pers = "httplib_server"; + int ret = mbedtls_ctr_drbg_seed( + &ctx->ctr_drbg, mbedtls_entropy_func, &ctx->entropy, + reinterpret_cast(pers), strlen(pers)); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set up SSL config for server + ret = mbedtls_ssl_config_defaults(&ctx->conf, MBEDTLS_SSL_IS_SERVER, + MBEDTLS_SSL_TRANSPORT_STREAM, + MBEDTLS_SSL_PRESET_DEFAULT); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete ctx; + return nullptr; + } + + // Set random number generator + mbedtls_ssl_conf_rng(&ctx->conf, mbedtls_ctr_drbg_random, &ctx->ctr_drbg); + + // Default: don't verify client + mbedtls_ssl_conf_authmode(&ctx->conf, MBEDTLS_SSL_VERIFY_NONE); + + // Set minimum TLS version to 1.2 +#ifdef CPPHTTPLIB_MBEDTLS_V3 + mbedtls_ssl_conf_min_tls_version(&ctx->conf, MBEDTLS_SSL_VERSION_TLS1_2); +#else + mbedtls_ssl_conf_min_version(&ctx->conf, MBEDTLS_SSL_MAJOR_VERSION_3, + MBEDTLS_SSL_MINOR_VERSION_3); +#endif + + // Set SNI callback to capture client's SNI hostname + mbedtls_ssl_conf_sni(&ctx->conf, impl::mbedtls_sni_callback, nullptr); + + return static_cast(ctx); +} + +void free_context(ctx_t ctx) { + if (ctx) { delete static_cast(ctx); } +} + +bool set_min_version(ctx_t ctx, Version version) { + if (!ctx) { return false; } + auto mctx = static_cast(ctx); + +#ifdef CPPHTTPLIB_MBEDTLS_V3 + // Mbed TLS 3.x uses mbedtls_ssl_protocol_version enum + mbedtls_ssl_protocol_version min_ver = MBEDTLS_SSL_VERSION_TLS1_2; + if (version >= Version::TLS1_3) { +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) + min_ver = MBEDTLS_SSL_VERSION_TLS1_3; +#endif + } + mbedtls_ssl_conf_min_tls_version(&mctx->conf, min_ver); +#else + // Mbed TLS 2.x uses major/minor version numbers + int major = MBEDTLS_SSL_MAJOR_VERSION_3; + int minor = MBEDTLS_SSL_MINOR_VERSION_3; // TLS 1.2 + if (version >= Version::TLS1_3) { +#if defined(MBEDTLS_SSL_PROTO_TLS1_3) + minor = MBEDTLS_SSL_MINOR_VERSION_4; // TLS 1.3 +#else + minor = MBEDTLS_SSL_MINOR_VERSION_3; // Fall back to TLS 1.2 +#endif + } + mbedtls_ssl_conf_min_version(&mctx->conf, major, minor); +#endif + return true; +} + +bool load_ca_pem(ctx_t ctx, const char *pem, size_t len) { + if (!ctx || !pem) { return false; } + auto mctx = static_cast(ctx); + + // mbedtls_x509_crt_parse expects null-terminated string for PEM + // Add null terminator if not present + std::string pem_str(pem, len); + int ret = mbedtls_x509_crt_parse( + &mctx->ca_chain, reinterpret_cast(pem_str.c_str()), + pem_str.size() + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + return true; +} + +bool load_ca_file(ctx_t ctx, const char *file_path) { + if (!ctx || !file_path) { return false; } + auto mctx = static_cast(ctx); + + int ret = mbedtls_x509_crt_parse_file(&mctx->ca_chain, file_path); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + return true; +} + +bool load_ca_dir(ctx_t ctx, const char *dir_path) { + if (!ctx || !dir_path) { return false; } + auto mctx = static_cast(ctx); + + int ret = mbedtls_x509_crt_parse_path(&mctx->ca_chain, dir_path); + if (ret < 0) { // Returns number of certs on success, negative on error + impl::mbedtls_last_error() = ret; + return false; + } + + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + return true; +} + +bool load_system_certs(ctx_t ctx) { + if (!ctx) { return false; } + auto mctx = static_cast(ctx); + bool loaded = false; + +#ifdef _WIN32 + // Load from Windows certificate store (ROOT and CA) + static const wchar_t *store_names[] = {L"ROOT", L"CA"}; + for (auto store_name : store_names) { + HCERTSTORE hStore = CertOpenSystemStoreW(0, store_name); + if (hStore) { + PCCERT_CONTEXT pContext = nullptr; + while ((pContext = CertEnumCertificatesInStore(hStore, pContext)) != + nullptr) { + int ret = mbedtls_x509_crt_parse_der( + &mctx->ca_chain, pContext->pbCertEncoded, pContext->cbCertEncoded); + if (ret == 0) { loaded = true; } + } + CertCloseStore(hStore, 0); + } + } +#elif defined(__APPLE__) && defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) + // Load from macOS Keychain + CFArrayRef certs = nullptr; + OSStatus status = SecTrustCopyAnchorCertificates(&certs); + if (status == errSecSuccess && certs) { + CFIndex count = CFArrayGetCount(certs); + for (CFIndex i = 0; i < count; i++) { + SecCertificateRef cert = + (SecCertificateRef)CFArrayGetValueAtIndex(certs, i); + CFDataRef data = SecCertificateCopyData(cert); + if (data) { + int ret = mbedtls_x509_crt_parse_der( + &mctx->ca_chain, CFDataGetBytePtr(data), + static_cast(CFDataGetLength(data))); + if (ret == 0) { loaded = true; } + CFRelease(data); + } + } + CFRelease(certs); + } +#else + // Try common CA certificate locations on Linux/Unix + static const char *ca_paths[] = { + "/etc/ssl/certs/ca-certificates.crt", // Debian/Ubuntu + "/etc/pki/tls/certs/ca-bundle.crt", // RHEL/CentOS + "/etc/ssl/ca-bundle.pem", // OpenSUSE + "/etc/pki/tls/cacert.pem", // OpenELEC + "/etc/ssl/cert.pem", // Alpine, FreeBSD + nullptr}; + + for (const char **path = ca_paths; *path; ++path) { + int ret = mbedtls_x509_crt_parse_file(&mctx->ca_chain, *path); + if (ret >= 0) { + loaded = true; + break; + } + } + + // Also try the CA directory + if (!loaded) { + static const char *ca_dirs[] = {"/etc/ssl/certs", // Debian/Ubuntu + "/etc/pki/tls/certs", // RHEL/CentOS + "/usr/share/ca-certificates", nullptr}; + + for (const char **dir = ca_dirs; *dir; ++dir) { + int ret = mbedtls_x509_crt_parse_path(&mctx->ca_chain, *dir); + if (ret >= 0) { + loaded = true; + break; + } + } + } +#endif + + if (loaded) { + mbedtls_ssl_conf_ca_chain(&mctx->conf, &mctx->ca_chain, nullptr); + } + return loaded; +} + +bool set_client_cert_pem(ctx_t ctx, const char *cert, const char *key, + const char *password) { + if (!ctx || !cert || !key) { return false; } + auto mctx = static_cast(ctx); + + // Parse certificate + std::string cert_str(cert); + int ret = mbedtls_x509_crt_parse( + &mctx->own_cert, + reinterpret_cast(cert_str.c_str()), + cert_str.size() + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Parse private key + std::string key_str(key); + const unsigned char *pwd = + password ? reinterpret_cast(password) : nullptr; + size_t pwd_len = password ? strlen(password) : 0; + +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_parse_key( + &mctx->own_key, reinterpret_cast(key_str.c_str()), + key_str.size() + 1, pwd, pwd_len, mbedtls_ctr_drbg_random, + &mctx->ctr_drbg); +#else + ret = mbedtls_pk_parse_key( + &mctx->own_key, reinterpret_cast(key_str.c_str()), + key_str.size() + 1, pwd, pwd_len); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + return true; +} + +bool set_client_cert_file(ctx_t ctx, const char *cert_path, + const char *key_path, const char *password) { + if (!ctx || !cert_path || !key_path) { return false; } + auto mctx = static_cast(ctx); + + // Parse certificate file + int ret = mbedtls_x509_crt_parse_file(&mctx->own_cert, cert_path); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Parse private key file +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_parse_keyfile(&mctx->own_key, key_path, password, + mbedtls_ctr_drbg_random, &mctx->ctr_drbg); +#else + ret = mbedtls_pk_parse_keyfile(&mctx->own_key, key_path, password); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + ret = mbedtls_ssl_conf_own_cert(&mctx->conf, &mctx->own_cert, &mctx->own_key); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + return true; +} + +void set_verify_client(ctx_t ctx, bool require) { + if (!ctx) { return; } + auto mctx = static_cast(ctx); + mctx->verify_client = require; + if (require) { + mbedtls_ssl_conf_authmode(&mctx->conf, MBEDTLS_SSL_VERIFY_REQUIRED); + } else { + // If a verify callback is set, use OPTIONAL mode to ensure the callback + // is called (matching OpenSSL behavior). Otherwise use NONE. + mbedtls_ssl_conf_authmode(&mctx->conf, mctx->has_verify_callback + ? MBEDTLS_SSL_VERIFY_OPTIONAL + : MBEDTLS_SSL_VERIFY_NONE); + } +} + +session_t create_session(ctx_t ctx, socket_t sock) { + if (!ctx || sock == INVALID_SOCKET) { return nullptr; } + auto mctx = static_cast(ctx); + + auto session = new (std::nothrow) impl::MbedTlsSession(); + if (!session) { return nullptr; } + + session->sock = sock; + + int ret = mbedtls_ssl_setup(&session->ssl, &mctx->conf); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + delete session; + return nullptr; + } + + // Set BIO callbacks + mbedtls_ssl_set_bio(&session->ssl, &session->sock, impl::mbedtls_net_send_cb, + impl::mbedtls_net_recv_cb, nullptr); + + // Set per-session verify callback with session pointer if callback is + // registered + if (mctx->has_verify_callback) { + mbedtls_ssl_set_verify(&session->ssl, impl::mbedtls_verify_callback, + session); + } + + return static_cast(session); +} + +void free_session(session_t session) { + if (session) { delete static_cast(session); } +} + +bool set_sni(session_t session, const char *hostname) { + if (!session || !hostname) { return false; } + auto msession = static_cast(session); + + int ret = mbedtls_ssl_set_hostname(&msession->ssl, hostname); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + msession->hostname = hostname; + return true; +} + +bool set_hostname(session_t session, const char *hostname) { + // In Mbed TLS, set_hostname also sets up hostname verification + return set_sni(session, hostname); +} + +TlsError connect(session_t session) { + TlsError err; + if (!session) { + err.code = ErrorCode::Fatal; + return err; + } + + auto msession = static_cast(session); + int ret = mbedtls_ssl_handshake(&msession->ssl); + + if (ret == 0) { + err.code = ErrorCode::Success; + } else { + err.code = impl::map_mbedtls_error(ret, err.sys_errno); + err.backend_code = static_cast(-ret); + impl::mbedtls_last_error() = ret; + } + + return err; +} + +TlsError accept(session_t session) { + // Same as connect for Mbed TLS - handshake works for both client and server + auto result = connect(session); + + // After successful handshake, capture SNI from thread-local storage + if (result.code == ErrorCode::Success && session) { + auto msession = static_cast(session); + msession->sni_hostname = std::move(impl::mbedpending_sni()); + impl::mbedpending_sni().clear(); + } + + return result; +} + +bool connect_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + if (!session) { + if (err) { err->code = ErrorCode::Fatal; } + return false; + } + + auto msession = static_cast(session); + + // Set socket to non-blocking mode + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + int ret; + while ((ret = mbedtls_ssl_handshake(&msession->ssl)) != 0) { + if (ret == MBEDTLS_ERR_SSL_WANT_READ) { + if (detail::select_read(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } else if (ret == MBEDTLS_ERR_SSL_WANT_WRITE) { + if (detail::select_write(sock, timeout_sec, timeout_usec) > 0) { + continue; + } + } + + // TlsError or timeout + if (err) { + err->code = impl::map_mbedtls_error(ret, err->sys_errno); + err->backend_code = static_cast(-ret); + } + impl::mbedtls_last_error() = ret; + return false; + } + + if (err) { err->code = ErrorCode::Success; } + return true; +} + +bool accept_nonblocking(session_t session, socket_t sock, + time_t timeout_sec, time_t timeout_usec, + TlsError *err) { + // Same implementation as connect for Mbed TLS + bool result = + connect_nonblocking(session, sock, timeout_sec, timeout_usec, err); + + // After successful handshake, capture SNI from thread-local storage + if (result && session) { + auto msession = static_cast(session); + msession->sni_hostname = std::move(impl::mbedpending_sni()); + impl::mbedpending_sni().clear(); + } + + return result; +} + +ssize_t read(session_t session, void *buf, size_t len, TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto msession = static_cast(session); + int ret = + mbedtls_ssl_read(&msession->ssl, static_cast(buf), len); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return 0; + } + + err.code = impl::map_mbedtls_error(ret, err.sys_errno); + err.backend_code = static_cast(-ret); + impl::mbedtls_last_error() = ret; + return -1; +} + +ssize_t write(session_t session, const void *buf, size_t len, + TlsError &err) { + if (!session || !buf) { + err.code = ErrorCode::Fatal; + return -1; + } + + auto msession = static_cast(session); + int ret = mbedtls_ssl_write(&msession->ssl, + static_cast(buf), len); + + if (ret > 0) { + err.code = ErrorCode::Success; + return static_cast(ret); + } + + if (ret == 0) { + err.code = ErrorCode::PeerClosed; + return 0; + } + + err.code = impl::map_mbedtls_error(ret, err.sys_errno); + err.backend_code = static_cast(-ret); + impl::mbedtls_last_error() = ret; + return -1; +} + +int pending(const_session_t session) { + if (!session) { return 0; } + auto msession = + static_cast(const_cast(session)); + return static_cast(mbedtls_ssl_get_bytes_avail(&msession->ssl)); +} + +void shutdown(session_t session, bool graceful) { + if (!session) { return; } + auto msession = static_cast(session); + + if (graceful) { + // Try to send close_notify, but don't block forever + int ret; + int attempts = 0; + while ((ret = mbedtls_ssl_close_notify(&msession->ssl)) != 0 && + attempts < 3) { + if (ret != MBEDTLS_ERR_SSL_WANT_READ && + ret != MBEDTLS_ERR_SSL_WANT_WRITE) { + break; + } + attempts++; + } + } +} + +bool is_peer_closed(session_t session, socket_t sock) { + if (!session || sock == INVALID_SOCKET) { return true; } + auto msession = static_cast(session); + + // Check if there's already decrypted data available in the TLS buffer + // If so, the connection is definitely alive + if (mbedtls_ssl_get_bytes_avail(&msession->ssl) > 0) { return false; } + + // Set socket to non-blocking to avoid blocking on read + detail::set_nonblocking(sock, true); + auto cleanup = + detail::scope_exit([&]() { detail::set_nonblocking(sock, false); }); + + // Try a 1-byte read to check connection status + // Note: This will consume the byte if data is available, but for the + // purpose of checking if peer is closed, this should be acceptable + // since we're only called when we expect the connection might be closing + unsigned char buf; + int ret = mbedtls_ssl_read(&msession->ssl, &buf, 1); + + // If we got data or WANT_READ (would block), connection is alive + if (ret > 0 || ret == MBEDTLS_ERR_SSL_WANT_READ) { return false; } + + // If we get a peer close notify or a connection reset, the peer is closed + return ret == MBEDTLS_ERR_SSL_PEER_CLOSE_NOTIFY || + ret == MBEDTLS_ERR_NET_CONN_RESET || ret == 0; +} + +cert_t get_peer_cert(const_session_t session) { + if (!session) { return nullptr; } + auto msession = + static_cast(const_cast(session)); + + // Mbed TLS returns a pointer to the internal peer cert chain. + // WARNING: This pointer is only valid while the session is active. + // Do not use the certificate after calling free_session(). + const mbedtls_x509_crt *cert = mbedtls_ssl_get_peer_cert(&msession->ssl); + return const_cast(cert); +} + +void free_cert(cert_t cert) { + // Mbed TLS: peer certificate is owned by the SSL context. + // No-op here, but callers should still call this for cross-backend + // portability. + (void)cert; +} + +bool verify_hostname(cert_t cert, const char *hostname) { + if (!cert || !hostname) { return false; } + auto mcert = static_cast(cert); + std::string host_str(hostname); + + // Check if hostname is an IP address + bool is_ip = impl::is_ipv4_address(host_str); + unsigned char ip_bytes[4]; + if (is_ip) { impl::parse_ipv4(host_str, ip_bytes); } + + // Check Subject Alternative Names (SAN) + // In Mbed TLS 3.x, subject_alt_names contains raw values without ASN.1 tags + // - DNS names: raw string bytes + // - IP addresses: raw IP bytes (4 for IPv4, 16 for IPv6) + const mbedtls_x509_sequence *san = &mcert->subject_alt_names; + while (san != nullptr && san->buf.p != nullptr && san->buf.len > 0) { + const unsigned char *p = san->buf.p; + size_t len = san->buf.len; + + if (is_ip) { + // Check if this SAN is an IPv4 address (4 bytes) + if (len == 4 && memcmp(p, ip_bytes, 4) == 0) { return true; } + // Check if this SAN is an IPv6 address (16 bytes) - skip for now + } else { + // Check if this SAN is a DNS name (printable ASCII string) + bool is_dns = len > 0; + for (size_t i = 0; i < len && is_dns; i++) { + if (p[i] < 32 || p[i] > 126) { is_dns = false; } + } + if (is_dns) { + std::string san_name(reinterpret_cast(p), len); + if (detail::match_hostname(san_name, host_str)) { return true; } + } + } + san = san->next; + } + + // Fallback: Check Common Name (CN) in subject + char cn[256]; + int ret = mbedtls_x509_dn_gets(cn, sizeof(cn), &mcert->subject); + if (ret > 0) { + std::string cn_str(cn); + + // Look for "CN=" in the DN string + size_t cn_pos = cn_str.find("CN="); + if (cn_pos != std::string::npos) { + size_t start = cn_pos + 3; + size_t end = cn_str.find(',', start); + std::string cn_value = + cn_str.substr(start, end == std::string::npos ? end : end - start); + + if (detail::match_hostname(cn_value, host_str)) { return true; } + } + } + + return false; +} + +uint64_t hostname_mismatch_code() { + return static_cast(MBEDTLS_X509_BADCERT_CN_MISMATCH); +} + +long get_verify_result(const_session_t session) { + if (!session) { return -1; } + auto msession = + static_cast(const_cast(session)); + uint32_t flags = mbedtls_ssl_get_verify_result(&msession->ssl); + // Return 0 (X509_V_OK equivalent) if verification passed + return flags == 0 ? 0 : static_cast(flags); +} + +std::string get_cert_subject_cn(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + // Find the CN in the subject + const mbedtls_x509_name *name = &x509->subject; + while (name != nullptr) { + if (MBEDTLS_OID_CMP(MBEDTLS_OID_AT_CN, &name->oid) == 0) { + return std::string(reinterpret_cast(name->val.p), + name->val.len); + } + name = name->next; + } + return ""; +} + +std::string get_cert_issuer_name(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + // Build a human-readable issuer name string + char buf[512]; + int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), &x509->issuer); + if (ret < 0) return ""; + return std::string(buf); +} + +bool get_cert_sans(cert_t cert, std::vector &sans) { + sans.clear(); + if (!cert) return false; + auto x509 = static_cast(cert); + + // Parse the Subject Alternative Name extension + const mbedtls_x509_sequence *cur = &x509->subject_alt_names; + while (cur != nullptr) { + if (cur->buf.len > 0) { + // Mbed TLS stores SAN as ASN.1 sequences + // The tag byte indicates the type + const unsigned char *p = cur->buf.p; + size_t len = cur->buf.len; + + // First byte is the tag + unsigned char tag = *p; + p++; + len--; + + // Parse length (simple single-byte length assumed) + if (len > 0 && *p < 0x80) { + size_t value_len = *p; + p++; + len--; + + if (value_len <= len) { + SanEntry entry; + // ASN.1 context tags for GeneralName + switch (tag & 0x1F) { + case 2: // dNSName + entry.type = SanType::DNS; + entry.value = + std::string(reinterpret_cast(p), value_len); + break; + case 7: // iPAddress + entry.type = SanType::IP; + if (value_len == 4) { + // IPv4 + char buf[16]; + snprintf(buf, sizeof(buf), "%d.%d.%d.%d", p[0], p[1], p[2], p[3]); + entry.value = buf; + } else if (value_len == 16) { + // IPv6 + char buf[64]; + snprintf(buf, sizeof(buf), + "%02x%02x:%02x%02x:%02x%02x:%02x%02x:" + "%02x%02x:%02x%02x:%02x%02x:%02x%02x", + p[0], p[1], p[2], p[3], p[4], p[5], p[6], p[7], p[8], + p[9], p[10], p[11], p[12], p[13], p[14], p[15]); + entry.value = buf; + } + break; + case 1: // rfc822Name (email) + entry.type = SanType::EMAIL; + entry.value = + std::string(reinterpret_cast(p), value_len); + break; + case 6: // uniformResourceIdentifier + entry.type = SanType::URI; + entry.value = + std::string(reinterpret_cast(p), value_len); + break; + default: entry.type = SanType::OTHER; break; + } + + if (!entry.value.empty()) { sans.push_back(std::move(entry)); } + } + } + } + cur = cur->next; + } + return true; +} + +bool get_cert_validity(cert_t cert, time_t ¬_before, + time_t ¬_after) { + if (!cert) return false; + auto x509 = static_cast(cert); + + // Convert mbedtls_x509_time to time_t + auto to_time_t = [](const mbedtls_x509_time &t) -> time_t { + struct tm tm_time = {}; + tm_time.tm_year = t.year - 1900; + tm_time.tm_mon = t.mon - 1; + tm_time.tm_mday = t.day; + tm_time.tm_hour = t.hour; + tm_time.tm_min = t.min; + tm_time.tm_sec = t.sec; +#ifdef _WIN32 + return _mkgmtime(&tm_time); +#else + return timegm(&tm_time); +#endif + }; + + not_before = to_time_t(x509->valid_from); + not_after = to_time_t(x509->valid_to); + return true; +} + +std::string get_cert_serial(cert_t cert) { + if (!cert) return ""; + auto x509 = static_cast(cert); + + // Convert serial number to hex string + std::string result; + result.reserve(x509->serial.len * 2); + for (size_t i = 0; i < x509->serial.len; i++) { + char hex[3]; + snprintf(hex, sizeof(hex), "%02X", x509->serial.p[i]); + result += hex; + } + return result; +} + +bool get_cert_der(cert_t cert, std::vector &der) { + if (!cert) return false; + auto crt = static_cast(cert); + if (!crt->raw.p || crt->raw.len == 0) return false; + der.assign(crt->raw.p, crt->raw.p + crt->raw.len); + return true; +} + +const char *get_sni(const_session_t session) { + if (!session) return nullptr; + auto msession = static_cast(session); + + // For server: return SNI received from client during handshake + if (!msession->sni_hostname.empty()) { + return msession->sni_hostname.c_str(); + } + + // For client: return the hostname set via set_sni + if (!msession->hostname.empty()) { return msession->hostname.c_str(); } + + return nullptr; +} + +uint64_t peek_error() { + // Mbed TLS doesn't have an error queue, return the last error + return static_cast(-impl::mbedtls_last_error()); +} + +uint64_t get_error() { + // Mbed TLS doesn't have an error queue, return and clear the last error + uint64_t err = static_cast(-impl::mbedtls_last_error()); + impl::mbedtls_last_error() = 0; + return err; +} + +std::string error_string(uint64_t code) { + char buf[256]; + mbedtls_strerror(-static_cast(code), buf, sizeof(buf)); + return std::string(buf); +} + +ca_store_t create_ca_store(const char *pem, size_t len) { + auto *ca_chain = new (std::nothrow) mbedtls_x509_crt; + if (!ca_chain) { return nullptr; } + + mbedtls_x509_crt_init(ca_chain); + + // mbedtls_x509_crt_parse expects null-terminated PEM + int ret = mbedtls_x509_crt_parse(ca_chain, + reinterpret_cast(pem), + len + 1); // +1 for null terminator + if (ret != 0) { + // Try without +1 in case PEM is already null-terminated + ret = mbedtls_x509_crt_parse( + ca_chain, reinterpret_cast(pem), len); + if (ret != 0) { + mbedtls_x509_crt_free(ca_chain); + delete ca_chain; + return nullptr; + } + } + + return static_cast(ca_chain); +} + +void free_ca_store(ca_store_t store) { + if (store) { + auto *ca_chain = static_cast(store); + mbedtls_x509_crt_free(ca_chain); + delete ca_chain; + } +} + +bool set_ca_store(ctx_t ctx, ca_store_t store) { + if (!ctx || !store) { return false; } + auto *mbed_ctx = static_cast(ctx); + auto *ca_chain = static_cast(store); + + // Free existing CA chain + mbedtls_x509_crt_free(&mbed_ctx->ca_chain); + mbedtls_x509_crt_init(&mbed_ctx->ca_chain); + + // Copy the CA chain (deep copy) + // Parse from the raw data of the source cert + mbedtls_x509_crt *src = ca_chain; + while (src != nullptr) { + int ret = mbedtls_x509_crt_parse_der(&mbed_ctx->ca_chain, src->raw.p, + src->raw.len); + if (ret != 0) { return false; } + src = src->next; + } + + // Update the SSL config to use the new CA chain + mbedtls_ssl_conf_ca_chain(&mbed_ctx->conf, &mbed_ctx->ca_chain, nullptr); + return true; +} + +size_t get_ca_certs(ctx_t ctx, std::vector &certs) { + certs.clear(); + if (!ctx) { return 0; } + auto *mbed_ctx = static_cast(ctx); + + // Iterate through the CA chain + mbedtls_x509_crt *cert = &mbed_ctx->ca_chain; + while (cert != nullptr && cert->raw.len > 0) { + // Create a copy of the certificate for the caller + auto *copy = new mbedtls_x509_crt; + mbedtls_x509_crt_init(copy); + int ret = mbedtls_x509_crt_parse_der(copy, cert->raw.p, cert->raw.len); + if (ret == 0) { + certs.push_back(static_cast(copy)); + } else { + mbedtls_x509_crt_free(copy); + delete copy; + } + cert = cert->next; + } + return certs.size(); +} + +std::vector get_ca_names(ctx_t ctx) { + std::vector names; + if (!ctx) { return names; } + auto *mbed_ctx = static_cast(ctx); + + // Iterate through the CA chain + mbedtls_x509_crt *cert = &mbed_ctx->ca_chain; + while (cert != nullptr && cert->raw.len > 0) { + char buf[512]; + int ret = mbedtls_x509_dn_gets(buf, sizeof(buf), &cert->subject); + if (ret > 0) { names.push_back(buf); } + cert = cert->next; + } + return names; +} + +bool update_server_cert(ctx_t ctx, const char *cert_pem, + const char *key_pem, const char *password) { + if (!ctx || !cert_pem || !key_pem) { return false; } + auto *mbed_ctx = static_cast(ctx); + + // Free existing certificate and key + mbedtls_x509_crt_free(&mbed_ctx->own_cert); + mbedtls_pk_free(&mbed_ctx->own_key); + mbedtls_x509_crt_init(&mbed_ctx->own_cert); + mbedtls_pk_init(&mbed_ctx->own_key); + + // Parse certificate PEM + int ret = mbedtls_x509_crt_parse( + &mbed_ctx->own_cert, reinterpret_cast(cert_pem), + strlen(cert_pem) + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Parse private key PEM +#ifdef CPPHTTPLIB_MBEDTLS_V3 + ret = mbedtls_pk_parse_key( + &mbed_ctx->own_key, reinterpret_cast(key_pem), + strlen(key_pem) + 1, + password ? reinterpret_cast(password) : nullptr, + password ? strlen(password) : 0, mbedtls_ctr_drbg_random, + &mbed_ctx->ctr_drbg); +#else + ret = mbedtls_pk_parse_key( + &mbed_ctx->own_key, reinterpret_cast(key_pem), + strlen(key_pem) + 1, + password ? reinterpret_cast(password) : nullptr, + password ? strlen(password) : 0); +#endif + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Configure SSL to use the new certificate and key + ret = mbedtls_ssl_conf_own_cert(&mbed_ctx->conf, &mbed_ctx->own_cert, + &mbed_ctx->own_key); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + return true; +} + +bool update_server_client_ca(ctx_t ctx, const char *ca_pem) { + if (!ctx || !ca_pem) { return false; } + auto *mbed_ctx = static_cast(ctx); + + // Free existing CA chain + mbedtls_x509_crt_free(&mbed_ctx->ca_chain); + mbedtls_x509_crt_init(&mbed_ctx->ca_chain); + + // Parse CA PEM + int ret = mbedtls_x509_crt_parse( + &mbed_ctx->ca_chain, reinterpret_cast(ca_pem), + strlen(ca_pem) + 1); + if (ret != 0) { + impl::mbedtls_last_error() = ret; + return false; + } + + // Update SSL config to use new CA chain + mbedtls_ssl_conf_ca_chain(&mbed_ctx->conf, &mbed_ctx->ca_chain, nullptr); + return true; +} + +bool set_verify_callback(ctx_t ctx, VerifyCallback callback) { + if (!ctx) { return false; } + auto *mbed_ctx = static_cast(ctx); + + impl::get_verify_callback() = std::move(callback); + mbed_ctx->has_verify_callback = + static_cast(impl::get_verify_callback()); + + if (mbed_ctx->has_verify_callback) { + // Set OPTIONAL mode to ensure callback is called even when verification + // is disabled (matching OpenSSL behavior where SSL_VERIFY_PEER is set) + mbedtls_ssl_conf_authmode(&mbed_ctx->conf, MBEDTLS_SSL_VERIFY_OPTIONAL); + mbedtls_ssl_conf_verify(&mbed_ctx->conf, impl::mbedtls_verify_callback, + nullptr); + } else { + mbedtls_ssl_conf_verify(&mbed_ctx->conf, nullptr, nullptr); + } + return true; +} + +long get_verify_error(const_session_t session) { + if (!session) { return -1; } + auto *msession = + static_cast(const_cast(session)); + return static_cast(mbedtls_ssl_get_verify_result(&msession->ssl)); +} + +std::string verify_error_string(long error_code) { + if (error_code == 0) { return ""; } + char buf[256]; + mbedtls_x509_crt_verify_info(buf, sizeof(buf), "", + static_cast(error_code)); + // Remove trailing newline if present + std::string result(buf); + while (!result.empty() && (result.back() == '\n' || result.back() == ' ')) { + result.pop_back(); + } + return result; +} + +} // namespace tls + +#endif // CPPHTTPLIB_MBEDTLS_SUPPORT + } // namespace httplib diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index 7c7790f41f..1fd8b1d1e8 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -8,8 +8,8 @@ #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.30.2" -#define CPPHTTPLIB_VERSION_NUM "0x001E02" +#define CPPHTTPLIB_VERSION "0.31.0" +#define CPPHTTPLIB_VERSION_NUM "0x001F00" /* * Platform compatibility check @@ -147,7 +147,7 @@ #endif #ifndef CPPHTTPLIB_PAYLOAD_MAX_LENGTH -#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH ((std::numeric_limits::max)()) +#define CPPHTTPLIB_PAYLOAD_MAX_LENGTH (100 * 1024 * 1024) // 100MB #endif #ifndef CPPHTTPLIB_FORM_URL_ENCODED_PAYLOAD_MAX_LENGTH @@ -383,6 +383,45 @@ using socket_t = int; #endif // CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#ifdef _WIN32 +#include +#ifdef _MSC_VER +#pragma comment(lib, "crypt32.lib") +#endif +#endif // _WIN32 +#if defined(CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN) +#if TARGET_OS_MAC +#include +#endif +#endif // CPPHTTPLIB_USE_CERTS_FROM_MACOSX_KEYCHAIN + +// Mbed TLS 3.x API compatibility +#if MBEDTLS_VERSION_MAJOR >= 3 +#define CPPHTTPLIB_MBEDTLS_V3 +#endif + +#endif // CPPHTTPLIB_MBEDTLS_SUPPORT + +// Define CPPHTTPLIB_SSL_ENABLED if any SSL backend is available +// This simplifies conditional compilation when adding new backends (e.g., +// wolfSSL) +#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) || defined(CPPHTTPLIB_MBEDTLS_SUPPORT) +#define CPPHTTPLIB_SSL_ENABLED +#endif + #ifdef CPPHTTPLIB_ZLIB_SUPPORT #include #endif @@ -799,6 +838,105 @@ public: using Range = std::pair; using Ranges = std::vector; +#ifdef CPPHTTPLIB_SSL_ENABLED +// TLS abstraction layer - public type definitions and API +namespace tls { + +// Opaque handles (defined as void* for abstraction) +using ctx_t = void *; +using session_t = void *; +using const_session_t = const void *; // For read-only session access +using cert_t = void *; +using ca_store_t = void *; + +// TLS versions +enum class Version { + TLS1_2 = 0x0303, + TLS1_3 = 0x0304, +}; + +// Subject Alternative Names (SAN) entry types +enum class SanType { DNS, IP, EMAIL, URI, OTHER }; + +// SAN entry structure +struct SanEntry { + SanType type; + std::string value; +}; + +// Verification context for certificate verification callback +struct VerifyContext { + session_t session; // TLS session handle + cert_t cert; // Current certificate being verified + int depth; // Certificate chain depth (0 = leaf) + bool preverify_ok; // OpenSSL/Mbed TLS pre-verification result + long error_code; // Backend-specific error code (0 = no error) + const char *error_string; // Human-readable error description + + // Certificate introspection methods + std::string subject_cn() const; + std::string issuer_name() const; + bool check_hostname(const char *hostname) const; + std::vector sans() const; + bool validity(time_t ¬_before, time_t ¬_after) const; + std::string serial() const; +}; + +using VerifyCallback = std::function; + +// TlsError codes for TLS operations (backend-independent) +enum class ErrorCode : int { + Success = 0, + WantRead, // Non-blocking: need to wait for read + WantWrite, // Non-blocking: need to wait for write + PeerClosed, // Peer closed the connection + Fatal, // Unrecoverable error + SyscallError, // System call error (check sys_errno) + CertVerifyFailed, // Certificate verification failed + HostnameMismatch, // Hostname verification failed +}; + +// TLS error information +struct TlsError { + ErrorCode code = ErrorCode::Fatal; + uint64_t backend_code = 0; // OpenSSL: ERR_get_error(), mbedTLS: return value + int sys_errno = 0; // errno when SyscallError + + // Convert verification error code to human-readable string + static std::string verify_error_to_string(long error_code); +}; + +// RAII wrapper for peer certificate +class PeerCert { +public: + PeerCert(); + PeerCert(PeerCert &&other) noexcept; + PeerCert &operator=(PeerCert &&other) noexcept; + ~PeerCert(); + + PeerCert(const PeerCert &) = delete; + PeerCert &operator=(const PeerCert &) = delete; + + explicit operator bool() const; + std::string subject_cn() const; + std::string issuer_name() const; + bool check_hostname(const char *hostname) const; + std::vector sans() const; + bool validity(time_t ¬_before, time_t ¬_after) const; + std::string serial() const; + +private: + explicit PeerCert(cert_t cert); + cert_t cert_ = nullptr; + friend PeerCert get_peer_cert_from_session(const_session_t session); +}; + +// Callback for TLS context setup (used by SSLServer constructor) +using ContextSetupCallback = std::function; + +} // namespace tls +#endif + struct Request { std::string method; std::string path; @@ -828,9 +966,6 @@ struct Request { ContentReceiverWithProgress content_receiver; DownloadProgress download_progress; UploadProgress upload_progress; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - const SSL *ssl = nullptr; -#endif bool has_header(const std::string &key) const; std::string get_header_value(const std::string &key, const char *def = "", @@ -858,6 +993,12 @@ struct Request { size_t authorization_count_ = 0; std::chrono::time_point start_time_ = (std::chrono::steady_clock::time_point::min)(); + +#ifdef CPPHTTPLIB_SSL_ENABLED + tls::const_session_t ssl = nullptr; + tls::PeerCert peer_cert() const; + std::string sni() const; +#endif }; struct Response { @@ -1005,74 +1146,18 @@ public: class ThreadPool final : public TaskQueue { public: - explicit ThreadPool(size_t n, size_t mqr = 0) - : shutdown_(false), max_queued_requests_(mqr) { - threads_.reserve(n); - while (n) { - threads_.emplace_back(worker(*this)); - n--; - } - } - + explicit ThreadPool(size_t n, size_t mqr = 0); ThreadPool(const ThreadPool &) = delete; ~ThreadPool() override = default; - bool enqueue(std::function fn) override { - { - std::unique_lock lock(mutex_); - if (max_queued_requests_ > 0 && jobs_.size() >= max_queued_requests_) { - return false; - } - jobs_.push_back(std::move(fn)); - } - - cond_.notify_one(); - return true; - } - - void shutdown() override { - // Stop all worker threads... - { - std::unique_lock lock(mutex_); - shutdown_ = true; - } - - cond_.notify_all(); - - // Join... - for (auto &t : threads_) { - t.join(); - } - } + bool enqueue(std::function fn) override; + void shutdown() override; private: struct worker { - explicit worker(ThreadPool &pool) : pool_(pool) {} + explicit worker(ThreadPool &pool); - void operator()() { - for (;;) { - std::function fn; - { - std::unique_lock lock(pool_.mutex_); - - pool_.cond_.wait( - lock, [&] { return !pool_.jobs_.empty() || pool_.shutdown_; }); - - if (pool_.shutdown_ && pool_.jobs_.empty()) { break; } - - fn = pool_.jobs_.front(); - pool_.jobs_.pop_front(); - } - - assert(true == static_cast(fn)); - fn(); - } - -#if defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(OPENSSL_IS_BORINGSSL) && \ - !defined(LIBRESSL_VERSION_NUMBER) - OPENSSL_thread_stop(); -#endif - } + void operator()(); ThreadPool &pool_; }; @@ -1184,6 +1269,9 @@ int close_socket(socket_t sock); ssize_t write_headers(Stream &strm, const Headers &headers); +bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, + time_t usec); + } // namespace detail class Server { @@ -1429,17 +1517,6 @@ public: Headers &&request_headers = Headers{}) : res_(std::move(res)), err_(err), request_headers_(std::move(request_headers)) {} -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - Result(std::unique_ptr &&res, Error err, Headers &&request_headers, - int ssl_error) - : res_(std::move(res)), err_(err), - request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {} - Result(std::unique_ptr &&res, Error err, Headers &&request_headers, - int ssl_error, unsigned long ssl_openssl_error) - : res_(std::move(res)), err_(err), - request_headers_(std::move(request_headers)), ssl_error_(ssl_error), - ssl_openssl_error_(ssl_openssl_error) {} -#endif // Response operator bool() const { return res_ != nullptr; } bool operator==(std::nullptr_t) const { return res_ == nullptr; } @@ -1454,13 +1531,6 @@ public: // Error Error error() const { return err_; } -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - // SSL Error - int ssl_error() const { return ssl_error_; } - // OpenSSL Error - unsigned long ssl_openssl_error() const { return ssl_openssl_error_; } -#endif - // Request Headers bool has_request_header(const std::string &key) const; std::string get_request_header_value(const std::string &key, @@ -1474,64 +1544,76 @@ private: std::unique_ptr res_; Error err_ = Error::Unknown; Headers request_headers_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + +#ifdef CPPHTTPLIB_SSL_ENABLED +public: + Result(std::unique_ptr &&res, Error err, Headers &&request_headers, + int ssl_error) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)), ssl_error_(ssl_error) {} + Result(std::unique_ptr &&res, Error err, Headers &&request_headers, + int ssl_error, unsigned long ssl_backend_error) + : res_(std::move(res)), err_(err), + request_headers_(std::move(request_headers)), ssl_error_(ssl_error), + ssl_backend_error_(ssl_backend_error) {} + + int ssl_error() const { return ssl_error_; } + unsigned long ssl_backend_error() const { return ssl_backend_error_; } + +private: int ssl_error_ = 0; - unsigned long ssl_openssl_error_ = 0; + unsigned long ssl_backend_error_ = 0; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +public: + [[deprecated("Use ssl_backend_error() instead")]] + unsigned long ssl_openssl_error() const { + return ssl_backend_error_; + } #endif }; struct ClientConnection { socket_t sock = INVALID_SOCKET; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSL *ssl = nullptr; -#endif bool is_open() const { return sock != INVALID_SOCKET; } ClientConnection() = default; - ~ClientConnection() { -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - if (ssl) { - SSL_free(ssl); - ssl = nullptr; - } -#endif - if (sock != INVALID_SOCKET) { - detail::close_socket(sock); - sock = INVALID_SOCKET; - } - } + ~ClientConnection(); ClientConnection(const ClientConnection &) = delete; ClientConnection &operator=(const ClientConnection &) = delete; ClientConnection(ClientConnection &&other) noexcept : sock(other.sock) -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED , - ssl(other.ssl) + session(other.session) #endif { other.sock = INVALID_SOCKET; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - other.ssl = nullptr; +#ifdef CPPHTTPLIB_SSL_ENABLED + other.session = nullptr; #endif } ClientConnection &operator=(ClientConnection &&other) noexcept { if (this != &other) { sock = other.sock; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - ssl = other.ssl; -#endif other.sock = INVALID_SOCKET; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - other.ssl = nullptr; +#ifdef CPPHTTPLIB_SSL_ENABLED + session = other.session; + other.session = nullptr; #endif } return *this; } + +#ifdef CPPHTTPLIB_SSL_ENABLED + tls::session_t session = nullptr; +#endif }; namespace detail { @@ -1540,7 +1622,9 @@ struct ChunkedDecoder; struct BodyReader { Stream *stream = nullptr; + bool has_content_length = false; size_t content_length = 0; + size_t payload_max_length = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; size_t bytes_read = 0; bool chunked = false; bool eof = false; @@ -1610,6 +1694,7 @@ public: std::unique_ptr decompressor_; std::string decompress_buffer_; size_t decompress_offset_ = 0; + size_t decompressed_bytes_read_ = 0; }; // clang-format off @@ -1756,10 +1841,6 @@ public: void set_basic_auth(const std::string &username, const std::string &password); void set_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const std::string &username, - const std::string &password); -#endif void set_keep_alive(bool on); void set_follow_location(bool on); @@ -1770,30 +1851,14 @@ public: void set_decompress(bool on); + void set_payload_max_length(size_t length); + void set_interface(const std::string &intf); void set_proxy(const std::string &host, int port); void set_proxy_basic_auth(const std::string &username, const std::string &password); void set_proxy_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const std::string &username, - const std::string &password); -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_ca_cert_path(const std::string &ca_cert_file_path, - const std::string &ca_cert_dir_path = std::string()); - void set_ca_cert_store(X509_STORE *ca_cert_store); - X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void enable_server_certificate_verification(bool enabled); - void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier( - std::function verifier); -#endif void set_logger(Logger logger); void set_error_logger(ErrorLogger error_logger); @@ -1801,11 +1866,15 @@ public: protected: struct Socket { socket_t sock = INVALID_SOCKET; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - SSL *ssl = nullptr; -#endif + + // For Mbed TLS compatibility: start_time for request timeout tracking + std::chrono::time_point start_time_; bool is_open() const { return sock != INVALID_SOCKET; } + +#ifdef CPPHTTPLIB_SSL_ENABLED + tls::session_t ssl = nullptr; +#endif }; virtual bool create_and_connect_socket(Socket &socket, Error &error); @@ -1872,10 +1941,6 @@ protected: std::string basic_auth_username_; std::string basic_auth_password_; std::string bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string digest_auth_username_; - std::string digest_auth_password_; -#endif bool keep_alive_ = false; bool follow_location_ = false; @@ -1890,6 +1955,9 @@ protected: bool compress_ = false; bool decompress_ = true; + size_t payload_max_length_ = CPPHTTPLIB_PAYLOAD_MAX_LENGTH; + bool has_payload_max_length_ = false; + std::string interface_; std::string proxy_host_; @@ -1898,33 +1966,11 @@ protected: std::string proxy_basic_auth_username_; std::string proxy_basic_auth_password_; std::string proxy_bearer_token_auth_token_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string proxy_digest_auth_username_; - std::string proxy_digest_auth_password_; -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - std::string ca_cert_file_path_; - std::string ca_cert_dir_path_; - - X509_STORE *ca_cert_store_ = nullptr; -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - bool server_certificate_verification_ = true; - bool server_hostname_verification_ = true; - std::function server_certificate_verifier_; -#endif mutable std::mutex logger_mutex_; Logger logger_; ErrorLogger error_logger_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - int last_ssl_error_ = 0; - unsigned long last_openssl_error_ = 0; -#endif - private: bool send_(Request &req, Response &res, Error &error); Result send_(Request &&req); @@ -1969,6 +2015,44 @@ private: virtual bool is_ssl() const; void transfer_socket_ownership_to_handle(StreamHandle &handle); + +#ifdef CPPHTTPLIB_SSL_ENABLED +public: + void set_digest_auth(const std::string &username, + const std::string &password); + void set_proxy_digest_auth(const std::string &username, + const std::string &password); + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + +protected: + std::string digest_auth_username_; + std::string digest_auth_password_; + std::string proxy_digest_auth_username_; + std::string proxy_digest_auth_password_; + std::string ca_cert_file_path_; + std::string ca_cert_dir_path_; + bool server_certificate_verification_ = true; + bool server_hostname_verification_ = true; + std::string ca_cert_pem_; // Store CA cert PEM for redirect transfer + int last_ssl_error_ = 0; + unsigned long last_backend_error_ = 0; +#endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +public: + [[deprecated("Use load_ca_cert_store() instead")]] + void set_ca_cert_store(X509_STORE *ca_cert_store); + + [[deprecated("Use tls::create_ca_store() instead")]] + X509_STORE *create_ca_cert_store(const char *ca_cert, std::size_t size) const; + + [[deprecated("Use set_server_certificate_verifier(VerifyCallback) instead")]] + virtual void set_server_certificate_verifier( + std::function verifier); +#endif }; class Client { @@ -2138,10 +2222,6 @@ public: void set_basic_auth(const std::string &username, const std::string &password); void set_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_digest_auth(const std::string &username, - const std::string &password); -#endif void set_keep_alive(bool on); void set_follow_location(bool on); @@ -2153,49 +2233,65 @@ public: void set_decompress(bool on); + void set_payload_max_length(size_t length); + void set_interface(const std::string &intf); void set_proxy(const std::string &host, int port); void set_proxy_basic_auth(const std::string &username, const std::string &password); void set_proxy_bearer_token_auth(const std::string &token); -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_proxy_digest_auth(const std::string &username, - const std::string &password); -#endif - -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void enable_server_certificate_verification(bool enabled); - void enable_server_hostname_verification(bool enabled); - void set_server_certificate_verifier( - std::function verifier); -#endif - void set_logger(Logger logger); void set_error_logger(ErrorLogger error_logger); - // SSL -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT - void set_ca_cert_path(const std::string &ca_cert_file_path, - const std::string &ca_cert_dir_path = std::string()); - - void set_ca_cert_store(X509_STORE *ca_cert_store); - void load_ca_cert_store(const char *ca_cert, std::size_t size); - - long get_openssl_verify_result() const; - - SSL_CTX *ssl_context() const; -#endif - private: std::unique_ptr cli_; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED +public: + void set_digest_auth(const std::string &username, + const std::string &password); + void set_proxy_digest_auth(const std::string &username, + const std::string &password); + void enable_server_certificate_verification(bool enabled); + void enable_server_hostname_verification(bool enabled); + void set_ca_cert_path(const std::string &ca_cert_file_path, + const std::string &ca_cert_dir_path = std::string()); + + void set_ca_cert_store(tls::ca_store_t ca_cert_store); + void load_ca_cert_store(const char *ca_cert, std::size_t size); + + void set_server_certificate_verifier(tls::VerifyCallback verifier); + + void set_session_verifier( + std::function verifier); + + tls::ctx_t tls_context() const; + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) + void enable_windows_certificate_verification(bool enabled); +#endif + +private: bool is_ssl_ = false; #endif + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +public: + [[deprecated("Use tls_context() instead")]] + SSL_CTX *ssl_context() const; + + [[deprecated("Use set_session_verifier(session_t) instead")]] + void set_server_certificate_verifier( + std::function verifier); + + [[deprecated("Use Result::ssl_backend_error() instead")]] + long get_verify_result() const; +#endif }; -#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +#ifdef CPPHTTPLIB_SSL_ENABLED class SSLServer : public Server { public: SSLServer(const char *cert_path, const char *private_key_path, @@ -2203,32 +2299,60 @@ public: const char *client_ca_cert_dir_path = nullptr, const char *private_key_password = nullptr); - SSLServer(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); + struct PemMemory { + const char *cert_pem; + size_t cert_pem_len; + const char *key_pem; + size_t key_pem_len; + const char *client_ca_pem; + size_t client_ca_pem_len; + const char *private_key_password; + }; + explicit SSLServer(const PemMemory &pem); - SSLServer( - const std::function &setup_ssl_ctx_callback); + // The callback receives the ctx_t handle which can be cast to the + // appropriate backend type (SSL_CTX* for OpenSSL, + // tls::impl::MbedTlsContext* for Mbed TLS) + explicit SSLServer(const tls::ContextSetupCallback &setup_callback); ~SSLServer() override; bool is_valid() const override; - SSL_CTX *ssl_context() const; + bool update_certs_pem(const char *cert_pem, const char *key_pem, + const char *client_ca_pem = nullptr, + const char *password = nullptr); - void update_certs(X509 *cert, EVP_PKEY *private_key, - X509_STORE *client_ca_cert_store = nullptr); + tls::ctx_t tls_context() const { return ctx_; } int ssl_last_error() const { return last_ssl_error_; } private: bool process_and_close_socket(socket_t sock) override; - STACK_OF(X509_NAME) * extract_ca_names_from_x509_store(X509_STORE *store); - - SSL_CTX *ctx_; + tls::ctx_t ctx_ = nullptr; std::mutex ctx_mutex_; int last_ssl_error_ = 0; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +public: + [[deprecated("Use SSLServer(PemMemory) or " + "SSLServer(ContextSetupCallback) instead")]] + SSLServer(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); + + [[deprecated("Use SSLServer(ContextSetupCallback) instead")]] + SSLServer( + const std::function &setup_ssl_ctx_callback); + + [[deprecated("Use tls_context() instead")]] + SSL_CTX *ssl_context() const; + + [[deprecated("Use update_certs_pem() instead")]] + void update_certs(X509 *cert, EVP_PKEY *private_key, + X509_STORE *client_ca_cert_store = nullptr); +#endif }; class SSLClient final : public ClientImpl { @@ -2242,20 +2366,34 @@ public: const std::string &client_key_path, const std::string &private_key_password = std::string()); - explicit SSLClient(const std::string &host, int port, X509 *client_cert, - EVP_PKEY *client_key, - const std::string &private_key_password = std::string()); + struct PemMemory { + const char *cert_pem; + size_t cert_pem_len; + const char *key_pem; + size_t key_pem_len; + const char *private_key_password; + }; + explicit SSLClient(const std::string &host, int port, const PemMemory &pem); ~SSLClient() override; bool is_valid() const override; - void set_ca_cert_store(X509_STORE *ca_cert_store); + void set_ca_cert_store(tls::ca_store_t ca_cert_store); void load_ca_cert_store(const char *ca_cert, std::size_t size); - long get_openssl_verify_result() const; + void set_server_certificate_verifier(tls::VerifyCallback verifier); - SSL_CTX *ssl_context() const; + // Post-handshake session verifier (backend-independent) + void set_session_verifier( + std::function verifier); + + tls::ctx_t tls_context() const { return ctx_; } + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) + void enable_windows_certificate_verification(bool enabled); +#endif private: bool create_and_connect_socket(Socket &socket, Error &error) override; @@ -2277,26 +2415,45 @@ private: bool load_certs(); - bool verify_host(X509 *server_cert) const; - bool verify_host_with_subject_alt_name(X509 *server_cert) const; - bool verify_host_with_common_name(X509 *server_cert) const; - bool check_host_name(const char *pattern, size_t pattern_len) const; - - SSL_CTX *ctx_; + tls::ctx_t ctx_ = nullptr; std::mutex ctx_mutex_; std::once_flag initialize_cert_; - std::vector host_components_; - long verify_result_ = 0; - friend class ClientImpl; -}; + std::function session_verifier_; + +#if defined(_WIN32) && \ + !defined(CPPHTTPLIB_DISABLE_WINDOWS_AUTOMATIC_ROOT_CERTIFICATES_UPDATE) + bool enable_windows_cert_verification_ = true; #endif -/* - * Implementation of template methods. - */ + friend class ClientImpl; + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +public: + [[deprecated("Use SSLClient(host, port, PemMemory) instead")]] + explicit SSLClient(const std::string &host, int port, X509 *client_cert, + EVP_PKEY *client_key, + const std::string &private_key_password = std::string()); + + [[deprecated("Use Result::ssl_backend_error() instead")]] + long get_verify_result() const; + + [[deprecated("Use tls_context() instead")]] + SSL_CTX *ssl_context() const; + + [[deprecated("Use set_session_verifier(session_t) instead")]] + void set_server_certificate_verifier( + std::function verifier) override; + +private: + bool verify_host(X509 *server_cert) const; + bool verify_host_with_subject_alt_name(X509 *server_cert) const; + bool verify_host_with_common_name(X509 *server_cert) const; +#endif +}; +#endif // CPPHTTPLIB_SSL_ENABLED namespace detail { @@ -2345,66 +2502,6 @@ inline size_t get_header_value_u64(const Headers &headers, } // namespace detail -inline size_t Request::get_header_value_u64(const std::string &key, size_t def, - size_t id) const { - return detail::get_header_value_u64(headers, key, def, id); -} - -inline size_t Response::get_header_value_u64(const std::string &key, size_t def, - size_t id) const { - return detail::get_header_value_u64(headers, key, def, id); -} - -namespace detail { - -inline bool set_socket_opt_impl(socket_t sock, int level, int optname, - const void *optval, socklen_t optlen) { - return setsockopt(sock, level, optname, -#ifdef _WIN32 - reinterpret_cast(optval), -#else - optval, -#endif - optlen) == 0; -} - -inline bool set_socket_opt(socket_t sock, int level, int optname, int optval) { - return set_socket_opt_impl(sock, level, optname, &optval, sizeof(optval)); -} - -inline bool set_socket_opt_time(socket_t sock, int level, int optname, - time_t sec, time_t usec) { -#ifdef _WIN32 - auto timeout = static_cast(sec * 1000 + usec / 1000); -#else - timeval timeout; - timeout.tv_sec = static_cast(sec); - timeout.tv_usec = static_cast(usec); -#endif - return set_socket_opt_impl(sock, level, optname, &timeout, sizeof(timeout)); -} - -} // namespace detail - -inline void default_socket_options(socket_t sock) { - detail::set_socket_opt(sock, SOL_SOCKET, -#ifdef SO_REUSEPORT - SO_REUSEPORT, -#else - SO_REUSEADDR, -#endif - 1); -} - -inline std::string get_bearer_token_auth(const Request &req) { - if (req.has_header("Authorization")) { - constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); - return req.get_header_value("Authorization") - .substr(bearer_header_prefix_len); - } - return ""; -} - template inline Server & Server::set_read_timeout(const std::chrono::duration &duration) { @@ -2429,12 +2526,6 @@ Server::set_idle_interval(const std::chrono::duration &duration) { return *this; } -inline size_t Result::get_request_header_value_u64(const std::string &key, - size_t def, - size_t id) const { - return detail::get_header_value_u64(request_headers_, key, def, id); -} - template inline void ClientImpl::set_connection_timeout( const std::chrono::duration &duration) { @@ -2842,105 +2933,73 @@ bool is_field_content(const std::string &s); bool is_field_value(const std::string &s); } // namespace fields - } // namespace detail +/* + * TLS Abstraction Layer Declarations + */ + +#ifdef CPPHTTPLIB_SSL_ENABLED +// TLS abstraction layer - backend-specific type declarations +#ifdef CPPHTTPLIB_MBEDTLS_SUPPORT +namespace tls { +namespace impl { + +// Mbed TLS context wrapper (holds config, entropy, DRBG, CA chain, own +// cert/key). This struct is accessible via tls::impl for use in SSL context +// setup callbacks (cast ctx_t to tls::impl::MbedTlsContext*). +struct MbedTlsContext { + mbedtls_ssl_config conf; + mbedtls_entropy_context entropy; + mbedtls_ctr_drbg_context ctr_drbg; + mbedtls_x509_crt ca_chain; + mbedtls_x509_crt own_cert; + mbedtls_pk_context own_key; + bool is_server = false; + bool verify_client = false; + bool has_verify_callback = false; + + MbedTlsContext(); + ~MbedTlsContext(); + + MbedTlsContext(const MbedTlsContext &) = delete; + MbedTlsContext &operator=(const MbedTlsContext &) = delete; +}; + +} // namespace impl +} // namespace tls +#endif + +#endif // CPPHTTPLIB_SSL_ENABLED + namespace stream { class Result { public: - Result() : chunk_size_(8192) {} - - explicit Result(ClientImpl::StreamHandle &&handle, size_t chunk_size = 8192) - : handle_(std::move(handle)), chunk_size_(chunk_size) {} - - Result(Result &&other) noexcept - : handle_(std::move(other.handle_)), buffer_(std::move(other.buffer_)), - current_size_(other.current_size_), chunk_size_(other.chunk_size_), - finished_(other.finished_) { - other.current_size_ = 0; - other.finished_ = true; - } - - Result &operator=(Result &&other) noexcept { - if (this != &other) { - handle_ = std::move(other.handle_); - buffer_ = std::move(other.buffer_); - current_size_ = other.current_size_; - chunk_size_ = other.chunk_size_; - finished_ = other.finished_; - other.current_size_ = 0; - other.finished_ = true; - } - return *this; - } - + Result(); + explicit Result(ClientImpl::StreamHandle &&handle, size_t chunk_size = 8192); + Result(Result &&other) noexcept; + Result &operator=(Result &&other) noexcept; Result(const Result &) = delete; Result &operator=(const Result &) = delete; - // Check if the result is valid (connection succeeded and response received) - bool is_valid() const { return handle_.is_valid(); } - explicit operator bool() const { return is_valid(); } - - // Response status code - int status() const { - return handle_.response ? handle_.response->status : -1; - } - - // Response headers - const Headers &headers() const { - static const Headers empty_headers; - return handle_.response ? handle_.response->headers : empty_headers; - } - + // Response info + bool is_valid() const; + explicit operator bool() const; + int status() const; + const Headers &headers() const; std::string get_header_value(const std::string &key, - const char *def = "") const { - return handle_.response ? handle_.response->get_header_value(key, def) - : def; - } + const char *def = "") const; + bool has_header(const std::string &key) const; + Error error() const; + Error read_error() const; + bool has_read_error() const; - bool has_header(const std::string &key) const { - return handle_.response ? handle_.response->has_header(key) : false; - } - - // Error information - Error error() const { return handle_.error; } - Error read_error() const { return handle_.get_read_error(); } - bool has_read_error() const { return handle_.has_read_error(); } - - // Streaming iteration API - // Call next() to read the next chunk, then access data via data()/size() - // Returns true if data was read, false when stream is exhausted - bool next() { - if (!handle_.is_valid() || finished_) { return false; } - - if (buffer_.size() < chunk_size_) { buffer_.resize(chunk_size_); } - - ssize_t n = handle_.read(&buffer_[0], chunk_size_); - if (n > 0) { - current_size_ = static_cast(n); - return true; - } - - current_size_ = 0; - finished_ = true; - return false; - } - - // Pointer to current chunk data (valid after next() returns true) - const char *data() const { return buffer_.data(); } - - // Size of current chunk (valid after next() returns true) - size_t size() const { return current_size_; } - - // Convenience method: read all remaining data into a string - std::string read_all() { - std::string result; - while (next()) { - result.append(data(), size()); - } - return result; - } + // Stream reading + bool next(); + const char *data() const; + size_t size() const; + std::string read_all(); private: ClientImpl::StreamHandle handle_; @@ -3205,13 +3264,8 @@ struct SSEMessage { std::string data; // Event payload std::string id; // Event ID for Last-Event-ID header - SSEMessage() : event("message") {} - - void clear() { - event = "message"; - data.clear(); - id.clear(); - } + SSEMessage(); + void clear(); }; class SSEClient { @@ -3220,255 +3274,40 @@ public: using ErrorHandler = std::function; using OpenHandler = std::function; - SSEClient(Client &client, const std::string &path) - : client_(client), path_(path) {} - - SSEClient(Client &client, const std::string &path, const Headers &headers) - : client_(client), path_(path), headers_(headers) {} - - ~SSEClient() { stop(); } + SSEClient(Client &client, const std::string &path); + SSEClient(Client &client, const std::string &path, const Headers &headers); + ~SSEClient(); SSEClient(const SSEClient &) = delete; SSEClient &operator=(const SSEClient &) = delete; // Event handlers - SSEClient &on_message(MessageHandler handler) { - on_message_ = std::move(handler); - return *this; - } - - SSEClient &on_event(const std::string &type, MessageHandler handler) { - event_handlers_[type] = std::move(handler); - return *this; - } - - SSEClient &on_open(OpenHandler handler) { - on_open_ = std::move(handler); - return *this; - } - - SSEClient &on_error(ErrorHandler handler) { - on_error_ = std::move(handler); - return *this; - } - - SSEClient &set_reconnect_interval(int ms) { - reconnect_interval_ms_ = ms; - return *this; - } - - SSEClient &set_max_reconnect_attempts(int n) { - max_reconnect_attempts_ = n; - return *this; - } + SSEClient &on_message(MessageHandler handler); + SSEClient &on_event(const std::string &type, MessageHandler handler); + SSEClient &on_open(OpenHandler handler); + SSEClient &on_error(ErrorHandler handler); + SSEClient &set_reconnect_interval(int ms); + SSEClient &set_max_reconnect_attempts(int n); // State accessors - bool is_connected() const { return connected_.load(); } - const std::string &last_event_id() const { return last_event_id_; } + bool is_connected() const; + const std::string &last_event_id() const; // Blocking start - runs event loop with auto-reconnect - void start() { - running_.store(true); - run_event_loop(); - } + void start(); // Non-blocking start - runs in background thread - void start_async() { - running_.store(true); - async_thread_ = std::thread([this]() { run_event_loop(); }); - } + void start_async(); // Stop the client (thread-safe) - void stop() { - running_.store(false); - client_.stop(); // Cancel any pending operations - if (async_thread_.joinable()) { async_thread_.join(); } - } + void stop(); private: - // Parse a single SSE field line - // Returns true if this line ends an event (blank line) - bool parse_sse_line(const std::string &line, SSEMessage &msg, int &retry_ms) { - // Blank line signals end of event - if (line.empty() || line == "\r") { return true; } - - // Lines starting with ':' are comments (ignored) - if (!line.empty() && line[0] == ':') { return false; } - - // Find the colon separator - auto colon_pos = line.find(':'); - if (colon_pos == std::string::npos) { - // Line with no colon is treated as field name with empty value - return false; - } - - auto field = line.substr(0, colon_pos); - std::string value; - - // Value starts after colon, skip optional single space - if (colon_pos + 1 < line.size()) { - auto value_start = colon_pos + 1; - if (line[value_start] == ' ') { value_start++; } - value = line.substr(value_start); - // Remove trailing \r if present - if (!value.empty() && value.back() == '\r') { value.pop_back(); } - } - - // Handle known fields - if (field == "event") { - msg.event = value; - } else if (field == "data") { - // Multiple data lines are concatenated with newlines - if (!msg.data.empty()) { msg.data += "\n"; } - msg.data += value; - } else if (field == "id") { - // Empty id is valid (clears the last event ID) - msg.id = value; - } else if (field == "retry") { - // Parse retry interval in milliseconds - { - int v = 0; - auto res = - detail::from_chars(value.data(), value.data() + value.size(), v); - if (res.ec == std::errc{}) { retry_ms = v; } - } - } - // Unknown fields are ignored per SSE spec - - return false; - } - - // Main event loop with auto-reconnect - void run_event_loop() { - auto reconnect_count = 0; - - while (running_.load()) { - // Build headers, including Last-Event-ID if we have one - auto request_headers = headers_; - if (!last_event_id_.empty()) { - request_headers.emplace("Last-Event-ID", last_event_id_); - } - - // Open streaming connection - auto result = stream::Get(client_, path_, request_headers); - - // Connection error handling - if (!result) { - connected_.store(false); - if (on_error_) { on_error_(result.error()); } - - if (!should_reconnect(reconnect_count)) { break; } - wait_for_reconnect(); - reconnect_count++; - continue; - } - - if (result.status() != 200) { - connected_.store(false); - // For certain errors, don't reconnect - if (result.status() == 204 || // No Content - server wants us to stop - result.status() == 404 || // Not Found - result.status() == 401 || // Unauthorized - result.status() == 403) { // Forbidden - if (on_error_) { on_error_(Error::Connection); } - break; - } - - if (on_error_) { on_error_(Error::Connection); } - - if (!should_reconnect(reconnect_count)) { break; } - wait_for_reconnect(); - reconnect_count++; - continue; - } - - // Connection successful - connected_.store(true); - reconnect_count = 0; - if (on_open_) { on_open_(); } - - // Event receiving loop - std::string buffer; - SSEMessage current_msg; - - while (running_.load() && result.next()) { - buffer.append(result.data(), result.size()); - - // Process complete lines in the buffer - size_t line_start = 0; - size_t newline_pos; - - while ((newline_pos = buffer.find('\n', line_start)) != - std::string::npos) { - auto line = buffer.substr(line_start, newline_pos - line_start); - line_start = newline_pos + 1; - - // Parse the line and check if event is complete - auto event_complete = - parse_sse_line(line, current_msg, reconnect_interval_ms_); - - if (event_complete && !current_msg.data.empty()) { - // Update last_event_id for reconnection - if (!current_msg.id.empty()) { last_event_id_ = current_msg.id; } - - // Dispatch event to appropriate handler - dispatch_event(current_msg); - - current_msg.clear(); - } - } - - // Keep unprocessed data in buffer - buffer.erase(0, line_start); - } - - // Connection ended - connected_.store(false); - - if (!running_.load()) { break; } - - // Check for read errors - if (result.has_read_error()) { - if (on_error_) { on_error_(result.read_error()); } - } - - if (!should_reconnect(reconnect_count)) { break; } - wait_for_reconnect(); - reconnect_count++; - } - - connected_.store(false); - } - - // Dispatch event to appropriate handler - void dispatch_event(const SSEMessage &msg) { - // Check for specific event type handler first - auto it = event_handlers_.find(msg.event); - if (it != event_handlers_.end()) { - it->second(msg); - return; - } - - // Fall back to generic message handler - if (on_message_) { on_message_(msg); } - } - - // Check if we should attempt to reconnect - bool should_reconnect(int count) const { - if (!running_.load()) { return false; } - if (max_reconnect_attempts_ == 0) { return true; } // unlimited - return count < max_reconnect_attempts_; - } - - // Wait for reconnect interval - void wait_for_reconnect() { - // Use small increments to check running_ flag frequently - auto waited = 0; - while (running_.load() && waited < reconnect_interval_ms_) { - std::this_thread::sleep_for(std::chrono::milliseconds(100)); - waited += 100; - } - } + bool parse_sse_line(const std::string &line, SSEMessage &msg, int &retry_ms); + void run_event_loop(); + void dispatch_event(const SSEMessage &msg); + bool should_reconnect(int count) const; + void wait_for_reconnect(); // Client and path Client &client_;