Merge "Parse and use extra HTTP headers when downloading the payload." into nyc-dev
diff --git a/common/constants.cc b/common/constants.cc
index fc6df37..f138ce3 100644
--- a/common/constants.cc
+++ b/common/constants.cc
@@ -94,5 +94,7 @@
 const char kPayloadPropertyFileHash[] = "FILE_HASH";
 const char kPayloadPropertyMetadataSize[] = "METADATA_SIZE";
 const char kPayloadPropertyMetadataHash[] = "METADATA_HASH";
+const char kPayloadPropertyAuthorization[] = "AUTHORIZATION";
+const char kPayloadPropertyUserAgent[] = "USER_AGENT";
 
 }  // namespace chromeos_update_engine
diff --git a/common/constants.h b/common/constants.h
index 25d587b..f0d589d 100644
--- a/common/constants.h
+++ b/common/constants.h
@@ -97,6 +97,8 @@
 extern const char kPayloadPropertyFileHash[];
 extern const char kPayloadPropertyMetadataSize[];
 extern const char kPayloadPropertyMetadataHash[];
+extern const char kPayloadPropertyAuthorization[];
+extern const char kPayloadPropertyUserAgent[];
 
 // A download source is any combination of protocol and server (that's of
 // interest to us when looking at UMA metrics) using which we may download
diff --git a/common/http_fetcher.h b/common/http_fetcher.h
index 11e8e9f..d2499eb 100644
--- a/common/http_fetcher.h
+++ b/common/http_fetcher.h
@@ -44,7 +44,7 @@
   // |proxy_resolver| is the resolver that will be consulted for proxy
   // settings. It may be null, in which case direct connections will
   // be used. Does not take ownership of the resolver.
-  HttpFetcher(ProxyResolver* proxy_resolver)
+  explicit HttpFetcher(ProxyResolver* proxy_resolver)
       : post_data_set_(false),
         http_response_code_(0),
         delegate_(nullptr),
@@ -95,6 +95,12 @@
   // TransferTerminated() will be called when the transfer is actually done.
   virtual void TerminateTransfer() = 0;
 
+  // Add or update a custom header to be sent with every request. If the same
+  // |header_name| is passed twice, the second |header_value| would override the
+  // previous value.
+  virtual void SetHeader(const std::string& header_name,
+                         const std::string& header_value) = 0;
+
   // If data is coming in too quickly, you can call Pause() to pause the
   // transfer. The delegate will not have ReceivedBytes() called while
   // an HttpFetcher is paused.
diff --git a/common/http_fetcher_unittest.cc b/common/http_fetcher_unittest.cc
index 5450958..0d4b5da 100644
--- a/common/http_fetcher_unittest.cc
+++ b/common/http_fetcher_unittest.cc
@@ -370,12 +370,12 @@
 namespace {
 class HttpFetcherTestDelegate : public HttpFetcherDelegate {
  public:
-  HttpFetcherTestDelegate() :
-      is_expect_error_(false), times_transfer_complete_called_(0),
-      times_transfer_terminated_called_(0), times_received_bytes_called_(0) {}
+  HttpFetcherTestDelegate() = default;
 
   void ReceivedBytes(HttpFetcher* /* fetcher */,
-                     const void* /* bytes */, size_t /* length */) override {
+                     const void* bytes,
+                     size_t length) override {
+    data.append(reinterpret_cast<const char*>(bytes), length);
     // Update counters
     times_received_bytes_called_++;
   }
@@ -397,12 +397,15 @@
   }
 
   // Are we expecting an error response? (default: no)
-  bool is_expect_error_;
+  bool is_expect_error_{false};
 
   // Counters for callback invocations.
-  int times_transfer_complete_called_;
-  int times_transfer_terminated_called_;
-  int times_received_bytes_called_;
+  int times_transfer_complete_called_{0};
+  int times_transfer_terminated_called_{0};
+  int times_received_bytes_called_{0};
+
+  // The received data bytes.
+  string data;
 };
 
 
