Merge "Parse and use extra HTTP headers when downloading the payload." into nyc-dev
diff --git a/common/constants.cc b/common/constants.cc
index fc6df37..f138ce3 100644
--- a/common/constants.cc
+++ b/common/constants.cc
@@ -94,5 +94,7 @@
const char kPayloadPropertyFileHash[] = "FILE_HASH";
const char kPayloadPropertyMetadataSize[] = "METADATA_SIZE";
const char kPayloadPropertyMetadataHash[] = "METADATA_HASH";
+const char kPayloadPropertyAuthorization[] = "AUTHORIZATION";
+const char kPayloadPropertyUserAgent[] = "USER_AGENT";
} // namespace chromeos_update_engine
diff --git a/common/constants.h b/common/constants.h
index 25d587b..f0d589d 100644
--- a/common/constants.h
+++ b/common/constants.h
@@ -97,6 +97,8 @@
extern const char kPayloadPropertyFileHash[];
extern const char kPayloadPropertyMetadataSize[];
extern const char kPayloadPropertyMetadataHash[];
+extern const char kPayloadPropertyAuthorization[];
+extern const char kPayloadPropertyUserAgent[];
// A download source is any combination of protocol and server (that's of
// interest to us when looking at UMA metrics) using which we may download
diff --git a/common/http_fetcher.h b/common/http_fetcher.h
index 11e8e9f..d2499eb 100644
--- a/common/http_fetcher.h
+++ b/common/http_fetcher.h
@@ -44,7 +44,7 @@
// |proxy_resolver| is the resolver that will be consulted for proxy
// settings. It may be null, in which case direct connections will
// be used. Does not take ownership of the resolver.
- HttpFetcher(ProxyResolver* proxy_resolver)
+ explicit HttpFetcher(ProxyResolver* proxy_resolver)
: post_data_set_(false),
http_response_code_(0),
delegate_(nullptr),
@@ -95,6 +95,12 @@
// TransferTerminated() will be called when the transfer is actually done.
virtual void TerminateTransfer() = 0;
+ // Add or update a custom header to be sent with every request. If the same
+ // |header_name| is passed twice, the second |header_value| would override the
+ // previous value.
+ virtual void SetHeader(const std::string& header_name,
+ const std::string& header_value) = 0;
+
// If data is coming in too quickly, you can call Pause() to pause the
// transfer. The delegate will not have ReceivedBytes() called while
// an HttpFetcher is paused.
diff --git a/common/http_fetcher_unittest.cc b/common/http_fetcher_unittest.cc
index 5450958..0d4b5da 100644
--- a/common/http_fetcher_unittest.cc
+++ b/common/http_fetcher_unittest.cc
@@ -370,12 +370,12 @@
namespace {
class HttpFetcherTestDelegate : public HttpFetcherDelegate {
public:
- HttpFetcherTestDelegate() :
- is_expect_error_(false), times_transfer_complete_called_(0),
- times_transfer_terminated_called_(0), times_received_bytes_called_(0) {}
+ HttpFetcherTestDelegate() = default;
void ReceivedBytes(HttpFetcher* /* fetcher */,
- const void* /* bytes */, size_t /* length */) override {
+ const void* bytes,
+ size_t length) override {
+ data.append(reinterpret_cast<const char*>(bytes), length);
// Update counters
times_received_bytes_called_++;
}
@@ -397,12 +397,15 @@
}
// Are we expecting an error response? (default: no)
- bool is_expect_error_;
+ bool is_expect_error_{false};
// Counters for callback invocations.
- int times_transfer_complete_called_;
- int times_transfer_terminated_called_;
- int times_received_bytes_called_;
+ int times_transfer_complete_called_{0};
+ int times_transfer_terminated_called_{0};
+ int times_received_bytes_called_{0};
+
+ // The received data bytes.
+ string data;
};
@@ -473,6 +476,41 @@
CHECK_EQ(delegate.times_transfer_terminated_called_, 0);
}
+TYPED_TEST(HttpFetcherTest, ExtraHeadersInRequestTest) {
+ if (this->test_.IsMock())
+ return;
+
+ HttpFetcherTestDelegate delegate;
+ unique_ptr<HttpFetcher> fetcher(this->test_.NewSmallFetcher());
+ fetcher->set_delegate(&delegate);
+ fetcher->SetHeader("User-Agent", "MyTest");
+ fetcher->SetHeader("user-agent", "Override that header");
+ fetcher->SetHeader("Authorization", "Basic user:passwd");
+
+ // Invalid headers.
+ fetcher->SetHeader("X-Foo", "Invalid\nHeader\nIgnored");
+ fetcher->SetHeader("X-Bar: ", "I do not know how to parse");
+
+ // Hide Accept header normally added by default.
+ fetcher->SetHeader("Accept", "");
+
+ PythonHttpServer server;
+ int port = server.GetPort();
+ ASSERT_TRUE(server.started_);
+
+ StartTransfer(fetcher.get(), LocalServerUrlForPath(port, "/echo-headers"));
+ this->loop_.Run();
+
+ EXPECT_NE(string::npos,
+ delegate.data.find("user-agent: Override that header\r\n"));
+ EXPECT_NE(string::npos,
+ delegate.data.find("Authorization: Basic user:passwd\r\n"));
+
+ EXPECT_EQ(string::npos, delegate.data.find("\nAccept:"));
+ EXPECT_EQ(string::npos, delegate.data.find("X-Foo: Invalid"));
+ EXPECT_EQ(string::npos, delegate.data.find("X-Bar: I do not"));
+}
+
namespace {
class PausingHttpFetcherTestDelegate : public HttpFetcherDelegate {
public:
diff --git a/common/libcurl_http_fetcher.cc b/common/libcurl_http_fetcher.cc
index 789f46e..b2dffc2 100644
--- a/common/libcurl_http_fetcher.cc
+++ b/common/libcurl_http_fetcher.cc
@@ -129,24 +129,32 @@
CHECK_EQ(curl_easy_setopt(curl_handle_, CURLOPT_POSTFIELDSIZE,
post_data_.size()),
CURLE_OK);
+ }
+ // Setup extra HTTP headers.
+ if (curl_http_headers_) {
+ curl_slist_free_all(curl_http_headers_);
+ curl_http_headers_ = nullptr;
+ }
+ for (const auto& header : extra_headers_) {
+ // curl_slist_append() copies the string.
+ curl_http_headers_ =
+ curl_slist_append(curl_http_headers_, header.second.c_str());
+ }
+ if (post_data_set_) {
// Set the Content-Type HTTP header, if one was specifically set.
- CHECK(!curl_http_headers_);
if (post_content_type_ != kHttpContentTypeUnspecified) {
- const string content_type_attr =
- base::StringPrintf("Content-Type: %s",
- GetHttpContentTypeString(post_content_type_));
- curl_http_headers_ = curl_slist_append(nullptr,
- content_type_attr.c_str());
- CHECK(curl_http_headers_);
- CHECK_EQ(
- curl_easy_setopt(curl_handle_, CURLOPT_HTTPHEADER,
- curl_http_headers_),
- CURLE_OK);
+ const string content_type_attr = base::StringPrintf(
+ "Content-Type: %s", GetHttpContentTypeString(post_content_type_));
+ curl_http_headers_ =
+ curl_slist_append(curl_http_headers_, content_type_attr.c_str());
} else {
LOG(WARNING) << "no content type set, using libcurl default";
}
}
+ CHECK_EQ(
+ curl_easy_setopt(curl_handle_, CURLOPT_HTTPHEADER, curl_http_headers_),
+ CURLE_OK);
if (bytes_downloaded_ > 0 || download_length_) {
// Resume from where we left off.
@@ -318,6 +326,17 @@
}
}
+void LibcurlHttpFetcher::SetHeader(const string& header_name,
+ const string& header_value) {
+ string header_line = header_name + ": " + header_value;
+ // Avoid the space if no data on the right side of the semicolon.
+ if (header_value.empty())
+ header_line = header_name + ":";
+ TEST_AND_RETURN(header_line.find('\n') == string::npos);
+ TEST_AND_RETURN(header_name.find(':') == string::npos);
+ extra_headers_[base::ToLowerASCII(header_name)] = header_line;
+}
+
void LibcurlHttpFetcher::CurlPerformOnce() {
CHECK(transfer_in_progress_);
int running_handles = 0;
diff --git a/common/libcurl_http_fetcher.h b/common/libcurl_http_fetcher.h
index 218e6cb..d126171 100644
--- a/common/libcurl_http_fetcher.h
+++ b/common/libcurl_http_fetcher.h
@@ -57,6 +57,10 @@
// cannot be resumed.
void TerminateTransfer() override;
+ // Pass the headers to libcurl.
+ void SetHeader(const std::string& header_name,
+ const std::string& header_value) override;
+
// Suspend the transfer by calling curl_easy_pause(CURLPAUSE_ALL).
void Pause() override;
@@ -181,6 +185,9 @@
CURL* curl_handle_{nullptr};
struct curl_slist* curl_http_headers_{nullptr};
+ // The extra headers that will be sent on each request.
+ std::map<std::string, std::string> extra_headers_;
+
// Lists of all read(0)/write(1) file descriptors that we're waiting on from
// the message loop. libcurl may open/close descriptors and switch their
// directions so maintain two separate lists so that watch conditions can be
diff --git a/common/mock_http_fetcher.cc b/common/mock_http_fetcher.cc
index f3fa70d..d0348f1 100644
--- a/common/mock_http_fetcher.cc
+++ b/common/mock_http_fetcher.cc
@@ -20,6 +20,7 @@
#include <base/bind.h>
#include <base/logging.h>
+#include <base/strings/string_util.h>
#include <base/time/time.h>
#include <gtest/gtest.h>
@@ -117,6 +118,11 @@
delegate_->TransferTerminated(this);
}
+void MockHttpFetcher::SetHeader(const std::string& header_name,
+ const std::string& header_value) {
+ extra_headers_[base::ToLowerASCII(header_name)] = header_value;
+}
+
void MockHttpFetcher::Pause() {
CHECK(!paused_);
paused_ = true;
diff --git a/common/mock_http_fetcher.h b/common/mock_http_fetcher.h
index 90d34dd..e56318e 100644
--- a/common/mock_http_fetcher.h
+++ b/common/mock_http_fetcher.h
@@ -17,6 +17,7 @@
#ifndef UPDATE_ENGINE_COMMON_MOCK_HTTP_FETCHER_H_
#define UPDATE_ENGINE_COMMON_MOCK_HTTP_FETCHER_H_
+#include <map>
#include <string>
#include <vector>
@@ -87,6 +88,9 @@
// The transfer cannot be resumed.
void TerminateTransfer() override;
+ void SetHeader(const std::string& header_name,
+ const std::string& header_value) override;
+
// Suspend the mock transfer.
void Pause() override;
@@ -125,6 +129,9 @@
// The number of bytes we've sent so far
size_t sent_size_;
+ // The extra headers set.
+ std::map<std::string, std::string> extra_headers_;
+
// The TaskId of the timeout callback. After each chunk of data sent, we
// time out for 0s just to make sure that run loop services other clients.
brillo::MessageLoop::TaskId timeout_id_;
diff --git a/common/multi_range_http_fetcher.h b/common/multi_range_http_fetcher.h
index 8158a22..8a91ead 100644
--- a/common/multi_range_http_fetcher.h
+++ b/common/multi_range_http_fetcher.h
@@ -80,6 +80,11 @@
// State change: Downloading -> Pending transfer ended
void TerminateTransfer() override;
+ void SetHeader(const std::string& header_name,
+ const std::string& header_value) override {
+ base_fetcher_->SetHeader(header_name, header_value);
+ }
+
void Pause() override { base_fetcher_->Pause(); }
void Unpause() override { base_fetcher_->Unpause(); }
diff --git a/test_http_server.cc b/test_http_server.cc
index 98e7a6d..2955e79 100644
--- a/test_http_server.cc
+++ b/test_http_server.cc
@@ -72,13 +72,12 @@
};
struct HttpRequest {
- HttpRequest()
- : start_offset(0), end_offset(0), return_code(kHttpResponseOk) {}
+ string raw_headers;
string host;
string url;
- off_t start_offset;
- off_t end_offset; // non-inclusive, zero indicates unspecified.
- HttpResponseCode return_code;
+ off_t start_offset{0};
+ off_t end_offset{0}; // non-inclusive, zero indicates unspecified.
+ HttpResponseCode return_code{kHttpResponseOk};
};
bool ParseRequest(int fd, HttpRequest* request) {
@@ -96,6 +95,7 @@
LOG(INFO) << "got headers:\n--8<------8<------8<------8<----\n"
<< headers
<< "\n--8<------8<------8<------8<----";
+ request->raw_headers = headers;
// Break header into lines.
vector<string> lines;
@@ -452,6 +452,13 @@
}
}
+// Returns a valid response echoing in the body of the response all the headers
+// sent by the client.
+void HandleEchoHeaders(int fd, const HttpRequest& request) {
+ WriteHeaders(fd, 0, request.raw_headers.size(), kHttpResponseOk);
+ WriteString(fd, request.raw_headers);
+}
+
void HandleHang(int fd) {
LOG(INFO) << "Hanging until the other side of the connection is closed.";
char c;
@@ -512,8 +519,8 @@
LOG(INFO) << "pid(" << getpid() << "): handling url " << url;
if (url == "/quitquitquit") {
HandleQuit(fd);
- } else if (base::StartsWith(url, "/download/",
- base::CompareCase::SENSITIVE)) {
+ } else if (base::StartsWith(
+ url, "/download/", base::CompareCase::SENSITIVE)) {
const UrlTerms terms(url, 2);
HandleGet(fd, request, terms.GetSizeT(1));
} else if (base::StartsWith(url, "/flaky/", base::CompareCase::SENSITIVE)) {
@@ -528,6 +535,8 @@
base::CompareCase::SENSITIVE)) {
const UrlTerms terms(url, 3);
HandleErrorIfOffset(fd, request, terms.GetSizeT(1), terms.GetInt(2));
+ } else if (url == "/echo-headers") {
+ HandleEchoHeaders(fd, request);
} else if (url == "/hang") {
HandleHang(fd);
} else {
diff --git a/update_attempter_android.cc b/update_attempter_android.cc
index 6f88ee7..bc9a698 100644
--- a/update_attempter_android.cc
+++ b/update_attempter_android.cc
@@ -171,6 +171,13 @@
BuildUpdateActions();
SetupDownload();
+ // Setup extra headers.
+ HttpFetcher* fetcher = download_action_->http_fetcher();
+ if (!headers[kPayloadPropertyAuthorization].empty())
+ fetcher->SetHeader("Authorization", headers[kPayloadPropertyAuthorization]);
+ if (!headers[kPayloadPropertyUserAgent].empty())
+ fetcher->SetHeader("User-Agent", headers[kPayloadPropertyUserAgent]);
+
cpu_limiter_.StartLimiter();
SetStatusAndNotify(UpdateStatus::UPDATE_AVAILABLE);
ongoing_update_ = true;