Parse and use extra HTTP headers when downloading the payload.

Android OTA backend requires to pass an Authorization HTTP header in
order to download some payload. This patch allows to specify such
header when initiating a payload download from Android.

Bug: 27047110
TEST=Added unittests to check the headers sent.

(cherry picked from commit fdd6dec9c4be2fbd667cf874c4cc6f4ffecaeef9)

Change-Id: I59d38d79a7b7a8975d105c611c692522b6c33707
diff --git a/common/constants.cc b/common/constants.cc
index b15c3f4..3b7aa6e 100644
--- a/common/constants.cc
+++ b/common/constants.cc
@@ -91,5 +91,7 @@
 const char kPayloadPropertyFileHash[] = "FILE_HASH";
 const char kPayloadPropertyMetadataSize[] = "METADATA_SIZE";
 const char kPayloadPropertyMetadataHash[] = "METADATA_HASH";
+const char kPayloadPropertyAuthorization[] = "AUTHORIZATION";
+const char kPayloadPropertyUserAgent[] = "USER_AGENT";
 
 }  // namespace chromeos_update_engine
diff --git a/common/constants.h b/common/constants.h
index 62f61ce..d001329 100644
--- a/common/constants.h
+++ b/common/constants.h
@@ -93,6 +93,8 @@
 extern const char kPayloadPropertyFileHash[];
 extern const char kPayloadPropertyMetadataSize[];
 extern const char kPayloadPropertyMetadataHash[];
+extern const char kPayloadPropertyAuthorization[];
+extern const char kPayloadPropertyUserAgent[];
 
 // A download source is any combination of protocol and server (that's of
 // interest to us when looking at UMA metrics) using which we may download
diff --git a/common/http_fetcher.h b/common/http_fetcher.h
index 11e8e9f..d2499eb 100644
--- a/common/http_fetcher.h
+++ b/common/http_fetcher.h
@@ -44,7 +44,7 @@
   // |proxy_resolver| is the resolver that will be consulted for proxy
   // settings. It may be null, in which case direct connections will
   // be used. Does not take ownership of the resolver.
-  HttpFetcher(ProxyResolver* proxy_resolver)
+  explicit HttpFetcher(ProxyResolver* proxy_resolver)
       : post_data_set_(false),
         http_response_code_(0),
         delegate_(nullptr),
@@ -95,6 +95,12 @@
   // TransferTerminated() will be called when the transfer is actually done.
   virtual void TerminateTransfer() = 0;
 
+  // Add or update a custom header to be sent with every request. If the same
+  // |header_name| is passed twice, the second |header_value| would override the
+  // previous value.
+  virtual void SetHeader(const std::string& header_name,
+                         const std::string& header_value) = 0;
+
   // If data is coming in too quickly, you can call Pause() to pause the
   // transfer. The delegate will not have ReceivedBytes() called while
   // an HttpFetcher is paused.
diff --git a/common/http_fetcher_unittest.cc b/common/http_fetcher_unittest.cc
index aaa538c..bd723d7 100644
--- a/common/http_fetcher_unittest.cc
+++ b/common/http_fetcher_unittest.cc
@@ -377,12 +377,12 @@
 namespace {
 class HttpFetcherTestDelegate : public HttpFetcherDelegate {
  public:
-  HttpFetcherTestDelegate() :
-      is_expect_error_(false), times_transfer_complete_called_(0),
-      times_transfer_terminated_called_(0), times_received_bytes_called_(0) {}
+  HttpFetcherTestDelegate() = default;
 
   void ReceivedBytes(HttpFetcher* /* fetcher */,
-                     const void* /* bytes */, size_t /* length */) override {
+                     const void* bytes,
+                     size_t length) override {
+    data.append(reinterpret_cast<const char*>(bytes), length);
     // Update counters
     times_received_bytes_called_++;
   }
@@ -404,12 +404,15 @@
   }
 
   // Are we expecting an error response? (default: no)
-  bool is_expect_error_;
+  bool is_expect_error_{false};
 
   // Counters for callback invocations.
-  int times_transfer_complete_called_;
-  int times_transfer_terminated_called_;
-  int times_received_bytes_called_;
+  int times_transfer_complete_called_{0};
+  int times_transfer_terminated_called_{0};
+  int times_received_bytes_called_{0};
+
+  // The received data bytes.
+  string data;
 };
 
 
