diff --git a/common/arg.h b/common/arg.h index a1b6a14e67..55782a158d 100644 --- a/common/arg.h +++ b/common/arg.h @@ -129,11 +129,3 @@ void common_params_add_preset_options(std::vector & args); // initialize argument parser context - used by test-arg-parser and preset common_params_context common_params_parser_init(common_params & params, llama_example ex, void(*print_usage)(int, char **) = nullptr); - -struct common_remote_params { - std::vector headers; - long timeout = 0; // CURLOPT_TIMEOUT, in seconds ; 0 means no timeout - long max_size = 0; // max size of the response ; unlimited if 0 ; max is 2GB -}; -// get remote file content, returns -std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); diff --git a/common/download.cpp b/common/download.cpp index ef87472560..6f56b5518f 100644 --- a/common/download.cpp +++ b/common/download.cpp @@ -308,7 +308,8 @@ static bool common_download_head(CURL * curl, // download one single file from remote URL to local path static bool common_download_file_single_online(const std::string & url, const std::string & path, - const std::string & bearer_token) { + const std::string & bearer_token, + const common_header_list & custom_headers) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; for (int i = 0; i < max_attempts; ++i) { @@ -330,6 +331,11 @@ static bool common_download_file_single_online(const std::string & url, common_load_model_from_url_headers headers; curl_easy_setopt(curl.get(), CURLOPT_HEADERDATA, &headers); curl_slist_ptr http_headers; + + for (const auto & h : custom_headers) { + std::string s = h.first + ": " + h.second; + http_headers.ptr = curl_slist_append(http_headers.ptr, s.c_str()); + } const bool was_perform_successful = common_download_head(curl.get(), http_headers, url, bearer_token); if (!was_perform_successful) { head_request_ok = false; @@ -454,8 +460,10 @@ std::pair> common_remote_get_content(const std::string & curl_easy_setopt(curl.get(), CURLOPT_MAXFILESIZE, params.max_size); } http_headers.ptr = curl_slist_append(http_headers.ptr, "User-Agent: llama-cpp"); + for (const auto & header : params.headers) { - http_headers.ptr = curl_slist_append(http_headers.ptr, header.c_str()); + std::string header_ = header.first + ": " + header.second; + http_headers.ptr = curl_slist_append(http_headers.ptr, header_.c_str()); } curl_easy_setopt(curl.get(), CURLOPT_HTTPHEADER, http_headers.ptr); @@ -619,7 +627,8 @@ static bool common_pull_file(httplib::Client & cli, // download one single file from remote URL to local path static bool common_download_file_single_online(const std::string & url, const std::string & path, - const std::string & bearer_token) { + const std::string & bearer_token, + const common_header_list & custom_headers) { static const int max_attempts = 3; static const int retry_delay_seconds = 2; @@ -629,6 +638,9 @@ static bool common_download_file_single_online(const std::string & url, if (!bearer_token.empty()) { default_headers.insert({"Authorization", "Bearer " + bearer_token}); } + for (const auto & h : custom_headers) { + default_headers.emplace(h.first, h.second); + } cli.set_default_headers(default_headers); const bool file_exists = std::filesystem::exists(path); @@ -734,13 +746,9 @@ std::pair> common_remote_get_content(const std::string auto [cli, parts] = common_http_client(url); httplib::Headers headers = {{"User-Agent", "llama-cpp"}}; + for (const auto & header : params.headers) { - size_t pos = header.find(':'); - if (pos != std::string::npos) { - headers.emplace(header.substr(0, pos), header.substr(pos + 1)); - } else { - headers.emplace(header, ""); - } + headers.emplace(header.first, header.second); } if (params.timeout > 0) { @@ -772,9 +780,10 @@ std::pair> common_remote_get_content(const std::string static bool common_download_file_single(const std::string & url, const std::string & path, const std::string & bearer_token, - bool offline) { + bool offline, + const common_header_list & headers) { if (!offline) { - return common_download_file_single_online(url, path, bearer_token); + return common_download_file_single_online(url, path, bearer_token, headers); } if (!std::filesystem::exists(path)) { @@ -788,13 +797,24 @@ static bool common_download_file_single(const std::string & url, // download multiple files from remote URLs to local paths // the input is a vector of pairs -static bool common_download_file_multiple(const std::vector> & urls, const std::string & bearer_token, bool offline) { +static bool common_download_file_multiple(const std::vector> & urls, + const std::string & bearer_token, + bool offline, + const common_header_list & headers) { // Prepare download in parallel std::vector> futures_download; + futures_download.reserve(urls.size()); + for (auto const & item : urls) { - futures_download.push_back(std::async(std::launch::async, [bearer_token, offline](const std::pair & it) -> bool { - return common_download_file_single(it.first, it.second, bearer_token, offline); - }, item)); + futures_download.push_back( + std::async( + std::launch::async, + [&bearer_token, offline, &headers](const std::pair & it) -> bool { + return common_download_file_single(it.first, it.second, bearer_token, offline, headers); + }, + item + ) + ); } // Wait for all downloads to complete @@ -807,17 +827,17 @@ static bool common_download_file_multiple(const std::vector(hf_repo_with_tag, ':'); std::string tag = parts.size() > 1 ? parts.back() : "latest"; std::string hf_repo = parts[0]; @@ -893,10 +916,10 @@ common_hf_file_res common_get_hf_file(const std::string & hf_repo_with_tag, cons std::string url = get_model_endpoint() + "v2/" + hf_repo + "/manifests/" + tag; // headers - std::vector headers; - headers.push_back("Accept: application/json"); + common_header_list headers = custom_headers; + headers.push_back({"Accept", "application/json"}); if (!bearer_token.empty()) { - headers.push_back("Authorization: Bearer " + bearer_token); + headers.push_back({"Authorization", "Bearer " + bearer_token}); } // Important: the User-Agent must be "llama-cpp" to get the "ggufFile" field in the response // User-Agent header is already set in common_remote_get_content, no need to set it here @@ -1031,9 +1054,10 @@ std::string common_docker_resolve_model(const std::string & docker) { const std::string url_prefix = "https://registry-1.docker.io/v2/" + repo; std::string manifest_url = url_prefix + "/manifests/" + tag; common_remote_params manifest_params; - manifest_params.headers.push_back("Authorization: Bearer " + token); - manifest_params.headers.push_back( - "Accept: application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json"); + manifest_params.headers.push_back({"Authorization", "Bearer " + token}); + manifest_params.headers.push_back({"Accept", + "application/vnd.docker.distribution.manifest.v2+json,application/vnd.oci.image.manifest.v1+json" + }); auto manifest_res = common_remote_get_content(manifest_url, manifest_params); if (manifest_res.first != 200) { throw std::runtime_error("Failed to get Docker manifest, HTTP code: " + std::to_string(manifest_res.first)); @@ -1070,7 +1094,7 @@ std::string common_docker_resolve_model(const std::string & docker) { std::string local_path = fs_get_cache_file(model_filename); const std::string blob_url = url_prefix + "/blobs/" + gguf_digest; - if (!common_download_file_single(blob_url, local_path, token, false)) { + if (!common_download_file_single(blob_url, local_path, token, false, {})) { throw std::runtime_error("Failed to download Docker Model"); } @@ -1084,11 +1108,11 @@ std::string common_docker_resolve_model(const std::string & docker) { #else -common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool) { +common_hf_file_res common_get_hf_file(const std::string &, const std::string &, bool, const common_header_list &) { throw std::runtime_error("download functionality is not enabled in this build"); } -bool common_download_model(const common_params_model &, const std::string &, bool) { +bool common_download_model(const common_params_model &, const std::string &, bool, const common_header_list &) { throw std::runtime_error("download functionality is not enabled in this build"); } diff --git a/common/download.h b/common/download.h index d1321e6e90..9ea2093939 100644 --- a/common/download.h +++ b/common/download.h @@ -1,12 +1,21 @@ #pragma once #include +#include struct common_params_model; -// -// download functionalities -// +using common_header = std::pair; +using common_header_list = std::vector; + +struct common_remote_params { + common_header_list headers; + long timeout = 0; // in seconds, 0 means no timeout + long max_size = 0; // unlimited if 0 +}; + +// get remote file content, returns +std::pair> common_remote_get_content(const std::string & url, const common_remote_params & params); struct common_cached_model_info { std::string manifest_path; @@ -41,13 +50,17 @@ struct common_hf_file_res { common_hf_file_res common_get_hf_file( const std::string & hf_repo_with_tag, const std::string & bearer_token, - bool offline); + bool offline, + const common_header_list & headers = {} +); // returns true if download succeeded bool common_download_model( const common_params_model & model, const std::string & bearer_token, - bool offline); + bool offline, + const common_header_list & headers = {} +); // returns list of cached models std::vector common_list_cached_models(); diff --git a/scripts/sync_vendor.py b/scripts/sync_vendor.py index 637f4cdc18..ed6bf1bf4e 100755 --- a/scripts/sync_vendor.py +++ b/scripts/sync_vendor.py @@ -16,7 +16,7 @@ vendor = { # "https://github.com/mackron/miniaudio/raw/refs/tags/0.11.23/miniaudio.h": "vendor/miniaudio/miniaudio.h", "https://github.com/mackron/miniaudio/raw/669ed3e844524fcd883231b13095baee9f6de304/miniaudio.h": "vendor/miniaudio/miniaudio.h", - "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.28.0/httplib.h": "vendor/cpp-httplib/httplib.h", + "https://raw.githubusercontent.com/yhirose/cpp-httplib/refs/tags/v0.30.0/httplib.h": "vendor/cpp-httplib/httplib.h", "https://raw.githubusercontent.com/sheredom/subprocess.h/b49c56e9fe214488493021017bf3954b91c7c1f5/subprocess.h": "vendor/sheredom/subprocess.h", } diff --git a/tests/test-arg-parser.cpp b/tests/test-arg-parser.cpp index e995974a2e..c7be0021be 100644 --- a/tests/test-arg-parser.cpp +++ b/tests/test-arg-parser.cpp @@ -1,5 +1,6 @@ #include "arg.h" #include "common.h" +#include "download.h" #include #include diff --git a/tools/server/server-common.cpp b/tools/server/server-common.cpp index e4a0be44cc..16b0db2983 100644 --- a/tools/server/server-common.cpp +++ b/tools/server/server-common.cpp @@ -1,10 +1,10 @@ #include "common.h" +#include "download.h" #include "log.h" #include "llama.h" #include "mtmd.h" #include "mtmd-helper.h" #include "chat.h" -#include "arg.h" // for common_remote_get_content; TODO: use download.h only #include "base64.hpp" #include "server-common.h" @@ -779,7 +779,7 @@ static void handle_media( // download remote image // TODO @ngxson : maybe make these params configurable common_remote_params params; - params.headers.push_back("User-Agent: llama.cpp/" + build_info); + params.headers.push_back({"User-Agent", "llama.cpp/" + build_info}); params.max_size = 1024 * 1024 * 10; // 10MB params.timeout = 10; // seconds SRV_INF("downloading image from '%s'\n", url.c_str()); diff --git a/vendor/cpp-httplib/httplib.cpp b/vendor/cpp-httplib/httplib.cpp index b86e6a2310..a437a36ed7 100644 --- a/vendor/cpp-httplib/httplib.cpp +++ b/vendor/cpp-httplib/httplib.cpp @@ -9,7 +9,7 @@ namespace httplib { namespace detail { bool is_hex(char c, int &v) { - if (0x20 <= c && isdigit(c)) { + if (isdigit(c)) { v = c - '0'; return true; } else if ('A' <= c && c <= 'F') { @@ -49,6 +49,90 @@ std::string from_i_to_hex(size_t n) { return ret; } +std::string compute_etag(const FileStat &fs) { + if (!fs.is_file()) { return std::string(); } + + // If mtime cannot be determined (negative value indicates an error + // or sentinel), do not generate an ETag. Returning a neutral / fixed + // value like 0 could collide with a real file that legitimately has + // mtime == 0 (epoch) and lead to misleading validators. + auto mtime_raw = fs.mtime(); + if (mtime_raw < 0) { return std::string(); } + + auto mtime = static_cast(mtime_raw); + auto size = fs.size(); + + return std::string("W/\"") + from_i_to_hex(mtime) + "-" + + from_i_to_hex(size) + "\""; +} + +// Format time_t as HTTP-date (RFC 9110 Section 5.6.7): "Sun, 06 Nov 1994 +// 08:49:37 GMT" This implementation is defensive: it validates `mtime`, checks +// return values from `gmtime_r`/`gmtime_s`, and ensures `strftime` succeeds. +std::string file_mtime_to_http_date(time_t mtime) { + if (mtime < 0) { return std::string(); } + + struct tm tm_buf; +#ifdef _WIN32 + if (gmtime_s(&tm_buf, &mtime) != 0) { return std::string(); } +#else + if (gmtime_r(&mtime, &tm_buf) == nullptr) { return std::string(); } +#endif + char buf[64]; + if (strftime(buf, sizeof(buf), "%a, %d %b %Y %H:%M:%S GMT", &tm_buf) == 0) { + return std::string(); + } + + return std::string(buf); +} + +// Parse HTTP-date (RFC 9110 Section 5.6.7) to time_t. Returns -1 on failure. +time_t parse_http_date(const std::string &date_str) { + struct tm tm_buf; + + // Create a classic locale object once for all parsing attempts + const std::locale classic_locale = std::locale::classic(); + + // Try to parse using std::get_time (C++11, cross-platform) + auto try_parse = [&](const char *fmt) -> bool { + std::istringstream ss(date_str); + ss.imbue(classic_locale); + + memset(&tm_buf, 0, sizeof(tm_buf)); + ss >> std::get_time(&tm_buf, fmt); + + return !ss.fail(); + }; + + // RFC 9110 preferred format (HTTP-date): "Sun, 06 Nov 1994 08:49:37 GMT" + if (!try_parse("%a, %d %b %Y %H:%M:%S")) { + // RFC 850 format: "Sunday, 06-Nov-94 08:49:37 GMT" + if (!try_parse("%A, %d-%b-%y %H:%M:%S")) { + // asctime format: "Sun Nov 6 08:49:37 1994" + if (!try_parse("%a %b %d %H:%M:%S %Y")) { + return static_cast(-1); + } + } + } + +#ifdef _WIN32 + return _mkgmtime(&tm_buf); +#else + return timegm(&tm_buf); +#endif +} + +bool is_weak_etag(const std::string &s) { + // Check if the string is a weak ETag (starts with 'W/"') + return s.size() > 3 && s[0] == 'W' && s[1] == '/' && s[2] == '"'; +} + +bool is_strong_etag(const std::string &s) { + // Check if the string is a strong ETag (starts and ends with '"', at least 2 + // chars) + return s.size() >= 2 && s[0] == '"' && s.back() == '"'; +} + size_t to_utf8(int code, char *buff) { if (code < 0x0080) { buff[0] = static_cast(code & 0x7F); @@ -168,6 +252,15 @@ bool FileStat::is_dir() const { return ret_ >= 0 && S_ISDIR(st_.st_mode); } +time_t FileStat::mtime() const { + return ret_ >= 0 ? static_cast(st_.st_mtime) + : static_cast(-1); +} + +size_t FileStat::size() const { + return ret_ >= 0 ? static_cast(st_.st_size) : 0; +} + std::string encode_path(const std::string &s) { std::string result; result.reserve(s.size()); @@ -209,6 +302,149 @@ std::string file_extension(const std::string &path) { bool is_space_or_tab(char c) { return c == ' ' || c == '\t'; } +template +bool parse_header(const char *beg, const char *end, T fn); + +template +bool parse_header(const char *beg, const char *end, T fn) { + // Skip trailing spaces and tabs. + while (beg < end && is_space_or_tab(end[-1])) { + end--; + } + + auto p = beg; + while (p < end && *p != ':') { + p++; + } + + auto name = std::string(beg, p); + if (!detail::fields::is_field_name(name)) { return false; } + + if (p == end) { return false; } + + auto key_end = p; + + if (*p++ != ':') { return false; } + + while (p < end && is_space_or_tab(*p)) { + p++; + } + + if (p <= end) { + auto key_len = key_end - beg; + if (!key_len) { return false; } + + auto key = std::string(beg, key_end); + auto val = std::string(p, end); + + if (!detail::fields::is_field_value(val)) { return false; } + + if (case_ignore::equal(key, "Location") || + case_ignore::equal(key, "Referer")) { + fn(key, val); + } else { + fn(key, decode_path_component(val)); + } + + return true; + } + + return false; +} + +bool parse_trailers(stream_line_reader &line_reader, Headers &dest, + const Headers &src_headers) { + // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked + // transfer coding is complete when a chunk with a chunk-size of zero is + // received, possibly followed by a trailer section, and finally terminated by + // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 + // + // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section + // doesn't care for the existence of the final CRLF. In other words, it seems + // to be ok whether the final CRLF exists or not in the chunked data. + // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 + // + // According to the reference code in RFC 9112, cpp-httplib now allows + // chunked transfer coding data without the final CRLF. + + // RFC 7230 Section 4.1.2 - Headers prohibited in trailers + thread_local case_ignore::unordered_set prohibited_trailers = { + "transfer-encoding", + "content-length", + "host", + "authorization", + "www-authenticate", + "proxy-authenticate", + "proxy-authorization", + "cookie", + "set-cookie", + "cache-control", + "expect", + "max-forwards", + "pragma", + "range", + "te", + "age", + "expires", + "date", + "location", + "retry-after", + "vary", + "warning", + "content-encoding", + "content-type", + "content-range", + "trailer"}; + + case_ignore::unordered_set declared_trailers; + auto trailer_header = get_header_value(src_headers, "Trailer", "", 0); + if (trailer_header && std::strlen(trailer_header)) { + auto len = std::strlen(trailer_header); + split(trailer_header, trailer_header + len, ',', + [&](const char *b, const char *e) { + const char *kbeg = b; + const char *kend = e; + while (kbeg < kend && (*kbeg == ' ' || *kbeg == '\t')) { + ++kbeg; + } + while (kend > kbeg && (kend[-1] == ' ' || kend[-1] == '\t')) { + --kend; + } + std::string key(kbeg, static_cast(kend - kbeg)); + if (!key.empty() && + prohibited_trailers.find(key) == prohibited_trailers.end()) { + declared_trailers.insert(key); + } + }); + } + + size_t trailer_header_count = 0; + while (strcmp(line_reader.ptr(), "\r\n") != 0) { + if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { return false; } + if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { return false; } + + constexpr auto line_terminator_len = 2; + auto line_beg = line_reader.ptr(); + auto line_end = + line_reader.ptr() + line_reader.size() - line_terminator_len; + + if (!parse_header(line_beg, line_end, + [&](const std::string &key, const std::string &val) { + if (declared_trailers.find(key) != + declared_trailers.end()) { + dest.emplace(key, val); + trailer_header_count++; + } + })) { + return false; + } + + if (!line_reader.getline()) { return false; } + } + + return true; +} + std::pair trim(const char *b, const char *e, size_t left, size_t right) { while (b + left < e && is_space_or_tab(b[left])) { @@ -280,6 +516,42 @@ void split(const char *b, const char *e, char d, size_t m, } } +bool split_find(const char *b, const char *e, char d, size_t m, + std::function fn) { + size_t i = 0; + size_t beg = 0; + size_t count = 1; + + while (e ? (b + i < e) : (b[i] != '\0')) { + if (b[i] == d && count < m) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + auto found = fn(&b[r.first], &b[r.second]); + if (found) { return true; } + } + beg = i + 1; + count++; + } + i++; + } + + if (i) { + auto r = trim(b, e, beg, i); + if (r.first < r.second) { + auto found = fn(&b[r.first], &b[r.second]); + if (found) { return true; } + } + } + + return false; +} + +bool split_find(const char *b, const char *e, char d, + std::function fn) { + return split_find(b, e, d, (std::numeric_limits::max)(), + std::move(fn)); +} + stream_line_reader::stream_line_reader(Stream &strm, char *fixed_buffer, size_t fixed_buffer_size) : strm_(strm), fixed_buffer_(fixed_buffer), @@ -1892,6 +2164,27 @@ bool zstd_decompressor::decompress(const char *data, size_t data_length, } #endif +std::unique_ptr +create_decompressor(const std::string &encoding) { + std::unique_ptr decompressor; + + if (encoding == "gzip" || encoding == "deflate") { +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + decompressor = detail::make_unique(); +#endif + } else if (encoding.find("br") != std::string::npos) { +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + decompressor = detail::make_unique(); +#endif + } else if (encoding == "zstd" || encoding.find("zstd") != std::string::npos) { +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + decompressor = detail::make_unique(); +#endif + } + + return decompressor; +} + bool is_prohibited_header_name(const std::string &name) { using udl::operator""_t; @@ -1928,53 +2221,6 @@ const char *get_header_value(const Headers &headers, return def; } -template -bool parse_header(const char *beg, const char *end, T fn) { - // Skip trailing spaces and tabs. - while (beg < end && is_space_or_tab(end[-1])) { - end--; - } - - auto p = beg; - while (p < end && *p != ':') { - p++; - } - - auto name = std::string(beg, p); - if (!detail::fields::is_field_name(name)) { return false; } - - if (p == end) { return false; } - - auto key_end = p; - - if (*p++ != ':') { return false; } - - while (p < end && is_space_or_tab(*p)) { - p++; - } - - if (p <= end) { - auto key_len = key_end - beg; - if (!key_len) { return false; } - - auto key = std::string(beg, key_end); - auto val = std::string(p, end); - - if (!detail::fields::is_field_value(val)) { return false; } - - if (case_ignore::equal(key, "Location") || - case_ignore::equal(key, "Referer")) { - fn(key, val); - } else { - fn(key, decode_path_component(val)); - } - - return true; - } - - return false; -} - bool read_headers(Stream &strm, Headers &headers) { const auto bufsiz = 2048; char buf[bufsiz]; @@ -2026,10 +2272,18 @@ bool read_content_with_length(Stream &strm, size_t len, ContentReceiverWithProgress out) { char buf[CPPHTTPLIB_RECV_BUFSIZ]; + detail::BodyReader br; + br.stream = &strm; + br.content_length = len; + br.chunked = false; + br.bytes_read = 0; + br.last_error = Error::Success; + size_t r = 0; while (r < len) { auto read_len = static_cast(len - r); - auto n = strm.read(buf, (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ)); + auto to_read = (std::min)(read_len, CPPHTTPLIB_RECV_BUFSIZ); + auto n = detail::read_body_content(&strm, br, buf, to_read); if (n <= 0) { return false; } if (!out(buf, static_cast(n), r, len)) { return false; } @@ -2089,125 +2343,35 @@ template ReadContentResult read_content_chunked(Stream &strm, T &x, size_t payload_max_length, ContentReceiverWithProgress out) { - const auto bufsiz = 16; - char buf[bufsiz]; + detail::ChunkedDecoder dec(strm); - stream_line_reader line_reader(strm, buf, bufsiz); - - if (!line_reader.getline()) { return ReadContentResult::Error; } - - unsigned long chunk_len; + char buf[CPPHTTPLIB_RECV_BUFSIZ]; size_t total_len = 0; - while (true) { - char *end_ptr; - chunk_len = std::strtoul(line_reader.ptr(), &end_ptr, 16); + for (;;) { + size_t chunk_offset = 0; + size_t chunk_total = 0; + auto n = dec.read_payload(buf, sizeof(buf), chunk_offset, chunk_total); + if (n < 0) { return ReadContentResult::Error; } - if (end_ptr == line_reader.ptr()) { return ReadContentResult::Error; } - if (chunk_len == ULONG_MAX) { return ReadContentResult::Error; } + if (n == 0) { + if (!dec.parse_trailers_into(x.trailers, x.headers)) { + return ReadContentResult::Error; + } + return ReadContentResult::Success; + } - if (chunk_len == 0) { break; } - - // Check if adding this chunk would exceed the payload limit if (total_len > payload_max_length || - payload_max_length - total_len < chunk_len) { + payload_max_length - total_len < static_cast(n)) { return ReadContentResult::PayloadTooLarge; } - total_len += chunk_len; - - if (!read_content_with_length(strm, chunk_len, nullptr, out)) { + if (!out(buf, static_cast(n), chunk_offset, chunk_total)) { return ReadContentResult::Error; } - if (!line_reader.getline()) { return ReadContentResult::Error; } - - if (strcmp(line_reader.ptr(), "\r\n") != 0) { - return ReadContentResult::Error; - } - - if (!line_reader.getline()) { return ReadContentResult::Error; } + total_len += static_cast(n); } - - assert(chunk_len == 0); - - // NOTE: In RFC 9112, '7.1 Chunked Transfer Coding' mentions "The chunked - // transfer coding is complete when a chunk with a chunk-size of zero is - // received, possibly followed by a trailer section, and finally terminated by - // an empty line". https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1 - // - // In '7.1.3. Decoding Chunked', however, the pseudo-code in the section - // does't care for the existence of the final CRLF. In other words, it seems - // to be ok whether the final CRLF exists or not in the chunked data. - // https://www.rfc-editor.org/rfc/rfc9112.html#section-7.1.3 - // - // According to the reference code in RFC 9112, cpp-httplib now allows - // chunked transfer coding data without the final CRLF. - if (!line_reader.getline()) { return ReadContentResult::Success; } - - // RFC 7230 Section 4.1.2 - Headers prohibited in trailers - thread_local case_ignore::unordered_set prohibited_trailers = { - // Message framing - "transfer-encoding", "content-length", - - // Routing - "host", - - // Authentication - "authorization", "www-authenticate", "proxy-authenticate", - "proxy-authorization", "cookie", "set-cookie", - - // Request modifiers - "cache-control", "expect", "max-forwards", "pragma", "range", "te", - - // Response control - "age", "expires", "date", "location", "retry-after", "vary", "warning", - - // Payload processing - "content-encoding", "content-type", "content-range", "trailer"}; - - // Parse declared trailer headers once for performance - case_ignore::unordered_set declared_trailers; - if (has_header(x.headers, "Trailer")) { - auto trailer_header = get_header_value(x.headers, "Trailer", "", 0); - auto len = std::strlen(trailer_header); - - split(trailer_header, trailer_header + len, ',', - [&](const char *b, const char *e) { - std::string key(b, e); - if (prohibited_trailers.find(key) == prohibited_trailers.end()) { - declared_trailers.insert(key); - } - }); - } - - size_t trailer_header_count = 0; - while (strcmp(line_reader.ptr(), "\r\n") != 0) { - if (line_reader.size() > CPPHTTPLIB_HEADER_MAX_LENGTH) { - return ReadContentResult::Error; - } - - // Check trailer header count limit - if (trailer_header_count >= CPPHTTPLIB_HEADER_MAX_COUNT) { - return ReadContentResult::Error; - } - - // Exclude line terminator - constexpr auto line_terminator_len = 2; - auto end = line_reader.ptr() + line_reader.size() - line_terminator_len; - - parse_header(line_reader.ptr(), end, - [&](const std::string &key, const std::string &val) { - if (declared_trailers.find(key) != declared_trailers.end()) { - x.trailers.emplace(key, val); - trailer_header_count++; - } - }); - - if (!line_reader.getline()) { return ReadContentResult::Error; } - } - - return ReadContentResult::Success; } bool is_chunked_transfer_encoding(const Headers &headers) { @@ -2223,27 +2387,13 @@ bool prepare_content_receiver(T &x, int &status, std::string encoding = x.get_header_value("Content-Encoding"); std::unique_ptr decompressor; - if (encoding == "gzip" || encoding == "deflate") { -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif - } else if (encoding.find("br") != std::string::npos) { -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif - } else if (encoding == "zstd") { -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - decompressor = detail::make_unique(); -#else - status = StatusCode::UnsupportedMediaType_415; - return false; -#endif + if (!encoding.empty()) { + decompressor = detail::create_decompressor(encoding); + if (!decompressor) { + // Unsupported encoding or no support compiled in + status = StatusCode::UnsupportedMediaType_415; + return false; + } } if (decompressor) { @@ -2329,7 +2479,7 @@ bool read_content(Stream &strm, T &x, size_t payload_max_length, int &status, ssize_t write_request_line(Stream &strm, const std::string &method, const std::string &path) { std::string s = method; - s += " "; + s += ' '; s += path; s += " HTTP/1.1\r\n"; return strm.write(s.data(), s.size()); @@ -2338,7 +2488,7 @@ ssize_t write_request_line(Stream &strm, const std::string &method, ssize_t write_response_line(Stream &strm, int status) { std::string s = "HTTP/1.1 "; s += std::to_string(status); - s += " "; + s += ' '; s += httplib::status_message(status); s += "\r\n"; return strm.write(s.data(), s.size()); @@ -2601,8 +2751,8 @@ bool redirect(T &cli, Request &req, Response &res, auto ret = cli.send(new_req, new_res, error); if (ret) { - req = new_req; - res = new_res; + req = std::move(new_req); + res = std::move(new_res); if (res.location.empty()) { res.location = location; } } @@ -2613,9 +2763,9 @@ std::string params_to_query_str(const Params ¶ms) { std::string query; for (auto it = params.begin(); it != params.end(); ++it) { - if (it != params.begin()) { query += "&"; } + if (it != params.begin()) { query += '&'; } query += encode_query_component(it->first); - query += "="; + query += '='; query += encode_query_component(it->second); } return query; @@ -2648,6 +2798,38 @@ void parse_query_text(const std::string &s, Params ¶ms) { parse_query_text(s.data(), s.size(), params); } +// Normalize a query string by decoding and re-encoding each key/value pair +// while preserving the original parameter order. This avoids double-encoding +// and ensures consistent encoding without reordering (unlike Params which +// uses std::multimap and sorts keys). +std::string normalize_query_string(const std::string &query) { + std::string result; + split(query.data(), query.data() + query.size(), '&', + [&](const char *b, const char *e) { + std::string key; + std::string val; + divide(b, static_cast(e - b), '=', + [&](const char *lhs_data, std::size_t lhs_size, + const char *rhs_data, std::size_t rhs_size) { + key.assign(lhs_data, lhs_size); + val.assign(rhs_data, rhs_size); + }); + + if (!key.empty()) { + auto dec_key = decode_query_component(key); + auto dec_val = decode_query_component(val); + + if (!result.empty()) { result += '&'; } + result += encode_query_component(dec_key); + if (!val.empty() || std::find(b, e, '=') != e) { + result += '='; + result += encode_query_component(dec_val); + } + } + }); + return result; +} + bool parse_multipart_boundary(const std::string &content_type, std::string &boundary) { auto boundary_keyword = "boundary="; @@ -2840,7 +3022,7 @@ bool parse_accept_header(const std::string &s, return; } - entries.push_back(accept_entry); + entries.push_back(std::move(accept_entry)); }); // Return false if any invalid entry was found @@ -2857,8 +3039,8 @@ bool parse_accept_header(const std::string &s, // Extract sorted media types content_types.reserve(entries.size()); - for (const auto &entry : entries) { - content_types.push_back(entry.media_type); + for (auto &entry : entries) { + content_types.push_back(std::move(entry.media_type)); } return true; @@ -2869,7 +3051,7 @@ public: FormDataParser() = default; void set_boundary(std::string &&boundary) { - boundary_ = boundary; + boundary_ = std::move(boundary); dash_boundary_crlf_ = dash_ + boundary_ + crlf_; crlf_dash_boundary_ = crlf_ + dash_ + boundary_; } @@ -3342,9 +3524,9 @@ std::string make_content_range_header_field( std::string field = "bytes "; field += std::to_string(st); - field += "-"; + field += '-'; field += std::to_string(ed); - field += "/"; + field += '/'; field += std::to_string(content_length); return field; } @@ -3721,7 +3903,7 @@ bool parse_www_authenticate(const Response &res, static_cast(m.length(2))) : s.substr(static_cast(m.position(3)), static_cast(m.length(3))); - auth[key] = val; + auth[std::move(key)] = std::move(val); } return true; } @@ -3734,7 +3916,7 @@ class ContentProviderAdapter { public: explicit ContentProviderAdapter( ContentProviderWithoutLength &&content_provider) - : content_provider_(content_provider) {} + : content_provider_(std::move(content_provider)) {} bool operator()(size_t offset, size_t, DataSink &sink) { return content_provider_(offset, sink); @@ -3744,8 +3926,189 @@ private: ContentProviderWithoutLength content_provider_; }; +// NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 +namespace fields { + +bool is_token_char(char c) { + return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || + c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || + c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; +} + +bool is_token(const std::string &s) { + if (s.empty()) { return false; } + for (auto c : s) { + if (!is_token_char(c)) { return false; } + } + return true; +} + +bool is_field_name(const std::string &s) { return is_token(s); } + +bool is_vchar(char c) { return c >= 33 && c <= 126; } + +bool is_obs_text(char c) { return 128 <= static_cast(c); } + +bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } + +bool is_field_content(const std::string &s) { + if (s.empty()) { return true; } + + if (s.size() == 1) { + return is_field_vchar(s[0]); + } else if (s.size() == 2) { + return is_field_vchar(s[0]) && is_field_vchar(s[1]); + } else { + size_t i = 0; + + if (!is_field_vchar(s[i])) { return false; } + i++; + + while (i < s.size() - 1) { + auto c = s[i++]; + if (c == ' ' || c == '\t' || is_field_vchar(c)) { + } else { + return false; + } + } + + return is_field_vchar(s[i]); + } +} + +bool is_field_value(const std::string &s) { return is_field_content(s); } + +} // namespace fields + } // namespace detail +const char *status_message(int status) { + switch (status) { + case StatusCode::Continue_100: return "Continue"; + case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; + case StatusCode::Processing_102: return "Processing"; + case StatusCode::EarlyHints_103: return "Early Hints"; + case StatusCode::OK_200: return "OK"; + case StatusCode::Created_201: return "Created"; + case StatusCode::Accepted_202: return "Accepted"; + case StatusCode::NonAuthoritativeInformation_203: + return "Non-Authoritative Information"; + case StatusCode::NoContent_204: return "No Content"; + case StatusCode::ResetContent_205: return "Reset Content"; + case StatusCode::PartialContent_206: return "Partial Content"; + case StatusCode::MultiStatus_207: return "Multi-Status"; + case StatusCode::AlreadyReported_208: return "Already Reported"; + case StatusCode::IMUsed_226: return "IM Used"; + case StatusCode::MultipleChoices_300: return "Multiple Choices"; + case StatusCode::MovedPermanently_301: return "Moved Permanently"; + case StatusCode::Found_302: return "Found"; + case StatusCode::SeeOther_303: return "See Other"; + case StatusCode::NotModified_304: return "Not Modified"; + case StatusCode::UseProxy_305: return "Use Proxy"; + case StatusCode::unused_306: return "unused"; + case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; + case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; + case StatusCode::BadRequest_400: return "Bad Request"; + case StatusCode::Unauthorized_401: return "Unauthorized"; + case StatusCode::PaymentRequired_402: return "Payment Required"; + case StatusCode::Forbidden_403: return "Forbidden"; + case StatusCode::NotFound_404: return "Not Found"; + case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; + case StatusCode::NotAcceptable_406: return "Not Acceptable"; + case StatusCode::ProxyAuthenticationRequired_407: + return "Proxy Authentication Required"; + case StatusCode::RequestTimeout_408: return "Request Timeout"; + case StatusCode::Conflict_409: return "Conflict"; + case StatusCode::Gone_410: return "Gone"; + case StatusCode::LengthRequired_411: return "Length Required"; + case StatusCode::PreconditionFailed_412: return "Precondition Failed"; + case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; + case StatusCode::UriTooLong_414: return "URI Too Long"; + case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; + case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; + case StatusCode::ExpectationFailed_417: return "Expectation Failed"; + case StatusCode::ImATeapot_418: return "I'm a teapot"; + case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; + case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; + case StatusCode::Locked_423: return "Locked"; + case StatusCode::FailedDependency_424: return "Failed Dependency"; + case StatusCode::TooEarly_425: return "Too Early"; + case StatusCode::UpgradeRequired_426: return "Upgrade Required"; + case StatusCode::PreconditionRequired_428: return "Precondition Required"; + case StatusCode::TooManyRequests_429: return "Too Many Requests"; + case StatusCode::RequestHeaderFieldsTooLarge_431: + return "Request Header Fields Too Large"; + case StatusCode::UnavailableForLegalReasons_451: + return "Unavailable For Legal Reasons"; + case StatusCode::NotImplemented_501: return "Not Implemented"; + case StatusCode::BadGateway_502: return "Bad Gateway"; + case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; + case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; + case StatusCode::HttpVersionNotSupported_505: + return "HTTP Version Not Supported"; + case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; + case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; + case StatusCode::LoopDetected_508: return "Loop Detected"; + case StatusCode::NotExtended_510: return "Not Extended"; + case StatusCode::NetworkAuthenticationRequired_511: + return "Network Authentication Required"; + + default: + case StatusCode::InternalServerError_500: return "Internal Server Error"; + } +} + +std::string to_string(const Error error) { + switch (error) { + case Error::Success: return "Success (no error)"; + case Error::Unknown: return "Unknown"; + case Error::Connection: return "Could not establish connection"; + case Error::BindIPAddress: return "Failed to bind IP address"; + case Error::Read: return "Failed to read connection"; + case Error::Write: return "Failed to write connection"; + case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; + case Error::Canceled: return "Connection handling canceled"; + case Error::SSLConnection: return "SSL connection failed"; + case Error::SSLLoadingCerts: return "SSL certificate loading failed"; + case Error::SSLServerVerification: return "SSL server verification failed"; + case Error::SSLServerHostnameVerification: + return "SSL server hostname verification failed"; + case Error::UnsupportedMultipartBoundaryChars: + return "Unsupported HTTP multipart boundary characters"; + case Error::Compression: return "Compression failed"; + case Error::ConnectionTimeout: return "Connection timed out"; + case Error::ProxyConnection: return "Proxy connection failed"; + case Error::ConnectionClosed: return "Connection closed by server"; + case Error::Timeout: return "Read timeout"; + case Error::ResourceExhaustion: return "Resource exhaustion"; + case Error::TooManyFormDataFiles: return "Too many form data files"; + case Error::ExceedMaxPayloadSize: return "Exceeded maximum payload size"; + case Error::ExceedUriMaxLength: return "Exceeded maximum URI length"; + case Error::ExceedMaxSocketDescriptorCount: + return "Exceeded maximum socket descriptor count"; + case Error::InvalidRequestLine: return "Invalid request line"; + case Error::InvalidHTTPMethod: return "Invalid HTTP method"; + case Error::InvalidHTTPVersion: return "Invalid HTTP version"; + case Error::InvalidHeaders: return "Invalid headers"; + case Error::MultipartParsing: return "Multipart parsing failed"; + case Error::OpenFile: return "Failed to open file"; + case Error::Listen: return "Failed to listen on socket"; + case Error::GetSockName: return "Failed to get socket name"; + case Error::UnsupportedAddressFamily: return "Unsupported address family"; + case Error::HTTPParsing: return "HTTP parsing failed"; + case Error::InvalidRangeHeader: return "Invalid Range header"; + default: break; + } + + return "Invalid"; +} + +std::ostream &operator<<(std::ostream &os, const Error &obj) { + os << to_string(obj); + os << " (" << static_cast::type>(obj) << ')'; + return os; +} + std::string hosted_at(const std::string &hostname) { std::vector addrs; hosted_at(hostname, addrs); @@ -3779,7 +4142,7 @@ void hosted_at(const std::string &hostname, auto dummy = -1; if (detail::get_ip_and_port(addr, sizeof(struct sockaddr_storage), ip, dummy)) { - addrs.push_back(ip); + addrs.emplace_back(std::move(ip)); } } } @@ -4319,6 +4682,67 @@ ssize_t Stream::write(const std::string &s) { return write(s.data(), s.size()); } +// BodyReader implementation +ssize_t detail::BodyReader::read(char *buf, size_t len) { + if (!stream) { + last_error = Error::Connection; + return -1; + } + if (eof) { return 0; } + + if (!chunked) { + // Content-Length based reading + if (bytes_read >= content_length) { + eof = true; + return 0; + } + + auto remaining = content_length - bytes_read; + auto to_read = (std::min)(len, remaining); + auto n = stream->read(buf, to_read); + + if (n < 0) { + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return n; + } + if (n == 0) { + // Unexpected EOF before content_length + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return 0; + } + + bytes_read += static_cast(n); + if (bytes_read >= content_length) { eof = true; } + return n; + } + + // Chunked transfer encoding: delegate to shared decoder instance. + if (!chunked_decoder) { chunked_decoder.reset(new ChunkedDecoder(*stream)); } + + size_t chunk_offset = 0; + size_t chunk_total = 0; + auto n = chunked_decoder->read_payload(buf, len, chunk_offset, chunk_total); + if (n < 0) { + last_error = stream->get_error(); + if (last_error == Error::Success) { last_error = Error::Read; } + eof = true; + return n; + } + + if (n == 0) { + // Final chunk observed. Leave trailer parsing to the caller (StreamHandle). + eof = true; + return 0; + } + + bytes_read += static_cast(n); + return n; +} + namespace detail { void calc_actual_timeout(time_t max_timeout_msec, time_t duration_msec, @@ -4395,7 +4819,10 @@ ssize_t SocketStream::read(char *ptr, size_t size) { } } - if (!wait_readable()) { return -1; } + if (!wait_readable()) { + error_ = Error::Timeout; + return -1; + } read_buff_off_ = 0; read_buff_content_size_ = 0; @@ -4404,6 +4831,11 @@ ssize_t SocketStream::read(char *ptr, size_t size) { auto n = read_socket(sock_, read_buff_.data(), read_buff_size_, CPPHTTPLIB_RECV_FLAGS); if (n <= 0) { + if (n == 0) { + error_ = Error::ConnectionClosed; + } else { + error_ = Error::Read; + } return n; } else if (n <= static_cast(size)) { memcpy(ptr, read_buff_.data(), static_cast(n)); @@ -4415,7 +4847,15 @@ ssize_t SocketStream::read(char *ptr, size_t size) { return static_cast(size); } } else { - return read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + auto n = read_socket(sock_, ptr, size, CPPHTTPLIB_RECV_FLAGS); + if (n <= 0) { + if (n == 0) { + error_ = Error::ConnectionClosed; + } else { + error_ = Error::Read; + } + } + return n; } } @@ -4579,19 +5019,22 @@ bool RegexMatcher::match(Request &request) const { return std::regex_match(request.path, request.matches, regex_); } -std::string make_host_and_port_string(const std::string &host, int port, - bool is_ssl) { - std::string result; - +// Enclose IPv6 address in brackets if needed +std::string prepare_host_string(const std::string &host) { // Enclose IPv6 address in brackets (but not if already enclosed) if (host.find(':') == std::string::npos || (!host.empty() && host[0] == '[')) { // IPv4, hostname, or already bracketed IPv6 - result = host; + return host; } else { // IPv6 address without brackets - result = "[" + host + "]"; + return "[" + host + "]"; } +} + +std::string make_host_and_port_string(const std::string &host, int port, + bool is_ssl) { + auto result = prepare_host_string(host); // Append port if not default if ((!is_ssl && port == 80) || (is_ssl && port == 443)) { @@ -4603,6 +5046,29 @@ std::string make_host_and_port_string(const std::string &host, int port, return result; } +// Create "host:port" string always including port number (for CONNECT method) +std::string +make_host_and_port_string_always_port(const std::string &host, int port) { + return prepare_host_string(host) + ":" + std::to_string(port); +} + +template +bool check_and_write_headers(Stream &strm, Headers &headers, + T header_writer, Error &error) { + for (const auto &h : headers) { + if (!detail::fields::is_field_name(h.first) || + !detail::fields::is_field_value(h.second)) { + error = Error::InvalidHeaders; + return false; + } + } + if (header_writer(strm, headers) <= 0) { + error = Error::Write; + return false; + } + return true; +} + } // namespace detail // HTTP server implementation @@ -4694,7 +5160,7 @@ bool Server::set_mount_point(const std::string &mount_point, if (stat.is_dir()) { std::string mnt = !mount_point.empty() ? mount_point : "/"; if (!mnt.empty() && mnt[0] == '/') { - base_dirs_.push_back({mnt, dir, std::move(headers)}); + base_dirs_.push_back({std::move(mnt), dir, std::move(headers)}); return true; } } @@ -5010,7 +5476,7 @@ bool Server::write_response_core(Stream &strm, bool close_connection, { detail::BufferStream bstrm; if (!detail::write_response_line(bstrm, res.status)) { return false; } - if (!header_writer_(bstrm, res.headers)) { return false; } + if (header_writer_(bstrm, res.headers) <= 0) { return false; } // Flush buffer auto &data = bstrm.get_buffer(); @@ -5103,7 +5569,16 @@ bool Server::read_content(Stream &strm, Request &req, Response &res) { strm, req, res, // Regular [&](const char *buf, size_t n) { - if (req.body.size() + n > req.body.max_size()) { return false; } + // Prevent arithmetic overflow when checking sizes. + // Avoid computing (req.body.size() + n) directly because + // adding two unsigned `size_t` values can wrap around and + // produce a small result instead of indicating overflow. + // Instead, check using subtraction: ensure `n` does not + // exceed the remaining capacity `max_size() - size()`. + if (req.body.size() >= req.body.max_size() || + n > req.body.max_size() - req.body.size()) { + return false; + } req.body.append(buf, n); return true; }, @@ -5186,10 +5661,39 @@ bool Server::read_content_core( // RFC 7230 Section 3.3.3: If this is a request message and none of the above // are true (no Transfer-Encoding and no Content-Length), then the message // body length is zero (no message body is present). + // + // For non-SSL builds, peek into the socket to detect clients that send a + // body without a Content-Length header (raw HTTP over TCP). If there is + // pending data that exceeds the configured payload limit, treat this as an + // oversized request and fail early (causing connection close). For SSL + // builds we cannot reliably peek the decrypted application bytes, so keep + // the original behaviour. +#if !defined(CPPHTTPLIB_OPENSSL_SUPPORT) && !defined(_WIN32) + if (!req.has_header("Content-Length") && + !detail::is_chunked_transfer_encoding(req.headers)) { + socket_t s = strm.socket(); + if (s != INVALID_SOCKET) { + // Peek up to payload_max_length_ + 1 bytes. If more than + // payload_max_length_ bytes are pending, reject the request. + size_t to_peek = + (payload_max_length_ > 0) + ? (std::min)(payload_max_length_ + 1, static_cast(4096)) + : 1; + std::vector peekbuf(to_peek); + ssize_t n = ::recv(s, peekbuf.data(), to_peek, MSG_PEEK); + if (n > 0 && static_cast(n) > payload_max_length_) { + // Indicate failure so connection will be closed. + return false; + } + } + return true; + } +#else if (!req.has_header("Content-Length") && !detail::is_chunked_transfer_encoding(req.headers)) { return true; } +#endif if (!detail::read_content(strm, req, payload_max_length_, res.status, nullptr, out, true)) { @@ -5207,7 +5711,7 @@ bool Server::read_content_core( return true; } -bool Server::handle_file_request(const Request &req, Response &res) { +bool Server::handle_file_request(Request &req, Response &res) { for (const auto &entry : base_dirs_) { // Prefix match if (!req.path.compare(0, entry.mount_point.size(), entry.mount_point)) { @@ -5228,6 +5732,20 @@ bool Server::handle_file_request(const Request &req, Response &res) { res.set_header(kv.first, kv.second); } + auto etag = detail::compute_etag(stat); + if (!etag.empty()) { res.set_header("ETag", etag); } + + auto mtime = stat.mtime(); + + auto last_modified = detail::file_mtime_to_http_date(mtime); + if (!last_modified.empty()) { + res.set_header("Last-Modified", last_modified); + } + + if (check_if_not_modified(req, res, etag, mtime)) { return true; } + + check_if_range(req, etag, mtime); + auto mm = std::make_shared(path.c_str()); if (!mm->is_open()) { output_error_log(Error::OpenFile, &req); @@ -5257,6 +5775,79 @@ bool Server::handle_file_request(const Request &req, Response &res) { return false; } +bool Server::check_if_not_modified(const Request &req, Response &res, + const std::string &etag, + time_t mtime) const { + // Handle conditional GET: + // 1. If-None-Match takes precedence (RFC 9110 Section 13.1.2) + // 2. If-Modified-Since is checked only when If-None-Match is absent + if (req.has_header("If-None-Match")) { + if (!etag.empty()) { + auto val = req.get_header_value("If-None-Match"); + + // NOTE: We use exact string matching here. This works correctly + // because our server always generates weak ETags (W/"..."), and + // clients typically send back the same ETag they received. + // RFC 9110 Section 8.8.3.2 allows weak comparison for + // If-None-Match, where W/"x" and "x" would match, but this + // simplified implementation requires exact matches. + auto ret = detail::split_find(val.data(), val.data() + val.size(), ',', + [&](const char *b, const char *e) { + return std::equal(b, e, "*") || + std::equal(b, e, etag.begin()); + }); + + if (ret) { + res.status = StatusCode::NotModified_304; + return true; + } + } + } else if (req.has_header("If-Modified-Since")) { + auto val = req.get_header_value("If-Modified-Since"); + auto t = detail::parse_http_date(val); + + if (t != static_cast(-1) && mtime <= t) { + res.status = StatusCode::NotModified_304; + return true; + } + } + return false; +} + +bool Server::check_if_range(Request &req, const std::string &etag, + time_t mtime) const { + // Handle If-Range for partial content requests (RFC 9110 + // Section 13.1.5). If-Range is only evaluated when Range header is + // present. If the validator matches, serve partial content; otherwise + // serve full content. + if (!req.ranges.empty() && req.has_header("If-Range")) { + auto val = req.get_header_value("If-Range"); + + auto is_valid_range = [&]() { + if (detail::is_strong_etag(val)) { + // RFC 9110 Section 13.1.5: If-Range requires strong ETag + // comparison. + return (!etag.empty() && val == etag); + } else if (detail::is_weak_etag(val)) { + // Weak ETags are not valid for If-Range (RFC 9110 Section 13.1.5) + return false; + } else { + // HTTP-date comparison + auto t = detail::parse_http_date(val); + return (t != static_cast(-1) && mtime <= t); + } + }; + + if (!is_valid_range()) { + // Validator doesn't match: ignore Range and serve full content + req.ranges.clear(); + return false; + } + } + + return true; +} + socket_t Server::create_server_socket(const std::string &host, int port, int socket_flags, @@ -5524,10 +6115,13 @@ void Server::apply_ranges(const Request &req, Response &res, res.set_header("Transfer-Encoding", "chunked"); if (type == detail::EncodingType::Gzip) { res.set_header("Content-Encoding", "gzip"); + res.set_header("Vary", "Accept-Encoding"); } else if (type == detail::EncodingType::Brotli) { res.set_header("Content-Encoding", "br"); + res.set_header("Vary", "Accept-Encoding"); } else if (type == detail::EncodingType::Zstd) { res.set_header("Content-Encoding", "zstd"); + res.set_header("Vary", "Accept-Encoding"); } } } @@ -5586,6 +6180,7 @@ void Server::apply_ranges(const Request &req, Response &res, })) { res.body.swap(compressed); res.set_header("Content-Encoding", content_encoding); + res.set_header("Vary", "Accept-Encoding"); } } } @@ -5663,6 +6258,10 @@ Server::process_request(Stream &strm, const std::string &remote_addr, Request req; req.start_time_ = std::chrono::steady_clock::now(); + req.remote_addr = remote_addr; + req.remote_port = remote_port; + req.local_addr = local_addr; + req.local_port = local_port; Response res; res.version = "HTTP/1.1"; @@ -5908,7 +6507,6 @@ ClientImpl::ClientImpl(const std::string &host, int port, const std::string &client_cert_path, const std::string &client_key_path) : host_(detail::escape_abstract_namespace_unix_domain(host)), port_(port), - host_and_port_(detail::make_host_and_port_string(host_, port, is_ssl())), client_cert_path_(client_cert_path), client_key_path_(client_key_path) {} ClientImpl::~ClientImpl() { @@ -6007,6 +6605,26 @@ bool ClientImpl::create_and_connect_socket(Socket &socket, return true; } +bool ClientImpl::ensure_socket_connection(Socket &socket, Error &error) { + return create_and_connect_socket(socket, error); +} + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT +bool SSLClient::ensure_socket_connection(Socket &socket, Error &error) { + if (!ClientImpl::ensure_socket_connection(socket, error)) { return false; } + + if (!proxy_host_.empty() && proxy_port_ != -1) { return true; } + + if (!initialize_ssl(socket, error)) { + shutdown_socket(socket); + close_socket(socket); + return false; + } + + return true; +} +#endif + void ClientImpl::shutdown_ssl(Socket & /*socket*/, bool /*shutdown_gracefully*/) { // If there are any requests in flight from threads other than us, then it's @@ -6119,7 +6737,7 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { } if (!is_alive) { - if (!create_and_connect_socket(socket_, error)) { + if (!ensure_socket_connection(socket_, error)) { output_error_log(error, &req); return false; } @@ -6137,9 +6755,11 @@ bool ClientImpl::send_(Request &req, Response &res, Error &error) { } } - if (!scli.initialize_ssl(socket_, error)) { - output_error_log(error, &req); - return false; + if (!proxy_host_.empty() && proxy_port_ != -1) { + if (!scli.initialize_ssl(socket_, error)) { + output_error_log(error, &req); + return false; + } } } #endif @@ -6212,6 +6832,343 @@ Result ClientImpl::send_(Request &&req) { #endif } +void ClientImpl::prepare_default_headers(Request &r, bool for_stream, + const std::string &ct) { + (void)for_stream; + for (const auto &header : default_headers_) { + if (!r.has_header(header.first)) { r.headers.insert(header); } + } + + if (!r.has_header("Host")) { + if (address_family_ == AF_UNIX) { + r.headers.emplace("Host", "localhost"); + } else { + r.headers.emplace( + "Host", detail::make_host_and_port_string(host_, port_, is_ssl())); + } + } + + if (!r.has_header("Accept")) { r.headers.emplace("Accept", "*/*"); } + + if (!r.content_receiver) { + if (!r.has_header("Accept-Encoding")) { + std::string accept_encoding; +#ifdef CPPHTTPLIB_BROTLI_SUPPORT + accept_encoding = "br"; +#endif +#ifdef CPPHTTPLIB_ZLIB_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "gzip, deflate"; +#endif +#ifdef CPPHTTPLIB_ZSTD_SUPPORT + if (!accept_encoding.empty()) { accept_encoding += ", "; } + accept_encoding += "zstd"; +#endif + r.set_header("Accept-Encoding", accept_encoding); + } + +#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT + if (!r.has_header("User-Agent")) { + auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; + r.set_header("User-Agent", agent); + } +#endif + } + + if (!r.body.empty()) { + if (!ct.empty() && !r.has_header("Content-Type")) { + r.headers.emplace("Content-Type", ct); + } + if (!r.has_header("Content-Length")) { + r.headers.emplace("Content-Length", std::to_string(r.body.size())); + } + } +} + +ClientImpl::StreamHandle +ClientImpl::open_stream(const std::string &method, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, + const std::string &content_type) { + StreamHandle handle; + handle.response = detail::make_unique(); + handle.error = Error::Success; + + auto query_path = params.empty() ? path : append_query_params(path, params); + handle.connection_ = detail::make_unique(); + + { + std::lock_guard guard(socket_mutex_); + + auto is_alive = false; + if (socket_.is_open()) { + is_alive = detail::is_socket_alive(socket_.sock); +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_alive && is_ssl()) { + if (detail::is_ssl_peer_could_be_closed(socket_.ssl, socket_.sock)) { + is_alive = false; + } + } +#endif + if (!is_alive) { + shutdown_ssl(socket_, false); + shutdown_socket(socket_); + close_socket(socket_); + } + } + + if (!is_alive) { + if (!ensure_socket_connection(socket_, handle.error)) { + handle.response.reset(); + return handle; + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl()) { + auto &scli = static_cast(*this); + if (!proxy_host_.empty() && proxy_port_ != -1) { + if (!scli.initialize_ssl(socket_, handle.error)) { + handle.response.reset(); + return handle; + } + } + } +#endif + } + + transfer_socket_ownership_to_handle(handle); + } + +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (is_ssl() && handle.connection_->ssl) { + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, handle.connection_->ssl, read_timeout_sec_, + read_timeout_usec_, write_timeout_sec_, write_timeout_usec_); + } else { + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_); + } +#else + handle.socket_stream_ = detail::make_unique( + handle.connection_->sock, read_timeout_sec_, read_timeout_usec_, + write_timeout_sec_, write_timeout_usec_); +#endif + handle.stream_ = handle.socket_stream_.get(); + + Request req; + req.method = method; + req.path = query_path; + req.headers = headers; + req.body = body; + + prepare_default_headers(req, true, content_type); + + auto &strm = *handle.stream_; + if (detail::write_request_line(strm, req.method, req.path) < 0) { + handle.error = Error::Write; + handle.response.reset(); + return handle; + } + + if (!detail::check_and_write_headers(strm, req.headers, header_writer_, + handle.error)) { + handle.response.reset(); + return handle; + } + + if (!body.empty()) { + if (strm.write(body.data(), body.size()) < 0) { + handle.error = Error::Write; + handle.response.reset(); + return handle; + } + } + + if (!read_response_line(strm, req, *handle.response) || + !detail::read_headers(strm, handle.response->headers)) { + handle.error = Error::Read; + handle.response.reset(); + return handle; + } + + handle.body_reader_.stream = handle.stream_; + + auto content_length_str = handle.response->get_header_value("Content-Length"); + if (!content_length_str.empty()) { + handle.body_reader_.content_length = + static_cast(std::stoull(content_length_str)); + } + + auto transfer_encoding = + handle.response->get_header_value("Transfer-Encoding"); + handle.body_reader_.chunked = (transfer_encoding == "chunked"); + + auto content_encoding = handle.response->get_header_value("Content-Encoding"); + if (!content_encoding.empty()) { + handle.decompressor_ = detail::create_decompressor(content_encoding); + } + + return handle; +} + +ssize_t ClientImpl::StreamHandle::read(char *buf, size_t len) { + if (!is_valid() || !response) { return -1; } + + if (decompressor_) { return read_with_decompression(buf, len); } + auto n = detail::read_body_content(stream_, body_reader_, buf, len); + + if (n <= 0 && body_reader_.chunked && !trailers_parsed_ && stream_) { + trailers_parsed_ = true; + if (body_reader_.chunked_decoder) { + if (!body_reader_.chunked_decoder->parse_trailers_into( + response->trailers, response->headers)) { + return n; + } + } else { + detail::ChunkedDecoder dec(*stream_); + if (!dec.parse_trailers_into(response->trailers, response->headers)) { + return n; + } + } + } + + return n; +} + +ssize_t ClientImpl::StreamHandle::read_with_decompression(char *buf, + size_t len) { + if (decompress_offset_ < decompress_buffer_.size()) { + auto available = decompress_buffer_.size() - decompress_offset_; + auto to_copy = (std::min)(len, available); + std::memcpy(buf, decompress_buffer_.data() + decompress_offset_, to_copy); + decompress_offset_ += to_copy; + return static_cast(to_copy); + } + + decompress_buffer_.clear(); + decompress_offset_ = 0; + + constexpr size_t kDecompressionBufferSize = 8192; + char compressed_buf[kDecompressionBufferSize]; + + while (true) { + auto n = detail::read_body_content(stream_, body_reader_, compressed_buf, + sizeof(compressed_buf)); + + if (n <= 0) { return n; } + + bool decompress_ok = + decompressor_->decompress(compressed_buf, static_cast(n), + [this](const char *data, size_t data_len) { + decompress_buffer_.append(data, data_len); + return true; + }); + + if (!decompress_ok) { + body_reader_.last_error = Error::Read; + return -1; + } + + if (!decompress_buffer_.empty()) { break; } + } + + auto to_copy = (std::min)(len, decompress_buffer_.size()); + std::memcpy(buf, decompress_buffer_.data(), to_copy); + decompress_offset_ = to_copy; + return static_cast(to_copy); +} + +void ClientImpl::StreamHandle::parse_trailers_if_needed() { + if (!response || !stream_ || !body_reader_.chunked || trailers_parsed_) { + return; + } + + trailers_parsed_ = true; + + const auto bufsiz = 128; + char line_buf[bufsiz]; + detail::stream_line_reader line_reader(*stream_, line_buf, bufsiz); + + if (!line_reader.getline()) { return; } + + if (!detail::parse_trailers(line_reader, response->trailers, + response->headers)) { + return; + } +} + +// Inline method implementations for `ChunkedDecoder`. +namespace detail { + +ChunkedDecoder::ChunkedDecoder(Stream &s) : strm(s) {} + +ssize_t ChunkedDecoder::read_payload(char *buf, size_t len, + size_t &out_chunk_offset, + size_t &out_chunk_total) { + if (finished) { return 0; } + + if (chunk_remaining == 0) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return -1; } + + char *endptr = nullptr; + unsigned long chunk_len = std::strtoul(lr.ptr(), &endptr, 16); + if (endptr == lr.ptr()) { return -1; } + if (chunk_len == ULONG_MAX) { return -1; } + + if (chunk_len == 0) { + chunk_remaining = 0; + finished = true; + out_chunk_offset = 0; + out_chunk_total = 0; + return 0; + } + + chunk_remaining = static_cast(chunk_len); + last_chunk_total = chunk_remaining; + last_chunk_offset = 0; + } + + auto to_read = (std::min)(chunk_remaining, len); + auto n = strm.read(buf, to_read); + if (n <= 0) { return -1; } + + auto offset_before = last_chunk_offset; + last_chunk_offset += static_cast(n); + chunk_remaining -= static_cast(n); + + out_chunk_offset = offset_before; + out_chunk_total = last_chunk_total; + + if (chunk_remaining == 0) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return -1; } + if (std::strcmp(lr.ptr(), "\r\n") != 0) { return -1; } + } + + return n; +} + +bool ChunkedDecoder::parse_trailers_into(Headers &dest, + const Headers &src_headers) { + stream_line_reader lr(strm, line_buf, sizeof(line_buf)); + if (!lr.getline()) { return false; } + return parse_trailers(lr, dest, src_headers); +} + +} // namespace detail + +void +ClientImpl::transfer_socket_ownership_to_handle(StreamHandle &handle) { + handle.connection_->sock = socket_.sock; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + handle.connection_->ssl = socket_.ssl; + socket_.ssl = nullptr; +#endif + socket_.sock = INVALID_SOCKET; +} + bool ClientImpl::handle_request(Stream &strm, Request &req, Response &res, bool close_connection, Error &error) { @@ -6227,9 +7184,11 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, if (!is_ssl() && !proxy_host_.empty() && proxy_port_ != -1) { auto req2 = req; - req2.path = "http://" + host_and_port_ + req.path; + req2.path = "http://" + + detail::make_host_and_port_string(host_, port_, false) + + req.path; ret = process_request(strm, req2, res, close_connection, error); - req = req2; + req = std::move(req2); req.path = req_save.path; } else { ret = process_request(strm, req, res, close_connection, error); @@ -6253,7 +7212,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, } if (300 < res.status && res.status < 400 && follow_location_) { - req = req_save; + req = std::move(req_save); ret = redirect(req, res, error); } @@ -6281,7 +7240,7 @@ bool ClientImpl::handle_request(Stream &strm, Request &req, Response new_res; ret = send(new_req, new_res, error); - if (ret) { res = new_res; } + if (ret) { res = std::move(new_res); } } } } @@ -6514,42 +7473,11 @@ bool ClientImpl::write_request(Stream &strm, Request &req, } } - if (!req.has_header("Host")) { - // For Unix socket connections, use "localhost" as Host header (similar to - // curl behavior) - if (address_family_ == AF_UNIX) { - req.set_header("Host", "localhost"); - } else { - req.set_header("Host", host_and_port_); - } + std::string ct_for_defaults; + if (!req.has_header("Content-Type") && !req.body.empty()) { + ct_for_defaults = "text/plain"; } - - if (!req.has_header("Accept")) { req.set_header("Accept", "*/*"); } - - if (!req.content_receiver) { - if (!req.has_header("Accept-Encoding")) { - std::string accept_encoding; -#ifdef CPPHTTPLIB_BROTLI_SUPPORT - accept_encoding = "br"; -#endif -#ifdef CPPHTTPLIB_ZLIB_SUPPORT - if (!accept_encoding.empty()) { accept_encoding += ", "; } - accept_encoding += "gzip, deflate"; -#endif -#ifdef CPPHTTPLIB_ZSTD_SUPPORT - if (!accept_encoding.empty()) { accept_encoding += ", "; } - accept_encoding += "zstd"; -#endif - req.set_header("Accept-Encoding", accept_encoding); - } - -#ifndef CPPHTTPLIB_NO_DEFAULT_USER_AGENT - if (!req.has_header("User-Agent")) { - auto agent = std::string("cpp-httplib/") + CPPHTTPLIB_VERSION; - req.set_header("User-Agent", agent); - } -#endif - }; + prepare_default_headers(req, false, ct_for_defaults); if (req.body.empty()) { if (req.content_provider_) { @@ -6565,15 +7493,6 @@ bool ClientImpl::write_request(Stream &strm, Request &req, req.set_header("Content-Length", "0"); } } - } else { - if (!req.has_header("Content-Type")) { - req.set_header("Content-Type", "text/plain"); - } - - if (!req.has_header("Content-Length")) { - auto length = std::to_string(req.body.size()); - req.set_header("Content-Length", length); - } } if (!basic_auth_password_.empty() || !basic_auth_username_.empty()) { @@ -6620,18 +7539,41 @@ bool ClientImpl::write_request(Stream &strm, Request &req, query_part = ""; } - // Encode path and query + // Encode path part. If the original `req.path` already contained a + // query component, preserve its raw query string (including parameter + // order) instead of reparsing and reassembling it which may reorder + // parameters due to container ordering (e.g. `Params` uses + // `std::multimap`). When there is no query in `req.path`, fall back to + // building a query from `req.params` so existing callers that pass + // `Params` continue to work. auto path_with_query = path_encode_ ? detail::encode_path(path_part) : path_part; - detail::parse_query_text(query_part, req.params); - if (!req.params.empty()) { - path_with_query = append_query_params(path_with_query, req.params); + if (!query_part.empty()) { + // Normalize the query string (decode then re-encode) while preserving + // the original parameter order. + auto normalized = detail::normalize_query_string(query_part); + if (!normalized.empty()) { path_with_query += '?' + normalized; } + + // Still populate req.params for handlers/users who read them. + detail::parse_query_text(query_part, req.params); + } else { + // No query in path; parse any query_part (empty) and append params + // from `req.params` when present (preserves prior behavior for + // callers who provide Params separately). + detail::parse_query_text(query_part, req.params); + if (!req.params.empty()) { + path_with_query = append_query_params(path_with_query, req.params); + } } // Write request line and headers detail::write_request_line(bstrm, req.method, path_with_query); - header_writer_(bstrm, req.headers); + if (!detail::check_and_write_headers(bstrm, req.headers, header_writer_, + error)) { + output_error_log(error, &req); + return false; + } // Flush buffer auto &data = bstrm.get_buffer(); @@ -8096,7 +9038,9 @@ bool SSLSocketStream::wait_writable() const { ssize_t SSLSocketStream::read(char *ptr, size_t size) { if (SSL_pending(ssl_) > 0) { - return SSL_read(ssl_, ptr, static_cast(size)); + auto ret = SSL_read(ssl_, ptr, static_cast(size)); + if (ret == 0) { error_ = Error::ConnectionClosed; } + return ret; } else if (wait_readable()) { auto ret = SSL_read(ssl_, ptr, static_cast(size)); if (ret < 0) { @@ -8121,9 +9065,12 @@ ssize_t SSLSocketStream::read(char *ptr, size_t size) { } } assert(ret < 0); + } else if (ret == 0) { + error_ = Error::ConnectionClosed; } return ret; } else { + error_ = Error::Timeout; return -1; } } @@ -8499,7 +9446,8 @@ bool SSLClient::connect_with_proxy( start_time, [&](Stream &strm) { Request req2; req2.method = "CONNECT"; - req2.path = host_and_port_; + req2.path = + detail::make_host_and_port_string_always_port(host_, port_); if (max_timeout_msec_ > 0) { req2.start_time_ = std::chrono::steady_clock::now(); } @@ -8526,7 +9474,7 @@ bool SSLClient::connect_with_proxy( close_socket(socket); // Create a new socket for the authenticated CONNECT request - if (!create_and_connect_socket(socket, error)) { + if (!ensure_socket_connection(socket, error)) { success = false; output_error_log(error, nullptr); return false; @@ -8539,7 +9487,8 @@ bool SSLClient::connect_with_proxy( start_time, [&](Stream &strm) { Request req3; req3.method = "CONNECT"; - req3.path = host_and_port_; + req3.path = detail::make_host_and_port_string_always_port( + host_, port_); req3.headers.insert(detail::make_digest_authentication_header( req3, auth, 1, detail::random_string(10), proxy_digest_auth_username_, proxy_digest_auth_password_, @@ -9424,6 +10373,13 @@ Result Client::Options(const std::string &path, const Headers &headers) { return cli_->Options(path, headers); } +ClientImpl::StreamHandle +Client::open_stream(const std::string &method, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type) { + return cli_->open_stream(method, path, params, headers, body, content_type); +} + bool Client::send(Request &req, Response &res, Error &error) { return cli_->send(req, res, error); } diff --git a/vendor/cpp-httplib/httplib.h b/vendor/cpp-httplib/httplib.h index c9bd9fd86b..43cdbc5832 100644 --- a/vendor/cpp-httplib/httplib.h +++ b/vendor/cpp-httplib/httplib.h @@ -1,15 +1,15 @@ // // httplib.h // -// Copyright (c) 2025 Yuji Hirose. All rights reserved. +// Copyright (c) 2026 Yuji Hirose. All rights reserved. // MIT License // #ifndef CPPHTTPLIB_HTTPLIB_H #define CPPHTTPLIB_HTTPLIB_H -#define CPPHTTPLIB_VERSION "0.28.0" -#define CPPHTTPLIB_VERSION_NUM "0x001C00" +#define CPPHTTPLIB_VERSION "0.30.0" +#define CPPHTTPLIB_VERSION_NUM "0x001E00" /* * Platform compatibility check @@ -838,6 +838,50 @@ struct Response { std::string file_content_content_type_; }; +enum class Error { + Success = 0, + Unknown, + Connection, + BindIPAddress, + Read, + Write, + ExceedRedirectCount, + Canceled, + SSLConnection, + SSLLoadingCerts, + SSLServerVerification, + SSLServerHostnameVerification, + UnsupportedMultipartBoundaryChars, + Compression, + ConnectionTimeout, + ProxyConnection, + ConnectionClosed, + Timeout, + ResourceExhaustion, + TooManyFormDataFiles, + ExceedMaxPayloadSize, + ExceedUriMaxLength, + ExceedMaxSocketDescriptorCount, + InvalidRequestLine, + InvalidHTTPMethod, + InvalidHTTPVersion, + InvalidHeaders, + MultipartParsing, + OpenFile, + Listen, + GetSockName, + UnsupportedAddressFamily, + HTTPParsing, + InvalidRangeHeader, + + // For internal use only + SSLPeerCouldBeClosed_, +}; + +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + class Stream { public: virtual ~Stream() = default; @@ -856,6 +900,11 @@ public: ssize_t write(const char *ptr); ssize_t write(const std::string &s); + + Error get_error() const { return error_; } + +protected: + Error error_ = Error::Success; }; class TaskQueue { @@ -873,6 +922,7 @@ class ThreadPool final : public TaskQueue { public: explicit ThreadPool(size_t n, size_t mqr = 0) : shutdown_(false), max_queued_requests_(mqr) { + threads_.reserve(n); while (n) { threads_.emplace_back(worker(*this)); n--; @@ -961,27 +1011,21 @@ using ErrorLogger = std::function; using SocketOptions = std::function; -namespace detail { - -bool set_socket_opt_impl(socket_t sock, int level, int optname, - const void *optval, socklen_t optlen); -bool set_socket_opt(socket_t sock, int level, int optname, int opt); -bool set_socket_opt_time(socket_t sock, int level, int optname, time_t sec, - time_t usec); - -} // namespace detail - void default_socket_options(socket_t sock); const char *status_message(int status); +std::string to_string(Error error); + +std::ostream &operator<<(std::ostream &os, const Error &obj); + std::string get_bearer_token_auth(const Request &req); namespace detail { class MatcherBase { public: - MatcherBase(std::string pattern) : pattern_(pattern) {} + MatcherBase(std::string pattern) : pattern_(std::move(pattern)) {} virtual ~MatcherBase() = default; const std::string &pattern() const { return pattern_; } @@ -1051,10 +1095,9 @@ private: std::regex regex_; }; -ssize_t write_headers(Stream &strm, const Headers &headers); +int close_socket(socket_t sock); -std::string make_host_and_port_string(const std::string &host, int port, - bool is_ssl); +ssize_t write_headers(Stream &strm, const Headers &headers); } // namespace detail @@ -1206,7 +1249,11 @@ private: bool listen_internal(); bool routing(Request &req, Response &res, Stream &strm); - bool handle_file_request(const Request &req, Response &res); + bool handle_file_request(Request &req, Response &res); + bool check_if_not_modified(const Request &req, Response &res, + const std::string &etag, time_t mtime) const; + bool check_if_range(Request &req, const std::string &etag, + time_t mtime) const; bool dispatch_request(Request &req, Response &res, const Handlers &handlers) const; bool dispatch_request_for_content_reader( @@ -1290,48 +1337,6 @@ private: detail::write_headers; }; -enum class Error { - Success = 0, - Unknown, - Connection, - BindIPAddress, - Read, - Write, - ExceedRedirectCount, - Canceled, - SSLConnection, - SSLLoadingCerts, - SSLServerVerification, - SSLServerHostnameVerification, - UnsupportedMultipartBoundaryChars, - Compression, - ConnectionTimeout, - ProxyConnection, - ResourceExhaustion, - TooManyFormDataFiles, - ExceedMaxPayloadSize, - ExceedUriMaxLength, - ExceedMaxSocketDescriptorCount, - InvalidRequestLine, - InvalidHTTPMethod, - InvalidHTTPVersion, - InvalidHeaders, - MultipartParsing, - OpenFile, - Listen, - GetSockName, - UnsupportedAddressFamily, - HTTPParsing, - InvalidRangeHeader, - - // For internal use only - SSLPeerCouldBeClosed_, -}; - -std::string to_string(Error error); - -std::ostream &operator<<(std::ostream &os, const Error &obj); - class Result { public: Result() = default; @@ -1390,6 +1395,87 @@ private: #endif }; +struct ClientConnection { + socket_t sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + SSL *ssl = nullptr; +#endif + + bool is_open() const { return sock != INVALID_SOCKET; } + + ClientConnection() = default; + + ~ClientConnection() { +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + if (ssl) { + SSL_free(ssl); + ssl = nullptr; + } +#endif + if (sock != INVALID_SOCKET) { + detail::close_socket(sock); + sock = INVALID_SOCKET; + } + } + + ClientConnection(const ClientConnection &) = delete; + ClientConnection &operator=(const ClientConnection &) = delete; + + ClientConnection(ClientConnection &&other) noexcept + : sock(other.sock) +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + , + ssl(other.ssl) +#endif + { + other.sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + other.ssl = nullptr; +#endif + } + + ClientConnection &operator=(ClientConnection &&other) noexcept { + if (this != &other) { + sock = other.sock; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + ssl = other.ssl; +#endif + other.sock = INVALID_SOCKET; +#ifdef CPPHTTPLIB_OPENSSL_SUPPORT + other.ssl = nullptr; +#endif + } + return *this; + } +}; + +namespace detail { + +struct ChunkedDecoder; + +struct BodyReader { + Stream *stream = nullptr; + size_t content_length = 0; + size_t bytes_read = 0; + bool chunked = false; + bool eof = false; + std::unique_ptr chunked_decoder; + Error last_error = Error::Success; + + ssize_t read(char *buf, size_t len); + bool has_error() const { return last_error != Error::Success; } +}; + +inline ssize_t read_body_content(Stream *stream, BodyReader &br, char *buf, + size_t len) { + (void)stream; + return br.read(buf, len); +} + +class decompressor; + +} // namespace detail + class ClientImpl { public: explicit ClientImpl(const std::string &host); @@ -1404,6 +1490,43 @@ public: virtual bool is_valid() const; + struct StreamHandle { + std::unique_ptr response; + Error error = Error::Success; + + StreamHandle() = default; + StreamHandle(const StreamHandle &) = delete; + StreamHandle &operator=(const StreamHandle &) = delete; + StreamHandle(StreamHandle &&) = default; + StreamHandle &operator=(StreamHandle &&) = default; + ~StreamHandle() = default; + + bool is_valid() const { + return response != nullptr && error == Error::Success; + } + + ssize_t read(char *buf, size_t len); + void parse_trailers_if_needed(); + Error get_read_error() const { return body_reader_.last_error; } + bool has_read_error() const { return body_reader_.has_error(); } + + bool trailers_parsed_ = false; + + private: + friend class ClientImpl; + + ssize_t read_with_decompression(char *buf, size_t len); + + std::unique_ptr connection_; + std::unique_ptr socket_stream_; + Stream *stream_ = nullptr; + detail::BodyReader body_reader_; + + std::unique_ptr decompressor_; + std::string decompress_buffer_; + size_t decompress_offset_ = 0; + }; + // clang-format off Result Get(const std::string &path, DownloadProgress progress = nullptr); Result Get(const std::string &path, ContentReceiver content_receiver, DownloadProgress progress = nullptr); @@ -1497,6 +1620,15 @@ public: Result Options(const std::string &path, const Headers &headers); // clang-format on + // Streaming API: Open a stream for reading response body incrementally + // Socket ownership is transferred to StreamHandle for true streaming + // Supports all HTTP methods (GET, POST, PUT, PATCH, DELETE, etc.) + StreamHandle open_stream(const std::string &method, const std::string &path, + const Params ¶ms = {}, + const Headers &headers = {}, + const std::string &body = {}, + const std::string &content_type = {}); + bool send(Request &req, Response &res, Error &error); Result send(const Request &req); @@ -1592,6 +1724,7 @@ protected: }; virtual bool create_and_connect_socket(Socket &socket, Error &error); + virtual bool ensure_socket_connection(Socket &socket, Error &error); // All of: // shutdown_ssl @@ -1618,7 +1751,6 @@ protected: // Socket endpoint information const std::string host_; const int port_; - const std::string host_and_port_; // Current open socket Socket socket_; @@ -1717,6 +1849,8 @@ private: Response &res) const; bool write_request(Stream &strm, Request &req, bool close_connection, Error &error); + void prepare_default_headers(Request &r, bool for_stream, + const std::string &ct); bool redirect(Request &req, Response &res, Error &error); bool create_redirect_client(const std::string &scheme, const std::string &host, int port, Request &req, @@ -1747,6 +1881,8 @@ private: std::chrono::time_point start_time, std::function callback); virtual bool is_ssl() const; + + void transfer_socket_ownership_to_handle(StreamHandle &handle); }; class Client { @@ -1865,6 +2001,16 @@ public: Result Options(const std::string &path, const Headers &headers); // clang-format on + // Streaming API: Open a stream for reading response body incrementally + // Socket ownership is transferred to StreamHandle for true streaming + // Supports all HTTP methods (GET, POST, PUT, PATCH, DELETE, etc.) + ClientImpl::StreamHandle open_stream(const std::string &method, + const std::string &path, + const Params ¶ms = {}, + const Headers &headers = {}, + const std::string &body = {}, + const std::string &content_type = {}); + bool send(Request &req, Response &res, Error &error); Result send(const Request &req); @@ -2027,6 +2173,7 @@ public: private: bool create_and_connect_socket(Socket &socket, Error &error) override; + bool ensure_socket_connection(Socket &socket, Error &error) override; void shutdown_ssl(Socket &socket, bool shutdown_gracefully) override; void shutdown_ssl_impl(Socket &socket, bool shutdown_gracefully); @@ -2163,82 +2310,6 @@ inline void default_socket_options(socket_t sock) { 1); } -inline const char *status_message(int status) { - switch (status) { - case StatusCode::Continue_100: return "Continue"; - case StatusCode::SwitchingProtocol_101: return "Switching Protocol"; - case StatusCode::Processing_102: return "Processing"; - case StatusCode::EarlyHints_103: return "Early Hints"; - case StatusCode::OK_200: return "OK"; - case StatusCode::Created_201: return "Created"; - case StatusCode::Accepted_202: return "Accepted"; - case StatusCode::NonAuthoritativeInformation_203: - return "Non-Authoritative Information"; - case StatusCode::NoContent_204: return "No Content"; - case StatusCode::ResetContent_205: return "Reset Content"; - case StatusCode::PartialContent_206: return "Partial Content"; - case StatusCode::MultiStatus_207: return "Multi-Status"; - case StatusCode::AlreadyReported_208: return "Already Reported"; - case StatusCode::IMUsed_226: return "IM Used"; - case StatusCode::MultipleChoices_300: return "Multiple Choices"; - case StatusCode::MovedPermanently_301: return "Moved Permanently"; - case StatusCode::Found_302: return "Found"; - case StatusCode::SeeOther_303: return "See Other"; - case StatusCode::NotModified_304: return "Not Modified"; - case StatusCode::UseProxy_305: return "Use Proxy"; - case StatusCode::unused_306: return "unused"; - case StatusCode::TemporaryRedirect_307: return "Temporary Redirect"; - case StatusCode::PermanentRedirect_308: return "Permanent Redirect"; - case StatusCode::BadRequest_400: return "Bad Request"; - case StatusCode::Unauthorized_401: return "Unauthorized"; - case StatusCode::PaymentRequired_402: return "Payment Required"; - case StatusCode::Forbidden_403: return "Forbidden"; - case StatusCode::NotFound_404: return "Not Found"; - case StatusCode::MethodNotAllowed_405: return "Method Not Allowed"; - case StatusCode::NotAcceptable_406: return "Not Acceptable"; - case StatusCode::ProxyAuthenticationRequired_407: - return "Proxy Authentication Required"; - case StatusCode::RequestTimeout_408: return "Request Timeout"; - case StatusCode::Conflict_409: return "Conflict"; - case StatusCode::Gone_410: return "Gone"; - case StatusCode::LengthRequired_411: return "Length Required"; - case StatusCode::PreconditionFailed_412: return "Precondition Failed"; - case StatusCode::PayloadTooLarge_413: return "Payload Too Large"; - case StatusCode::UriTooLong_414: return "URI Too Long"; - case StatusCode::UnsupportedMediaType_415: return "Unsupported Media Type"; - case StatusCode::RangeNotSatisfiable_416: return "Range Not Satisfiable"; - case StatusCode::ExpectationFailed_417: return "Expectation Failed"; - case StatusCode::ImATeapot_418: return "I'm a teapot"; - case StatusCode::MisdirectedRequest_421: return "Misdirected Request"; - case StatusCode::UnprocessableContent_422: return "Unprocessable Content"; - case StatusCode::Locked_423: return "Locked"; - case StatusCode::FailedDependency_424: return "Failed Dependency"; - case StatusCode::TooEarly_425: return "Too Early"; - case StatusCode::UpgradeRequired_426: return "Upgrade Required"; - case StatusCode::PreconditionRequired_428: return "Precondition Required"; - case StatusCode::TooManyRequests_429: return "Too Many Requests"; - case StatusCode::RequestHeaderFieldsTooLarge_431: - return "Request Header Fields Too Large"; - case StatusCode::UnavailableForLegalReasons_451: - return "Unavailable For Legal Reasons"; - case StatusCode::NotImplemented_501: return "Not Implemented"; - case StatusCode::BadGateway_502: return "Bad Gateway"; - case StatusCode::ServiceUnavailable_503: return "Service Unavailable"; - case StatusCode::GatewayTimeout_504: return "Gateway Timeout"; - case StatusCode::HttpVersionNotSupported_505: - return "HTTP Version Not Supported"; - case StatusCode::VariantAlsoNegotiates_506: return "Variant Also Negotiates"; - case StatusCode::InsufficientStorage_507: return "Insufficient Storage"; - case StatusCode::LoopDetected_508: return "Loop Detected"; - case StatusCode::NotExtended_510: return "Not Extended"; - case StatusCode::NetworkAuthenticationRequired_511: - return "Network Authentication Required"; - - default: - case StatusCode::InternalServerError_500: return "Internal Server Error"; - } -} - inline std::string get_bearer_token_auth(const Request &req) { if (req.has_header("Authorization")) { constexpr auto bearer_header_prefix_len = detail::str_len("Bearer "); @@ -2272,55 +2343,6 @@ Server::set_idle_interval(const std::chrono::duration &duration) { return *this; } -inline std::string to_string(const Error error) { - switch (error) { - case Error::Success: return "Success (no error)"; - case Error::Unknown: return "Unknown"; - case Error::Connection: return "Could not establish connection"; - case Error::BindIPAddress: return "Failed to bind IP address"; - case Error::Read: return "Failed to read connection"; - case Error::Write: return "Failed to write connection"; - case Error::ExceedRedirectCount: return "Maximum redirect count exceeded"; - case Error::Canceled: return "Connection handling canceled"; - case Error::SSLConnection: return "SSL connection failed"; - case Error::SSLLoadingCerts: return "SSL certificate loading failed"; - case Error::SSLServerVerification: return "SSL server verification failed"; - case Error::SSLServerHostnameVerification: - return "SSL server hostname verification failed"; - case Error::UnsupportedMultipartBoundaryChars: - return "Unsupported HTTP multipart boundary characters"; - case Error::Compression: return "Compression failed"; - case Error::ConnectionTimeout: return "Connection timed out"; - case Error::ProxyConnection: return "Proxy connection failed"; - case Error::ResourceExhaustion: return "Resource exhaustion"; - case Error::TooManyFormDataFiles: return "Too many form data files"; - case Error::ExceedMaxPayloadSize: return "Exceeded maximum payload size"; - case Error::ExceedUriMaxLength: return "Exceeded maximum URI length"; - case Error::ExceedMaxSocketDescriptorCount: - return "Exceeded maximum socket descriptor count"; - case Error::InvalidRequestLine: return "Invalid request line"; - case Error::InvalidHTTPMethod: return "Invalid HTTP method"; - case Error::InvalidHTTPVersion: return "Invalid HTTP version"; - case Error::InvalidHeaders: return "Invalid headers"; - case Error::MultipartParsing: return "Multipart parsing failed"; - case Error::OpenFile: return "Failed to open file"; - case Error::Listen: return "Failed to listen on socket"; - case Error::GetSockName: return "Failed to get socket name"; - case Error::UnsupportedAddressFamily: return "Unsupported address family"; - case Error::HTTPParsing: return "HTTP parsing failed"; - case Error::InvalidRangeHeader: return "Invalid Range header"; - default: break; - } - - return "Invalid"; -} - -inline std::ostream &operator<<(std::ostream &os, const Error &obj) { - os << to_string(obj); - os << " (" << static_cast::type>(obj) << ')'; - return os; -} - inline size_t Result::get_request_header_value_u64(const std::string &key, size_t def, size_t id) const { @@ -2439,6 +2461,8 @@ struct FileStat { FileStat(const std::string &path); bool is_file() const; bool is_dir() const; + time_t mtime() const; + size_t size() const; private: #if defined(_WIN32) @@ -2449,6 +2473,9 @@ private: int ret_ = -1; }; +std::string make_host_and_port_string(const std::string &host, int port, + bool is_ssl); + std::string trim_copy(const std::string &s); void divide( @@ -2669,6 +2696,25 @@ private: std::string growable_buffer_; }; +bool parse_trailers(stream_line_reader &line_reader, Headers &dest, + const Headers &src_headers); + +struct ChunkedDecoder { + Stream &strm; + size_t chunk_remaining = 0; + bool finished = false; + char line_buf[64]; + size_t last_chunk_total = 0; + size_t last_chunk_offset = 0; + + explicit ChunkedDecoder(Stream &s); + + ssize_t read_payload(char *buf, size_t len, size_t &out_chunk_offset, + size_t &out_chunk_total); + + bool parse_trailers_into(Headers &dest, const Headers &src_headers); +}; + class mmap { public: mmap(const char *path); @@ -2696,59 +2742,669 @@ private: // NOTE: https://www.rfc-editor.org/rfc/rfc9110#section-5 namespace fields { -inline bool is_token_char(char c) { - return std::isalnum(c) || c == '!' || c == '#' || c == '$' || c == '%' || - c == '&' || c == '\'' || c == '*' || c == '+' || c == '-' || - c == '.' || c == '^' || c == '_' || c == '`' || c == '|' || c == '~'; -} - -inline bool is_token(const std::string &s) { - if (s.empty()) { return false; } - for (auto c : s) { - if (!is_token_char(c)) { return false; } - } - return true; -} - -inline bool is_field_name(const std::string &s) { return is_token(s); } - -inline bool is_vchar(char c) { return c >= 33 && c <= 126; } - -inline bool is_obs_text(char c) { return 128 <= static_cast(c); } - -inline bool is_field_vchar(char c) { return is_vchar(c) || is_obs_text(c); } - -inline bool is_field_content(const std::string &s) { - if (s.empty()) { return true; } - - if (s.size() == 1) { - return is_field_vchar(s[0]); - } else if (s.size() == 2) { - return is_field_vchar(s[0]) && is_field_vchar(s[1]); - } else { - size_t i = 0; - - if (!is_field_vchar(s[i])) { return false; } - i++; - - while (i < s.size() - 1) { - auto c = s[i++]; - if (c == ' ' || c == '\t' || is_field_vchar(c)) { - } else { - return false; - } - } - - return is_field_vchar(s[i]); - } -} - -inline bool is_field_value(const std::string &s) { return is_field_content(s); } +bool is_token_char(char c); +bool is_token(const std::string &s); +bool is_field_name(const std::string &s); +bool is_vchar(char c); +bool is_obs_text(char c); +bool is_field_vchar(char c); +bool is_field_content(const std::string &s); +bool is_field_value(const std::string &s); } // namespace fields } // namespace detail +namespace stream { + +class Result { +public: + Result() : chunk_size_(8192) {} + + explicit Result(ClientImpl::StreamHandle &&handle, size_t chunk_size = 8192) + : handle_(std::move(handle)), chunk_size_(chunk_size) {} + + Result(Result &&other) noexcept + : handle_(std::move(other.handle_)), buffer_(std::move(other.buffer_)), + current_size_(other.current_size_), chunk_size_(other.chunk_size_), + finished_(other.finished_) { + other.current_size_ = 0; + other.finished_ = true; + } + + Result &operator=(Result &&other) noexcept { + if (this != &other) { + handle_ = std::move(other.handle_); + buffer_ = std::move(other.buffer_); + current_size_ = other.current_size_; + chunk_size_ = other.chunk_size_; + finished_ = other.finished_; + other.current_size_ = 0; + other.finished_ = true; + } + return *this; + } + + Result(const Result &) = delete; + Result &operator=(const Result &) = delete; + + // Check if the result is valid (connection succeeded and response received) + bool is_valid() const { return handle_.is_valid(); } + explicit operator bool() const { return is_valid(); } + + // Response status code + int status() const { + return handle_.response ? handle_.response->status : -1; + } + + // Response headers + const Headers &headers() const { + static const Headers empty_headers; + return handle_.response ? handle_.response->headers : empty_headers; + } + + std::string get_header_value(const std::string &key, + const char *def = "") const { + return handle_.response ? handle_.response->get_header_value(key, def) + : def; + } + + bool has_header(const std::string &key) const { + return handle_.response ? handle_.response->has_header(key) : false; + } + + // Error information + Error error() const { return handle_.error; } + Error read_error() const { return handle_.get_read_error(); } + bool has_read_error() const { return handle_.has_read_error(); } + + // Streaming iteration API + // Call next() to read the next chunk, then access data via data()/size() + // Returns true if data was read, false when stream is exhausted + bool next() { + if (!handle_.is_valid() || finished_) { return false; } + + if (buffer_.size() < chunk_size_) { buffer_.resize(chunk_size_); } + + ssize_t n = handle_.read(&buffer_[0], chunk_size_); + if (n > 0) { + current_size_ = static_cast(n); + return true; + } + + current_size_ = 0; + finished_ = true; + return false; + } + + // Pointer to current chunk data (valid after next() returns true) + const char *data() const { return buffer_.data(); } + + // Size of current chunk (valid after next() returns true) + size_t size() const { return current_size_; } + + // Convenience method: read all remaining data into a string + std::string read_all() { + std::string result; + while (next()) { + result.append(data(), size()); + } + return result; + } + +private: + ClientImpl::StreamHandle handle_; + std::string buffer_; + size_t current_size_ = 0; + size_t chunk_size_; + bool finished_ = false; +}; + +// GET +template +inline Result Get(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path), chunk_size}; +} + +template +inline Result Get(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, {}, headers), chunk_size}; +} + +template +inline Result Get(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, params), chunk_size}; +} + +template +inline Result Get(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("GET", path, params, headers), chunk_size}; +} + +// POST +template +inline Result Post(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Post(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Post(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("POST", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Post(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("POST", path, params, headers, body, content_type), + chunk_size}; +} + +// PUT +template +inline Result Put(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PUT", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Put(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("PUT", path, params, headers, body, content_type), + chunk_size}; +} + +// PATCH +template +inline Result Patch(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, size_t chunk_size = 8192) { + return Result{cli.open_stream("PATCH", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Patch(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("PATCH", path, params, headers, body, content_type), + chunk_size}; +} + +// DELETE +template +inline Result Delete(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, {}, headers), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, {}, {}, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Headers &headers, const std::string &body, + const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("DELETE", path, {}, headers, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params, headers), chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const std::string &body, + const std::string &content_type, + size_t chunk_size = 8192) { + return Result{cli.open_stream("DELETE", path, params, {}, body, content_type), + chunk_size}; +} + +template +inline Result Delete(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + const std::string &body, const std::string &content_type, + size_t chunk_size = 8192) { + return Result{ + cli.open_stream("DELETE", path, params, headers, body, content_type), + chunk_size}; +} + +// HEAD +template +inline Result Head(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, {}, headers), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, params), chunk_size}; +} + +template +inline Result Head(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("HEAD", path, params, headers), chunk_size}; +} + +// OPTIONS +template +inline Result Options(ClientType &cli, const std::string &path, + size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Headers &headers, size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, {}, headers), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Params ¶ms, size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, params), chunk_size}; +} + +template +inline Result Options(ClientType &cli, const std::string &path, + const Params ¶ms, const Headers &headers, + size_t chunk_size = 8192) { + return Result{cli.open_stream("OPTIONS", path, params, headers), chunk_size}; +} + +} // namespace stream + +namespace sse { + +struct SSEMessage { + std::string event; // Event type (default: "message") + std::string data; // Event payload + std::string id; // Event ID for Last-Event-ID header + + SSEMessage() : event("message") {} + + void clear() { + event = "message"; + data.clear(); + id.clear(); + } +}; + +class SSEClient { +public: + using MessageHandler = std::function; + using ErrorHandler = std::function; + using OpenHandler = std::function; + + SSEClient(Client &client, const std::string &path) + : client_(client), path_(path) {} + + SSEClient(Client &client, const std::string &path, const Headers &headers) + : client_(client), path_(path), headers_(headers) {} + + ~SSEClient() { stop(); } + + SSEClient(const SSEClient &) = delete; + SSEClient &operator=(const SSEClient &) = delete; + + // Event handlers + SSEClient &on_message(MessageHandler handler) { + on_message_ = std::move(handler); + return *this; + } + + SSEClient &on_event(const std::string &type, MessageHandler handler) { + event_handlers_[type] = std::move(handler); + return *this; + } + + SSEClient &on_open(OpenHandler handler) { + on_open_ = std::move(handler); + return *this; + } + + SSEClient &on_error(ErrorHandler handler) { + on_error_ = std::move(handler); + return *this; + } + + SSEClient &set_reconnect_interval(int ms) { + reconnect_interval_ms_ = ms; + return *this; + } + + SSEClient &set_max_reconnect_attempts(int n) { + max_reconnect_attempts_ = n; + return *this; + } + + // State accessors + bool is_connected() const { return connected_.load(); } + const std::string &last_event_id() const { return last_event_id_; } + + // Blocking start - runs event loop with auto-reconnect + void start() { + running_.store(true); + run_event_loop(); + } + + // Non-blocking start - runs in background thread + void start_async() { + running_.store(true); + async_thread_ = std::thread([this]() { run_event_loop(); }); + } + + // Stop the client (thread-safe) + void stop() { + running_.store(false); + client_.stop(); // Cancel any pending operations + if (async_thread_.joinable()) { async_thread_.join(); } + } + +private: + // Parse a single SSE field line + // Returns true if this line ends an event (blank line) + bool parse_sse_line(const std::string &line, SSEMessage &msg, int &retry_ms) { + // Blank line signals end of event + if (line.empty() || line == "\r") { return true; } + + // Lines starting with ':' are comments (ignored) + if (!line.empty() && line[0] == ':') { return false; } + + // Find the colon separator + auto colon_pos = line.find(':'); + if (colon_pos == std::string::npos) { + // Line with no colon is treated as field name with empty value + return false; + } + + auto field = line.substr(0, colon_pos); + std::string value; + + // Value starts after colon, skip optional single space + if (colon_pos + 1 < line.size()) { + auto value_start = colon_pos + 1; + if (line[value_start] == ' ') { value_start++; } + value = line.substr(value_start); + // Remove trailing \r if present + if (!value.empty() && value.back() == '\r') { value.pop_back(); } + } + + // Handle known fields + if (field == "event") { + msg.event = value; + } else if (field == "data") { + // Multiple data lines are concatenated with newlines + if (!msg.data.empty()) { msg.data += "\n"; } + msg.data += value; + } else if (field == "id") { + // Empty id is valid (clears the last event ID) + msg.id = value; + } else if (field == "retry") { + // Parse retry interval in milliseconds + try { + retry_ms = std::stoi(value); + } catch (...) { + // Invalid retry value, ignore + } + } + // Unknown fields are ignored per SSE spec + + return false; + } + + // Main event loop with auto-reconnect + void run_event_loop() { + auto reconnect_count = 0; + + while (running_.load()) { + // Build headers, including Last-Event-ID if we have one + auto request_headers = headers_; + if (!last_event_id_.empty()) { + request_headers.emplace("Last-Event-ID", last_event_id_); + } + + // Open streaming connection + auto result = stream::Get(client_, path_, request_headers); + + // Connection error handling + if (!result) { + connected_.store(false); + if (on_error_) { on_error_(result.error()); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + if (result.status() != 200) { + connected_.store(false); + // For certain errors, don't reconnect + if (result.status() == 204 || // No Content - server wants us to stop + result.status() == 404 || // Not Found + result.status() == 401 || // Unauthorized + result.status() == 403) { // Forbidden + if (on_error_) { on_error_(Error::Connection); } + break; + } + + if (on_error_) { on_error_(Error::Connection); } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + continue; + } + + // Connection successful + connected_.store(true); + reconnect_count = 0; + if (on_open_) { on_open_(); } + + // Event receiving loop + std::string buffer; + SSEMessage current_msg; + + while (running_.load() && result.next()) { + buffer.append(result.data(), result.size()); + + // Process complete lines in the buffer + size_t line_start = 0; + size_t newline_pos; + + while ((newline_pos = buffer.find('\n', line_start)) != + std::string::npos) { + auto line = buffer.substr(line_start, newline_pos - line_start); + line_start = newline_pos + 1; + + // Parse the line and check if event is complete + auto event_complete = + parse_sse_line(line, current_msg, reconnect_interval_ms_); + + if (event_complete && !current_msg.data.empty()) { + // Update last_event_id for reconnection + if (!current_msg.id.empty()) { last_event_id_ = current_msg.id; } + + // Dispatch event to appropriate handler + dispatch_event(current_msg); + + current_msg.clear(); + } + } + + // Keep unprocessed data in buffer + buffer.erase(0, line_start); + } + + // Connection ended + connected_.store(false); + + if (!running_.load()) { break; } + + // Check for read errors + if (result.has_read_error()) { + if (on_error_) { on_error_(result.read_error()); } + } + + if (!should_reconnect(reconnect_count)) { break; } + wait_for_reconnect(); + reconnect_count++; + } + + connected_.store(false); + } + + // Dispatch event to appropriate handler + void dispatch_event(const SSEMessage &msg) { + // Check for specific event type handler first + auto it = event_handlers_.find(msg.event); + if (it != event_handlers_.end()) { + it->second(msg); + return; + } + + // Fall back to generic message handler + if (on_message_) { on_message_(msg); } + } + + // Check if we should attempt to reconnect + bool should_reconnect(int count) const { + if (!running_.load()) { return false; } + if (max_reconnect_attempts_ == 0) { return true; } // unlimited + return count < max_reconnect_attempts_; + } + + // Wait for reconnect interval + void wait_for_reconnect() { + // Use small increments to check running_ flag frequently + auto waited = 0; + while (running_.load() && waited < reconnect_interval_ms_) { + std::this_thread::sleep_for(std::chrono::milliseconds(100)); + waited += 100; + } + } + + // Client and path + Client &client_; + std::string path_; + Headers headers_; + + // Callbacks + MessageHandler on_message_; + std::map event_handlers_; + OpenHandler on_open_; + ErrorHandler on_error_; + + // Configuration + int reconnect_interval_ms_ = 3000; + int max_reconnect_attempts_ = 0; // 0 = unlimited + + // State + std::atomic running_{false}; + std::atomic connected_{false}; + std::string last_event_id_; + + // Async support + std::thread async_thread_; +}; + +} // namespace sse + } // namespace httplib