diff --git a/components/core/src/clp/CurlDownloadHandler.cpp b/components/core/src/clp/CurlDownloadHandler.cpp index 9a2720083..d1b88758a 100644 --- a/components/core/src/clp/CurlDownloadHandler.cpp +++ b/components/core/src/clp/CurlDownloadHandler.cpp @@ -1,13 +1,22 @@ #include "CurlDownloadHandler.hpp" +#include +#include +#include #include #include #include +#include #include #include +#include +#include #include #include +#include + +#include "ErrorCode.hpp" namespace clp { CurlDownloadHandler::CurlDownloadHandler( @@ -19,7 +28,8 @@ CurlDownloadHandler::CurlDownloadHandler( size_t offset, bool disable_caching, std::chrono::seconds connection_timeout, - std::chrono::seconds overall_timeout + std::chrono::seconds overall_timeout, + std::optional> const& http_header_kv_pairs ) : m_error_msg_buf{std::move(error_msg_buf)} { if (nullptr != m_error_msg_buf) { @@ -48,13 +58,55 @@ CurlDownloadHandler::CurlDownloadHandler( m_easy_handle.set_option(CURLOPT_TIMEOUT, static_cast(overall_timeout.count())); // Set up http headers + constexpr std::string_view cRangeHeaderName{"range"}; + constexpr std::string_view cCacheControlHeaderName{"cache-control"}; + constexpr std::string_view cPragmaHeaderName{"pragma"}; + std::unordered_set const reserved_headers{ + cRangeHeaderName, + cCacheControlHeaderName, + cPragmaHeaderName + }; if (0 != offset) { - std::string const range{"Range: bytes=" + std::to_string(offset) + "-"}; - m_http_headers.append(range); + m_http_headers.append(fmt::format("{}: bytes={}-", cRangeHeaderName, offset)); } if (disable_caching) { - m_http_headers.append("Cache-Control: no-cache"); - m_http_headers.append("Pragma: no-cache"); + m_http_headers.append(fmt::format("{}: no-cache", cCacheControlHeaderName)); + m_http_headers.append(fmt::format("{}: no-cache", cPragmaHeaderName)); + } + if (http_header_kv_pairs.has_value()) { + for (auto const& [key, value] : http_header_kv_pairs.value()) { + // HTTP header field-name (key) is case-insensitive: + // https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 + // Therefore, we convert keys to lowercase for comparison with the reserved keys. + // NOTE: We do not check for duplicate keys due to case insensitivity, leaving duplicate + // handling to the server. + auto lower_key{key}; + std::transform( + lower_key.begin(), + lower_key.end(), + lower_key.begin(), + [](unsigned char c) -> char { + // Implicitly cast the input character into `unsigned char` to avoid UB: + // https://en.cppreference.com/w/cpp/string/byte/tolower + return static_cast(std::tolower(c)); + } + ); + if (reserved_headers.contains(lower_key) || value.ends_with("\r\n")) { + throw CurlOperationFailed( + ErrorCode_Failure, + __FILE__, + __LINE__, + CURLE_BAD_FUNCTION_ARGUMENT, + fmt::format( + "`CurlDownloadHandler` failed to construct with the following " + "invalid header: {}:{}", + key, + value + ) + ); + } + m_http_headers.append(fmt::format("{}: {}", key, value)); + } } if (false == m_http_headers.is_empty()) { m_easy_handle.set_option(CURLOPT_HTTPHEADER, m_http_headers.get_raw_list()); diff --git a/components/core/src/clp/CurlDownloadHandler.hpp b/components/core/src/clp/CurlDownloadHandler.hpp index 6421257ba..e7c4b73a8 100644 --- a/components/core/src/clp/CurlDownloadHandler.hpp +++ b/components/core/src/clp/CurlDownloadHandler.hpp @@ -5,7 +5,10 @@ #include #include #include +#include +#include #include +#include #include @@ -53,6 +56,9 @@ class CurlDownloadHandler { * Doc: https://curl.se/libcurl/c/CURLOPT_CONNECTTIMEOUT.html * @param overall_timeout Maximum time that the transfer may take. Note that this includes * `connection_timeout`. Doc: https://curl.se/libcurl/c/CURLOPT_TIMEOUT.html + * @param http_header_kv_pairs Key-value pairs representing HTTP headers to pass to the server + * in the download request. Doc: https://curl.se/libcurl/c/CURLOPT_HTTPHEADER.html + * @throw CurlOperationFailed if an error occurs. */ explicit CurlDownloadHandler( std::shared_ptr error_msg_buf, @@ -63,7 +69,9 @@ class CurlDownloadHandler { size_t offset = 0, bool disable_caching = false, std::chrono::seconds connection_timeout = cDefaultConnectionTimeout, - std::chrono::seconds overall_timeout = cDefaultOverallTimeout + std::chrono::seconds overall_timeout = cDefaultOverallTimeout, + std::optional> const& http_header_kv_pairs + = std::nullopt ); // Disable copy/move constructors/assignment operators diff --git a/components/core/src/clp/NetworkReader.cpp b/components/core/src/clp/NetworkReader.cpp index cdde759c2..086b60681 100644 --- a/components/core/src/clp/NetworkReader.cpp +++ b/components/core/src/clp/NetworkReader.cpp @@ -6,7 +6,10 @@ #include #include #include +#include #include +#include +#include #include @@ -118,7 +121,8 @@ NetworkReader::NetworkReader( std::chrono::seconds overall_timeout, std::chrono::seconds connection_timeout, size_t buffer_pool_size, - size_t buffer_size + size_t buffer_size, + std::optional> http_header_kv_pairs ) : m_src_url{src_url}, m_offset{offset}, @@ -130,7 +134,12 @@ NetworkReader::NetworkReader( for (size_t i = 0; i < m_buffer_pool_size; ++i) { m_buffer_pool.emplace_back(m_buffer_size); } - m_downloader_thread = std::make_unique(*this, offset, disable_caching); + m_downloader_thread = std::make_unique( + *this, + offset, + disable_caching, + std::move(http_header_kv_pairs) + ); m_downloader_thread->start(); } @@ -215,7 +224,8 @@ auto NetworkReader::DownloaderThread::thread_method() -> void { m_offset, m_disable_caching, m_reader.m_connection_timeout, - m_reader.m_overall_timeout + m_reader.m_overall_timeout, + m_http_header_kv_pairs }; auto const ret_code{curl_handler.perform()}; // Enqueue the last filled buffer, if any diff --git a/components/core/src/clp/NetworkReader.hpp b/components/core/src/clp/NetworkReader.hpp index 7c808fd4f..08be975ea 100644 --- a/components/core/src/clp/NetworkReader.hpp +++ b/components/core/src/clp/NetworkReader.hpp @@ -13,6 +13,8 @@ #include #include #include +#include +#include #include #include @@ -94,6 +96,8 @@ class NetworkReader : public ReaderInterface { * Doc: https://curl.se/libcurl/c/CURLOPT_CONNECTTIMEOUT.html * @param buffer_pool_size The required number of buffers in the buffer pool. * @param buffer_size The size of each buffer in the buffer pool. + * @param http_header_kv_pairs Key-value pairs representing HTTP headers to pass to the server + * in the download request. Doc: https://curl.se/libcurl/c/CURLOPT_HTTPHEADER.html */ explicit NetworkReader( std::string_view src_url, @@ -103,7 +107,9 @@ class NetworkReader : public ReaderInterface { std::chrono::seconds connection_timeout = CurlDownloadHandler::cDefaultConnectionTimeout, size_t buffer_pool_size = cDefaultBufferPoolSize, - size_t buffer_size = cDefaultBufferSize + size_t buffer_size = cDefaultBufferSize, + std::optional> http_header_kv_pairs + = std::nullopt ); // Destructor @@ -242,11 +248,19 @@ class NetworkReader : public ReaderInterface { * @param reader * @param offset Index of the byte at which to start the download. * @param disable_caching Whether to disable caching. + * @param http_header_kv_pairs Key-value pairs representing HTTP headers to pass to the + * server in the download request. Doc: https://curl.se/libcurl/c/CURLOPT_HTTPHEADER.html */ - DownloaderThread(NetworkReader& reader, size_t offset, bool disable_caching) + DownloaderThread( + NetworkReader& reader, + size_t offset, + bool disable_caching, + std::optional> http_header_kv_pairs + ) : m_reader{reader}, m_offset{offset}, - m_disable_caching{disable_caching} {} + m_disable_caching{disable_caching}, + m_http_header_kv_pairs{std::move(http_header_kv_pairs)} {} private: // Methods implementing `clp::Thread` @@ -255,6 +269,7 @@ class NetworkReader : public ReaderInterface { NetworkReader& m_reader; size_t m_offset{0}; bool m_disable_caching{false}; + std::optional> m_http_header_kv_pairs; }; /** diff --git a/components/core/tests/test-NetworkReader.cpp b/components/core/tests/test-NetworkReader.cpp index cd4b90cc0..f32daef14 100644 --- a/components/core/tests/test-NetworkReader.cpp +++ b/components/core/tests/test-NetworkReader.cpp @@ -6,10 +6,13 @@ #include #include #include +#include #include #include #include +#include +#include #include "../src/clp/Array.hpp" #include "../src/clp/CurlDownloadHandler.hpp" @@ -188,3 +191,58 @@ TEST_CASE("network_reader_illegal_offset", "[NetworkReader]") { size_t pos{}; REQUIRE((clp::ErrorCode_Failure == reader.try_get_pos(pos))); } + +TEST_CASE("network_reader_with_valid_http_header_kv_pairs", "[NetworkReader]") { + std::unordered_map valid_http_header_kv_pairs; + // We use httpbin (https://httpbin.org/) to test the user-specified headers. On success, it is + // supposed to respond all the user-specified headers as key-value pairs in JSON form. + constexpr int cNumHttpHeaderKeyValuePairs{10}; + for (size_t i{0}; i < cNumHttpHeaderKeyValuePairs; ++i) { + valid_http_header_kv_pairs.emplace( + fmt::format("Unit-Test-Key{}", i), + fmt::format("Unit-Test-Value{}", i) + ); + } + clp::NetworkReader reader{ + "https://httpbin.org/headers", + 0, + false, + clp::CurlDownloadHandler::cDefaultOverallTimeout, + clp::CurlDownloadHandler::cDefaultConnectionTimeout, + clp::NetworkReader::cDefaultBufferPoolSize, + clp::NetworkReader::cDefaultBufferSize, + valid_http_header_kv_pairs + }; + auto const content = nlohmann::json::parse(get_content(reader)); + auto const& headers{content.at("headers")}; + REQUIRE(assert_curl_error_code(CURLE_OK, reader)); + for (auto const& [key, value] : valid_http_header_kv_pairs) { + REQUIRE((value == headers.at(key).get())); + } +} + +TEST_CASE("network_reader_with_illegal_http_header_kv_pairs", "[NetworkReader]") { + auto illegal_header_kv_pairs = GENERATE( + // The following headers are determined by offset and disable_cache, which should not be + // overridden by user-defined headers. + std::unordered_map{{"Range", "bytes=100-"}}, + std::unordered_map{{"RAnGe", "bytes=100-"}}, + std::unordered_map{{"Cache-Control", "no-cache"}}, + std::unordered_map{{"Pragma", "no-cache"}}, + // The CRLF-terminated headers should be rejected. + std::unordered_map{{"Legal-Name", "CRLF\r\n"}} + ); + clp::NetworkReader reader{ + "https://httpbin.org/headers", + 0, + false, + clp::CurlDownloadHandler::cDefaultOverallTimeout, + clp::CurlDownloadHandler::cDefaultConnectionTimeout, + clp::NetworkReader::cDefaultBufferPoolSize, + clp::NetworkReader::cDefaultBufferSize, + illegal_header_kv_pairs + }; + auto const content = get_content(reader); + REQUIRE(content.empty()); + REQUIRE(assert_curl_error_code(CURLE_BAD_FUNCTION_ARGUMENT, reader)); +}