@@ -480,6 +483,41 @@
   CHECK_EQ(delegate.times_transfer_terminated_called_, 0);
 }
 
+TYPED_TEST(HttpFetcherTest, ExtraHeadersInRequestTest) {
+  if (this->test_.IsMock())
+    return;
+
+  HttpFetcherTestDelegate delegate;
+  unique_ptr<HttpFetcher> fetcher(this->test_.NewSmallFetcher());
+  fetcher->set_delegate(&delegate);
+  fetcher->SetHeader("User-Agent", "MyTest");
+  fetcher->SetHeader("user-agent", "Override that header");
+  fetcher->SetHeader("Authorization", "Basic user:passwd");
+
+  // Invalid headers.
+  fetcher->SetHeader("X-Foo", "Invalid\nHeader\nIgnored");
+  fetcher->SetHeader("X-Bar: ", "I do not know how to parse");
+
+  // Hide Accept header normally added by default.
+  fetcher->SetHeader("Accept", "");
+
+  PythonHttpServer server;
+  int port = server.GetPort();
+  ASSERT_TRUE(server.started_);
+
+  StartTransfer(fetcher.get(), LocalServerUrlForPath(port, "/echo-headers"));
+  this->loop_.Run();
+
+  EXPECT_NE(string::npos,
+            delegate.data.find("user-agent: Override that header\r\n"));
+  EXPECT_NE(string::npos,
+            delegate.data.find("Authorization: Basic user:passwd\r\n"));
+
+  EXPECT_EQ(string::npos, delegate.data.find("\nAccept:"));
+  EXPECT_EQ(string::npos, delegate.data.find("X-Foo: Invalid"));
+  EXPECT_EQ(string::npos, delegate.data.find("X-Bar: I do not"));
+}
+
 namespace {
 class PausingHttpFetcherTestDelegate : public HttpFetcherDelegate {
  public:
diff --git a/common/libcurl_http_fetcher.cc b/common/libcurl_http_fetcher.cc
index 761b74e..725bdd4 100644
--- a/common/libcurl_http_fetcher.cc
+++ b/common/libcurl_http_fetcher.cc
@@ -128,24 +128,32 @@
     CHECK_EQ(curl_easy_setopt(curl_handle_, CURLOPT_POSTFIELDSIZE,
                               post_data_.size()),
              CURLE_OK);
+  }
 
+  // Setup extra HTTP headers.
+  if (curl_http_headers_) {
+    curl_slist_free_all(curl_http_headers_);
+    curl_http_headers_ = nullptr;
+  }
+  for (const auto& header : extra_headers_) {
+    // curl_slist_append() copies the string.
+    curl_http_headers_ =
+        curl_slist_append(curl_http_headers_, header.second.c_str());
+  }
+  if (post_data_set_) {
     // Set the Content-Type HTTP header, if one was specifically set.
-    CHECK(!curl_http_headers_);
     if (post_content_type_ != kHttpContentTypeUnspecified) {
-      const string content_type_attr =
-        base::StringPrintf("Content-Type: %s",
-                           GetHttpContentTypeString(post_content_type_));
-      curl_http_headers_ = curl_slist_append(nullptr,
-                                             content_type_attr.c_str());
-      CHECK(curl_http_headers_);
-      CHECK_EQ(
-          curl_easy_setopt(curl_handle_, CURLOPT_HTTPHEADER,
-                           curl_http_headers_),
-          CURLE_OK);
+      const string content_type_attr = base::StringPrintf(
+          "Content-Type: %s", GetHttpContentTypeString(post_content_type_));
+      curl_http_headers_ =
+          curl_slist_append(curl_http_headers_, content_type_attr.c_str());
     } else {
       LOG(WARNING) << "no content type set, using libcurl default";
     }
   }
+  CHECK_EQ(
+      curl_easy_setopt(curl_handle_, CURLOPT_HTTPHEADER, curl_http_headers_),
+      CURLE_OK);
 
   if (bytes_downloaded_ > 0 || download_length_) {
     // Resume from where we left off.
@@ -311,6 +319,17 @@
   }
 }
 