@@ -473,6 +476,41 @@
   CHECK_EQ(delegate.times_transfer_terminated_called_, 0);
 }
 
+TYPED_TEST(HttpFetcherTest, ExtraHeadersInRequestTest) {
+  if (this->test_.IsMock())
+    return;
+
+  HttpFetcherTestDelegate delegate;
+  unique_ptr<HttpFetcher> fetcher(this->test_.NewSmallFetcher());
+  fetcher->set_delegate(&delegate);
+  fetcher->SetHeader("User-Agent", "MyTest");
+  fetcher->SetHeader("user-agent", "Override that header");
+  fetcher->SetHeader("Authorization", "Basic user:passwd");
+
+  // Invalid headers.
+  fetcher->SetHeader("X-Foo", "Invalid\nHeader\nIgnored");
+  fetcher->SetHeader("X-Bar: ", "I do not know how to parse");
+
+  // Hide Accept header normally added by default.
+  fetcher->SetHeader("Accept", "");
+
+  PythonHttpServer server;
+  int port = server.GetPort();
+  ASSERT_TRUE(server.started_);
+
+  StartTransfer(fetcher.get(), LocalServerUrlForPath(port, "/echo-headers"));
+  this->loop_.Run();
+
+  EXPECT_NE(string::npos,
+            delegate.data.find("user-agent: Override that header\r\n"));
+  EXPECT_NE(string::npos,
+            delegate.data.find("Authorization: Basic user:passwd\r\n"));
+
+  EXPECT_EQ(string::npos, delegate.data.find("\nAccept:"));
+  EXPECT_EQ(string::npos, delegate.data.find("X-Foo: Invalid"));
+  EXPECT_EQ(string::npos, delegate.data.find("X-Bar: I do not"));
+}
+
 namespace {
 class PausingHttpFetcherTestDelegate : public HttpFetcherDelegate {
  public:
diff --git a/common/libcurl_http_fetcher.cc b/common/libcurl_http_fetcher.cc
index 789f46e..b2dffc2 100644
--- a/common/libcurl_http_fetcher.cc
+++ b/common/libcurl_http_fetcher.cc
@@ -129,24 +129,32 @@
     CHECK_EQ(curl_easy_setopt(curl_handle_, CURLOPT_POSTFIELDSIZE,
                               post_data_.size()),
              CURLE_OK);
+  }
 
+  // Setup extra HTTP headers.
+  if (curl_http_headers_) {
+    curl_slist_free_all(curl_http_headers_);
+    curl_http_headers_ = nullptr;
+  }
+  for (const auto& header : extra_headers_) {
+    // curl_slist_append() copies the string.
+    curl_http_headers_ =
+        curl_slist_append(curl_http_headers_, header.second.c_str());
+  }
+  if (post_data_set_) {
     // Set the Content-Type HTTP header, if one was specifically set.
-    CHECK(!curl_http_headers_);
     if (post_content_type_ != kHttpContentTypeUnspecified) {
-      const string content_type_attr =
-        base::StringPrintf("Content-Type: %s",
-                           GetHttpContentTypeString(post_content_type_));
-      curl_http_headers_ = curl_slist_append(nullptr,
-                                             content_type_attr.c_str());
-      CHECK(curl_http_headers_);
-      CHECK_EQ(
-          curl_easy_setopt(curl_handle_, CURLOPT_HTTPHEADER,
-                           curl_http_headers_),
-          CURLE_OK);
+      const string content_type_attr = base::StringPrintf(
+          "Content-Type: %s", GetHttpContentTypeString(post_content_type_));
+      curl_http_headers_ =
+          curl_slist_append(curl_http_headers_, content_type_attr.c_str());
     } else {
       LOG(WARNING) << "no content type set, using libcurl default";
     }
   }
+  CHECK_EQ(
+      curl_easy_setopt(curl_handle_, CURLOPT_HTTPHEADER, curl_http_headers_),
+      CURLE_OK);
 
   if (bytes_downloaded_ > 0 || download_length_) {
     // Resume from where we left off.
@@ -318,6 +326,17 @@
   }
 }
 
