Parse and use extra HTTP headers when downloading the payload.
Android OTA backend requires to pass an Authorization HTTP header in
order to download some payload. This patch allows to specify such
header when initiating a payload download from Android.
Bug: 27047110
TEST=Added unittests to check the headers sent.
Change-Id: Iece7e0ee252349bbaa9fb8545da2c34d2a76ae69
diff --git a/common/constants.cc b/common/constants.cc
index fc6df37..f138ce3 100644
--- a/common/constants.cc
+++ b/common/constants.cc
@@ -94,5 +94,7 @@
const char kPayloadPropertyFileHash[] = "FILE_HASH";
const char kPayloadPropertyMetadataSize[] = "METADATA_SIZE";
const char kPayloadPropertyMetadataHash[] = "METADATA_HASH";
+const char kPayloadPropertyAuthorization[] = "AUTHORIZATION";
+const char kPayloadPropertyUserAgent[] = "USER_AGENT";
} // namespace chromeos_update_engine
diff --git a/common/constants.h b/common/constants.h
index 25d587b..f0d589d 100644
--- a/common/constants.h
+++ b/common/constants.h
@@ -97,6 +97,8 @@
extern const char kPayloadPropertyFileHash[];
extern const char kPayloadPropertyMetadataSize[];
extern const char kPayloadPropertyMetadataHash[];
+extern const char kPayloadPropertyAuthorization[];
+extern const char kPayloadPropertyUserAgent[];
// A download source is any combination of protocol and server (that's of
// interest to us when looking at UMA metrics) using which we may download
diff --git a/common/http_fetcher.h b/common/http_fetcher.h
index 11e8e9f..d2499eb 100644
--- a/common/http_fetcher.h
+++ b/common/http_fetcher.h
@@ -44,7 +44,7 @@
// |proxy_resolver| is the resolver that will be consulted for proxy
// settings. It may be null, in which case direct connections will
// be used. Does not take ownership of the resolver.
- HttpFetcher(ProxyResolver* proxy_resolver)
+ explicit HttpFetcher(ProxyResolver* proxy_resolver)
: post_data_set_(false),
http_response_code_(0),
delegate_(nullptr),
@@ -95,6 +95,12 @@
// TransferTerminated() will be called when the transfer is actually done.
virtual void TerminateTransfer() = 0;
+ // Add or update a custom header to be sent with every request. If the same
+ // |header_name| is passed twice, the second |header_value| would override the
+ // previous value.
+ virtual void SetHeader(const std::string& header_name,
+ const std::string& header_value) = 0;
+
// If data is coming in too quickly, you can call Pause() to pause the
// transfer. The delegate will not have ReceivedBytes() called while
// an HttpFetcher is paused.
diff --git a/common/http_fetcher_unittest.cc b/common/http_fetcher_unittest.cc
index 5450958..0d4b5da 100644
--- a/common/http_fetcher_unittest.cc
+++ b/common/http_fetcher_unittest.cc
@@ -370,12 +370,12 @@
namespace {
class HttpFetcherTestDelegate : public HttpFetcherDelegate {
public:
- HttpFetcherTestDelegate() :
- is_expect_error_(false), times_transfer_complete_called_(0),
- times_transfer_terminated_called_(0), times_received_bytes_called_(0) {}
+ HttpFetcherTestDelegate() = default;
void ReceivedBytes(HttpFetcher* /* fetcher */,
- const void* /* bytes */, size_t /* length */) override {
+ const void* bytes,
+ size_t length) override {
+ data.append(reinterpret_cast<const char*>(bytes), length);
// Update counters
times_received_bytes_called_++;
}
@@ -397,12 +397,15 @@
}
// Are we expecting an error response? (default: no)
- bool is_expect_error_;
+ bool is_expect_error_{false};
// Counters for callback invocations.
- int times_transfer_complete_called_;
- int times_transfer_terminated_called_;
- int times_received_bytes_called_;
+ int times_transfer_complete_called_{0};
+ int times_transfer_terminated_called_{0};
+ int times_received_bytes_called_{0};
+
+ // The received data bytes.
+ string data;
};
@@ -473,6 +476,41 @@
CHECK_EQ(delegate.times_transfer_terminated_called_, 0);
}
+TYPED_TEST(HttpFetcherTest, ExtraHeadersInRequestTest) {
+ if (this->test_.IsMock())
+ return;
+
+ HttpFetcherTestDelegate delegate;
+ unique_ptr<HttpFetcher> fetcher(this->test_.NewSmallFetcher());
+ fetcher->set_delegate(&delegate);
+ fetcher->SetHeader("User-Agent", "MyTest");
+ fetcher->SetHeader("user-agent", "Override that header");
+ fetcher->SetHeader("Authorization", "Basic user:passwd");
+
+ // Invalid headers.
+ fetcher->SetHeader("X-Foo", "Invalid\nHeader\nIgnored");
+ fetcher->SetHeader("X-Bar: ", "I do not know how to parse");
+
+ // Hide Accept header normally added by default.
+ fetcher->SetHeader("Accept", "");
+
+ PythonHttpServer server;
+ int port = server.GetPort();
+ ASSERT_TRUE(server.started_);
+
+ StartTransfer(fetcher.get(), LocalServerUrlForPath(port, "/echo-headers"));
+ this->loop_.Run();
+
+ EXPECT_NE(string::npos,
+ delegate.data.find("user-agent: Override that header\r\n"));
+ EXPECT_NE(string::npos,
+ delegate.data.find("Authorization: Basic user:passwd\r\n"));
+
+ EXPECT_EQ(string::npos, delegate.data.find("\nAccept:"));
+ EXPECT_EQ(string::npos, delegate.data.find("X-Foo: Invalid"));
+ EXPECT_EQ(string::npos, delegate.data.find("X-Bar: I do not"));
+}
+
namespace {
class PausingHttpFetcherTestDelegate : public HttpFetcherDelegate {
public:
diff --git a/common/libcurl_http_fetcher.cc b/common/libcurl_http_fetcher.cc
index 789f46e..b2dffc2 100644
--- a/common/libcurl_http_fetcher.cc
+++ b/common/libcurl_http_fetcher.cc
@@ -129,24 +129,32 @@
CHECK_EQ(curl_easy_setopt(curl_handle_, CURLOPT_POSTFIELDSIZE,
post_data_.size()),
CURLE_OK);
+ }
+ // Setup extra HTTP headers.
+ if (curl_http_headers_) {
+ curl_slist_free_all(curl_http_headers_);
+ curl_http_headers_ = nullptr;
+ }
+ for (const auto& header : extra_headers_) {
+ // curl_slist_append() copies the string.
+ curl_http_headers_ =
+ curl_slist_append(curl_http_headers_, header.second.c_str());
+ }
+ if (post_data_set_) {
// Set the Content-Type HTTP header, if one was specifically set.
- CHECK(!curl_http_headers_);
if (post_content_type_ != kHttpContentTypeUnspecified) {
- const string content_type_attr =
- base::StringPrintf("Content-Type: %s",
- GetHttpContentTypeString(post_content_type_));
- curl_http_headers_ = curl_slist_append(nullptr,
- content_type_attr.c_str());
- CHECK(curl_http_headers_);
- CHECK_EQ(
- curl_easy_setopt(curl_handle_, CURLOPT_HTTPHEADER,
- curl_http_headers_),
- CURLE_OK);
+ const string content_type_attr = base::StringPrintf(
+ "Content-Type: %s", GetHttpContentTypeString(post_content_type_));
+ curl_http_headers_ =
+ curl_slist_append(curl_http_headers_, content_type_attr.c_str());
} else {
LOG(WARNING) << "no content type set, using libcurl default";
}
}
+ CHECK_EQ(
+ curl_easy_setopt(curl_handle_, CURLOPT_HTTPHEADER, curl_http_headers_),
+ CURLE_OK);
if (bytes_downloaded_ > 0 || download_length_) {
// Resume from where we left off.
@@ -318,6 +326,17 @@
}
}
+void LibcurlHttpFetcher::SetHeader(const string& header_name,
+ const string& header_value) {
+ string header_line = header_name + ": " + header_value;
+ // Avoid the space if no data on the right side of the semicolon.
+ if (header_value.empty())
+ header_line = header_name + ":";
+ TEST_AND_RETURN(header_line.find('\n') == string::npos);
+ TEST_AND_RETURN(header_name.find(':') == string::npos);
+ extra_headers_[base::ToLowerASCII(header_name)] = header_line;
+}
+
void LibcurlHttpFetcher::CurlPerformOnce() {
CHECK(transfer_in_progress_);
int running_handles = 0;
diff --git a/common/libcurl_http_fetcher.h b/common/libcurl_http_fetcher.h
index 218e6cb..d126171 100644
--- a/common/libcurl_http_fetcher.h
+++ b/common/libcurl_http_fetcher.h
@@ -57,6 +57,10 @@
// cannot be resumed.
void TerminateTransfer() override;
+ // Pass the headers to libcurl.
+ void SetHeader(const std::string& header_name,
+ const std::string& header_value) override;
+
// Suspend the transfer by calling curl_easy_pause(CURLPAUSE_ALL).
void Pause() override;
@@ -181,6 +185,9 @@
CURL* curl_handle_{nullptr};
struct curl_slist* curl_http_headers_{nullptr};
+ // The extra headers that will be sent on each request.
+ std::map<std::string, std::string> extra_headers_;
+
// Lists of all read(0)/write(1) file descriptors that we're waiting on from
// the message loop. libcurl may open/close descriptors and switch their
// directions so maintain two separate lists so that watch conditions can be
diff --git a/common/mock_http_fetcher.cc b/common/mock_http_fetcher.cc
index f3fa70d..d0348f1 100644
--- a/common/mock_http_fetcher.cc
+++ b/common/mock_http_fetcher.cc
@@ -20,6 +20,7 @@
#include <base/bind.h>
#include <base/logging.h>
+#include <base/strings/string_util.h>
#include <base/time/time.h>
#include <gtest/gtest.h>
@@ -117,6 +118,11 @@
delegate_->TransferTerminated(this);
}
+void MockHttpFetcher::SetHeader(const std::string& header_name,
+ const std::string& header_value) {
+ extra_headers_[base::ToLowerASCII(header_name)] = header_value;
+}
+
void MockHttpFetcher::Pause() {
CHECK(!paused_);
paused_ = true;
diff --git a/common/mock_http_fetcher.h b/common/mock_http_fetcher.h
index 90d34dd..e56318e 100644
--- a/common/mock_http_fetcher.h
+++ b/common/mock_http_fetcher.h
@@ -17,6 +17,7 @@
#ifndef UPDATE_ENGINE_COMMON_MOCK_HTTP_FETCHER_H_
#define UPDATE_ENGINE_COMMON_MOCK_HTTP_FETCHER_H_
+#include <map>
#include <string>
#include <vector>
@@ -87,6 +88,9 @@
// The transfer cannot be resumed.
void TerminateTransfer() override;
+ void SetHeader(const std::string& header_name,
+ const std::string& header_value) override;
+
// Suspend the mock transfer.
void Pause() override;
@@ -125,6 +129,9 @@
// The number of bytes we've sent so far
size_t sent_size_;
+ // The extra headers set.
+ std::map<std::string, std::string> extra_headers_;
+
// The TaskId of the timeout callback. After each chunk of data sent, we
// time out for 0s just to make sure that run loop services other clients.
brillo::MessageLoop::TaskId timeout_id_;
diff --git a/common/multi_range_http_fetcher.h b/common/multi_range_http_fetcher.h
index 8158a22..8a91ead 100644
--- a/common/multi_range_http_fetcher.h
+++ b/common/multi_range_http_fetcher.h
@@ -80,6 +80,11 @@
// State change: Downloading -> Pending transfer ended
void TerminateTransfer() override;
+ void SetHeader(const std::string& header_name,
+ const std::string& header_value) override {
+ base_fetcher_->SetHeader(header_name, header_value);
+ }
+
void Pause() override { base_fetcher_->Pause(); }
void Unpause() override { base_fetcher_->Unpause(); }
diff --git a/test_http_server.cc b/test_http_server.cc
index 98e7a6d..2955e79 100644
--- a/test_http_server.cc
+++ b/test_http_server.cc
@@ -72,13 +72,12 @@
};
struct HttpRequest {
- HttpRequest()
- : start_offset(0), end_offset(0), return_code(kHttpResponseOk) {}
+ string raw_headers;
string host;
string url;
- off_t start_offset;
- off_t end_offset; // non-inclusive, zero indicates unspecified.
- HttpResponseCode return_code;
+ off_t start_offset{0};
+ off_t end_offset{0}; // non-inclusive, zero indicates unspecified.
+ HttpResponseCode return_code{kHttpResponseOk};
};
bool ParseRequest(int fd, HttpRequest* request) {
@@ -96,6 +95,7 @@
LOG(INFO) << "got headers:\n--8<------8<------8<------8<----\n"
<< headers
<< "\n--8<------8<------8<------8<----";
+ request->raw_headers = headers;
// Break header into lines.
vector<string> lines;
@@ -452,6 +452,13 @@
}
}
+// Returns a valid response echoing in the body of the response all the headers
+// sent by the client.
+void HandleEchoHeaders(int fd, const HttpRequest& request) {
+ WriteHeaders(fd, 0, request.raw_headers.size(), kHttpResponseOk);
+ WriteString(fd, request.raw_headers);
+}
+
void HandleHang(int fd) {
LOG(INFO) << "Hanging until the other side of the connection is closed.";
char c;
@@ -512,8 +519,8 @@
LOG(INFO) << "pid(" << getpid() << "): handling url " << url;
if (url == "/quitquitquit") {
HandleQuit(fd);
- } else if (base::StartsWith(url, "/download/",
- base::CompareCase::SENSITIVE)) {
+ } else if (base::StartsWith(
+ url, "/download/", base::CompareCase::SENSITIVE)) {
const UrlTerms terms(url, 2);
HandleGet(fd, request, terms.GetSizeT(1));
} else if (base::StartsWith(url, "/flaky/", base::CompareCase::SENSITIVE)) {
@@ -528,6 +535,8 @@
base::CompareCase::SENSITIVE)) {
const UrlTerms terms(url, 3);
HandleErrorIfOffset(fd, request, terms.GetSizeT(1), terms.GetInt(2));
+ } else if (url == "/echo-headers") {
+ HandleEchoHeaders(fd, request);
} else if (url == "/hang") {
HandleHang(fd);
} else {
diff --git a/update_attempter_android.cc b/update_attempter_android.cc
index 6f88ee7..bc9a698 100644
--- a/update_attempter_android.cc
+++ b/update_attempter_android.cc
@@ -171,6 +171,13 @@
BuildUpdateActions();
SetupDownload();
+ // Setup extra headers.
+ HttpFetcher* fetcher = download_action_->http_fetcher();
+ if (!headers[kPayloadPropertyAuthorization].empty())
+ fetcher->SetHeader("Authorization", headers[kPayloadPropertyAuthorization]);
+ if (!headers[kPayloadPropertyUserAgent].empty())
+ fetcher->SetHeader("User-Agent", headers[kPayloadPropertyUserAgent]);
+
cpu_limiter_.StartLimiter();
SetStatusAndNotify(UpdateStatus::UPDATE_AVAILABLE);
ongoing_update_ = true;