+void LibcurlHttpFetcher::SetHeader(const string& header_name,
+                                   const string& header_value) {
+  string header_line = header_name + ": " + header_value;
+  // Avoid the space if no data on the right side of the semicolon.
+  if (header_value.empty())
+    header_line = header_name + ":";
+  TEST_AND_RETURN(header_line.find('\n') == string::npos);
+  TEST_AND_RETURN(header_name.find(':') == string::npos);
+  extra_headers_[base::ToLowerASCII(header_name)] = header_line;
+}
+
 void LibcurlHttpFetcher::CurlPerformOnce() {
   CHECK(transfer_in_progress_);
   int running_handles = 0;
diff --git a/common/libcurl_http_fetcher.h b/common/libcurl_http_fetcher.h
index 66dbb18..5a64236 100644
--- a/common/libcurl_http_fetcher.h
+++ b/common/libcurl_http_fetcher.h
@@ -57,6 +57,10 @@
   // cannot be resumed.
   void TerminateTransfer() override;
 
+  // Pass the headers to libcurl.
+  void SetHeader(const std::string& header_name,
+                 const std::string& header_value) override;
+
   // Suspend the transfer by calling curl_easy_pause(CURLPAUSE_ALL).
   void Pause() override;
 
@@ -181,6 +185,9 @@
   CURL* curl_handle_{nullptr};
   struct curl_slist* curl_http_headers_{nullptr};
 
+  // The extra headers that will be sent on each request.
+  std::map<std::string, std::string> extra_headers_;
+
   // Lists of all read(0)/write(1) file descriptors that we're waiting on from
   // the message loop. libcurl may open/close descriptors and switch their
   // directions so maintain two separate lists so that watch conditions can be
diff --git a/common/mock_http_fetcher.cc b/common/mock_http_fetcher.cc
index f3fa70d..d0348f1 100644
--- a/common/mock_http_fetcher.cc
+++ b/common/mock_http_fetcher.cc
@@ -20,6 +20,7 @@
 
 #include <base/bind.h>
 #include <base/logging.h>
+#include <base/strings/string_util.h>
 #include <base/time/time.h>
 #include <gtest/gtest.h>
 
@@ -117,6 +118,11 @@
   delegate_->TransferTerminated(this);
 }
 
+void MockHttpFetcher::SetHeader(const std::string& header_name,
+                                const std::string& header_value) {
+  extra_headers_[base::ToLowerASCII(header_name)] = header_value;
+}
+
 void MockHttpFetcher::Pause() {
   CHECK(!paused_);
   paused_ = true;
diff --git a/common/mock_http_fetcher.h b/common/mock_http_fetcher.h
index 90d34dd..e56318e 100644
--- a/common/mock_http_fetcher.h
+++ b/common/mock_http_fetcher.h
@@ -17,6 +17,7 @@
 #ifndef UPDATE_ENGINE_COMMON_MOCK_HTTP_FETCHER_H_
 #define UPDATE_ENGINE_COMMON_MOCK_HTTP_FETCHER_H_
 
+#include <map>
 #include <string>
 #include <vector>
 
@@ -87,6 +88,9 @@
   // The transfer cannot be resumed.
   void TerminateTransfer() override;
 
+  void SetHeader(const std::string& header_name,
+                 const std::string& header_value) override;
+
   // Suspend the mock transfer.
   void Pause() override;
 
@@ -125,6 +129,9 @@
   // The number of bytes we've sent so far
   size_t sent_size_;
 
+  // The extra headers set.
+  std::map<std::string, std::string> extra_headers_;
+
   // The TaskId of the timeout callback. After each chunk of data sent, we
   // time out for 0s just to make sure that run loop services other clients.
   brillo::MessageLoop::TaskId timeout_id_;
diff --git a/common/multi_range_http_fetcher.h b/common/multi_range_http_fetcher.h
index 8158a22..8a91ead 100644
--- a/common/multi_range_http_fetcher.h
+++ b/common/multi_range_http_fetcher.h
@@ -80,6 +80,11 @@
   // State change: Downloading -> Pending transfer ended
   void TerminateTransfer() override;
 
+  void SetHeader(const std::string& header_name,
+                 const std::string& header_value) override {
+    base_fetcher_->SetHeader(header_name, header_value);
+  }
+
   void Pause() override { base_fetcher_->Pause(); }
 
   void Unpause() override { base_fetcher_->Unpause(); }