+void LibcurlHttpFetcher::SetHeader(const string& header_name,
+                                   const string& header_value) {
+  string header_line = header_name + ": " + header_value;
+  // Avoid the space if no data on the right side of the semicolon.
+  if (header_value.empty())
+    header_line = header_name + ":";
+  TEST_AND_RETURN(header_line.find('\n') == string::npos);
+  TEST_AND_RETURN(header_name.find(':') == string::npos);
+  extra_headers_[base::ToLowerASCII(header_name)] = header_line;
+}
+
 void LibcurlHttpFetcher::CurlPerformOnce() {
   CHECK(transfer_in_progress_);
   int running_handles = 0;
diff --git a/common/libcurl_http_fetcher.h b/common/libcurl_http_fetcher.h
index 218e6cb..d126171 100644
--- a/common/libcurl_http_fetcher.h
+++ b/common/libcurl_http_fetcher.h
@@ -57,6 +57,10 @@
   // cannot be resumed.
   void TerminateTransfer() override;
 
+  // Pass the headers to libcurl.
+  void SetHeader(const std::string& header_name,
+                 const std::string& header_value) override;
+
   // Suspend the transfer by calling curl_easy_pause(CURLPAUSE_ALL).
   void Pause() override;
 
@@ -181,6 +185,9 @@
   CURL* curl_handle_{nullptr};
   struct curl_slist* curl_http_headers_{nullptr};
 
+  // The extra headers that will be sent on each request.
+  std::map<std::string, std::string> extra_headers_;
+
   // Lists of all read(0)/write(1) file descriptors that we're waiting on from
   // the message loop. libcurl may open/close descriptors and switch their
   // directions so maintain two separate lists so that watch conditions can be
diff --git a/common/mock_http_fetcher.cc b/common/mock_http_fetcher.cc
index f3fa70d..d0348f1 100644
--- a/common/mock_http_fetcher.cc
+++ b/common/mock_http_fetcher.cc
@@ -20,6 +20,7 @@
 
 #include <base/bind.h>
 #include <base/logging.h>
+#include <base/strings/string_util.h>
 #include <base/time/time.h>
 #include <gtest/gtest.h>
 
@@ -117,6 +118,11 @@
   delegate_->TransferTerminated(this);
 }
 
+void MockHttpFetcher::SetHeader(const std::string& header_name,
+                                const std::string& header_value) {
+  extra_headers_[base::ToLowerASCII(header_name)] = header_value;
+}
+
 void MockHttpFetcher::Pause() {
   CHECK(!paused_);
   paused_ = true;
diff --git a/common/mock_http_fetcher.h b/common/mock_http_fetcher.h
index 90d34dd..e56318e 100644
--- a/common/mock_http_fetcher.h
+++ b/common/mock_http_fetcher.h
@@ -17,6 +17,7 @@
 #ifndef UPDATE_ENGINE_COMMON_MOCK_HTTP_FETCHER_H_
 #define UPDATE_ENGINE_COMMON_MOCK_HTTP_FETCHER_H_
 
+#include <map>
 #include <string>
 #include <vector>
 
@@ -87,6 +88,9 @@
   // The transfer cannot be resumed.
   void TerminateTransfer() override;
 
+  void SetHeader(const std::string& header_name,
+                 const std::string& header_value) override;
+
   // Suspend the mock transfer.
   void Pause() override;
 
@@ -125,6 +129,9 @@
   // The number of bytes we've sent so far
   size_t sent_size_;
 
+  // The extra headers set.
+  std::map<std::string, std::string> extra_headers_;
+
   // The TaskId of the timeout callback. After each chunk of data sent, we
   // time out for 0s just to make sure that run loop services other clients.
   brillo::MessageLoop::TaskId timeout_id_;
diff --git a/common/multi_range_http_fetcher.h b/common/multi_range_http_fetcher.h
index 8158a22..8a91ead 100644
--- a/common/multi_range_http_fetcher.h
+++ b/common/multi_range_http_fetcher.h
@@ -80,6 +80,11 @@
   // State change: Downloading -> Pending transfer ended
   void TerminateTransfer() override;
 
+  void SetHeader(const std::string& header_name,
+                 const std::string& header_value) override {
+    base_fetcher_->SetHeader(header_name, header_value);
+  }
+
   void Pause() override { base_fetcher_->Pause(); }
 
   void Unpause() override { base_fetcher_->Unpause(); }
diff --git a/test_http_server.cc b/test_http_server.cc
index 98e7a6d..2955e79 100644
--- a/test_http_server.cc
+++ b/test_http_server.cc
@@ -72,13 +72,12 @@
 };
 
 struct HttpRequest {
-  HttpRequest()
-      : start_offset(0), end_offset(0), return_code(kHttpResponseOk) {}
+  string raw_headers;
   string host;
   string url;
-  off_t start_offset;
-  off_t end_offset;  // non-inclusive, zero indicates unspecified.
-  HttpResponseCode return_code;
+  off_t start_offset{0};
+  off_t end_offset{0};  // non-inclusive, zero indicates unspecified.
+  HttpResponseCode return_code{kHttpResponseOk};
 };
 
 bool ParseRequest(int fd, HttpRequest* request) {
@@ -96,6 +95,7 @@
   LOG(INFO) << "got headers:\n--8<------8<------8<------8<----\n"
             << headers
             << "\n--8<------8<------8<------8<----";
+  request->raw_headers = headers;
 
   // Break header into lines.
   vector<string> lines;
@@ -452,6 +452,13 @@
   }
 }
 
+// Returns a valid response echoing in the body of the response all the headers
+// sent by the client.
+void HandleEchoHeaders(int fd, const HttpRequest& request) {
+  WriteHeaders(fd, 0, request.raw_headers.size(), kHttpResponseOk);
+  WriteString(fd, request.raw_headers);
+}
+
 void HandleHang(int fd) {
   LOG(INFO) << "Hanging until the other side of the connection is closed.";
   char c;
@@ -512,8 +519,8 @@
   LOG(INFO) << "pid(" << getpid() <<  "): handling url " << url;
   if (url == "/quitquitquit") {
     HandleQuit(fd);
-  } else if (base::StartsWith(url, "/download/", 
-                              base::CompareCase::SENSITIVE)) {
+  } else if (base::StartsWith(
+                 url, "/download/", base::CompareCase::SENSITIVE)) {
     const UrlTerms terms(url, 2);
     HandleGet(fd, request, terms.GetSizeT(1));
   } else if (base::StartsWith(url, "/flaky/", base::CompareCase::SENSITIVE)) {
@@ -528,6 +535,8 @@
                               base::CompareCase::SENSITIVE)) {
     const UrlTerms terms(url, 3);
     HandleErrorIfOffset(fd, request, terms.GetSizeT(1), terms.GetInt(2));
+  } else if (url == "/echo-headers") {
+    HandleEchoHeaders(fd, request);
   } else if (url == "/hang") {
     HandleHang(fd);
   } else {
diff --git a/update_attempter_android.cc b/update_attempter_android.cc
index 6f88ee7..bc9a698 100644
--- a/update_attempter_android.cc
+++ b/update_attempter_android.cc
@@ -171,6 +171,13 @@
 
   BuildUpdateActions();
   SetupDownload();
+  // Setup extra headers.
+  HttpFetcher* fetcher = download_action_->http_fetcher();
+  if (!headers[kPayloadPropertyAuthorization].empty())
+    fetcher->SetHeader("Authorization", headers[kPayloadPropertyAuthorization]);
+  if (!headers[kPayloadPropertyUserAgent].empty())
+    fetcher->SetHeader("User-Agent", headers[kPayloadPropertyUserAgent]);
+
   cpu_limiter_.StartLimiter();
   SetStatusAndNotify(UpdateStatus::UPDATE_AVAILABLE);
   ongoing_update_ = true;