create_snapshot: Derive source hash from device

Verity hash extracted from device merkel tree is stored
in a file. Use that to derive the source hash

Bug: 371304627
Test: create_snapshot --source=system.pb --target=$OUT/system.img --merkel_tree
Change-Id: I670318496054d5aa9f81efa5954e4fdd28bc3e19
Signed-off-by: Akilesh Kailash <akailash@google.com>
diff --git a/fs_mgr/libsnapshot/Android.bp b/fs_mgr/libsnapshot/Android.bp
index 850e248..b5b976a 100644
--- a/fs_mgr/libsnapshot/Android.bp
+++ b/fs_mgr/libsnapshot/Android.bp
@@ -492,7 +492,10 @@
     host_supported: true,
     device_supported: false,
 
-    srcs: ["libsnapshot_cow/create_cow.cpp"],
+    srcs: [
+        "libsnapshot_cow/create_cow.cpp",
+        "android/snapshot/snapshot.proto",
+    ],
 
     cflags: [
         "-Wall",
@@ -502,14 +505,21 @@
     static_libs: [
         "liblog",
         "libbase",
+        "libfstab",
         "libext4_utils",
         "libsnapshot_cow",
         "libcrypto",
         "libbrotli",
         "libz",
+        "libdm",
         "liblz4",
         "libzstd",
         "libgflags",
+        "libavb",
+        "libext2_uuid",
+        "libfs_avb",
+        "libcrypto",
+        "libprotobuf-cpp-lite",
     ],
     shared_libs: [
     ],
diff --git a/fs_mgr/libsnapshot/libsnapshot_cow/create_cow.cpp b/fs_mgr/libsnapshot/libsnapshot_cow/create_cow.cpp
index 5497b72..fd4e7da 100644
--- a/fs_mgr/libsnapshot/libsnapshot_cow/create_cow.cpp
+++ b/fs_mgr/libsnapshot/libsnapshot_cow/create_cow.cpp
@@ -8,6 +8,7 @@
 
 #include <condition_variable>
 #include <cstring>
+#include <fstream>
 #include <future>
 #include <iostream>
 #include <limits>
@@ -17,26 +18,27 @@
 #include <unordered_map>
 #include <vector>
 
-#include <android-base/file.h>
-#include <android-base/logging.h>
-#include <android-base/stringprintf.h>
-#include <android-base/unique_fd.h>
-#include <ext4_utils/ext4_utils.h>
-#include <storage_literals/storage_literals.h>
-
 #include <android-base/chrono_utils.h>
+#include <android-base/file.h>
+#include <android-base/hex.h>
+#include <android-base/logging.h>
 #include <android-base/scopeguard.h>
+#include <android-base/stringprintf.h>
 #include <android-base/strings.h>
-
+#include <android-base/unique_fd.h>
+#include <android/snapshot/snapshot.pb.h>
+#include <ext4_utils/ext4_utils.h>
+#include <fs_avb/fs_avb_util.h>
 #include <gflags/gflags.h>
 #include <libsnapshot/cow_writer.h>
-
 #include <openssl/sha.h>
+#include <storage_literals/storage_literals.h>
 
 DEFINE_string(source, "", "Source partition image");
 DEFINE_string(target, "", "Target partition image");
 DEFINE_string(compression, "lz4",
               "Compression algorithm. Default is set to lz4. Available options: lz4, zstd, gz");
+DEFINE_bool(merkel_tree, false, "If true, source image hash is obtained from verity merkel tree");
 
 namespace android {
 namespace snapshot {
@@ -51,7 +53,8 @@
 class CreateSnapshot {
   public:
     CreateSnapshot(const std::string& src_file, const std::string& target_file,
-                   const std::string& patch_file, const std::string& compression);
+                   const std::string& patch_file, const std::string& compression,
+                   const bool& merkel_tree);
     bool CreateSnapshotPatch();
 
   private:
@@ -108,6 +111,14 @@
     bool WriteOrderedSnapshots();
     bool WriteNonOrderedSnapshots();
     bool VerifyMergeOrder();
+
+    bool CalculateDigest(const void* buffer, size_t size, const void* salt, uint32_t salt_length,
+                         uint8_t* digest);
+    bool ParseSourceMerkelTree();
+
+    bool use_merkel_tree_ = false;
+    std::vector<uint8_t> target_salt_;
+    std::vector<uint8_t> source_salt_;
 };
 
 void CreateSnapshotLogger(android::base::LogId, android::base::LogSeverity severity, const char*,
@@ -120,8 +131,12 @@
 }
 
 CreateSnapshot::CreateSnapshot(const std::string& src_file, const std::string& target_file,
-                               const std::string& patch_file, const std::string& compression)
-    : src_file_(src_file), target_file_(target_file), patch_file_(patch_file) {
+                               const std::string& patch_file, const std::string& compression,
+                               const bool& merkel_tree)
+    : src_file_(src_file),
+      target_file_(target_file),
+      patch_file_(patch_file),
+      use_merkel_tree_(merkel_tree) {
     if (!compression.empty()) {
         compression_ = compression;
     }
@@ -156,7 +171,76 @@
     if (!PrepareParse(src_file_, false)) {
         return false;
     }
-    return ParsePartition();
+
+    if (use_merkel_tree_) {
+        return ParseSourceMerkelTree();
+    } else {
+        return ParsePartition();
+    }
+}
+
+bool CreateSnapshot::CalculateDigest(const void* buffer, size_t size, const void* salt,
+                                     uint32_t salt_length, uint8_t* digest) {
+    SHA256_CTX ctx;
+    if (SHA256_Init(&ctx) != 1) {
+        return false;
+    }
+    if (SHA256_Update(&ctx, salt, salt_length) != 1) {
+        return false;
+    }
+    if (SHA256_Update(&ctx, buffer, size) != 1) {
+        return false;
+    }
+    if (SHA256_Final(digest, &ctx) != 1) {
+        return false;
+    }
+    return true;
+}
+
+bool CreateSnapshot::ParseSourceMerkelTree() {
+    std::string fname = android::base::Basename(target_file_.c_str());
+    std::string partitionName = fname.substr(0, fname.find(".img"));
+
+    auto vbmeta = android::fs_mgr::LoadAndVerifyVbmetaByPath(
+            target_file_, partitionName, "", true, false, false, nullptr, nullptr, nullptr);
+    if (vbmeta == nullptr) {
+        LOG(ERROR) << "LoadAndVerifyVbmetaByPath failed for partition: " << partitionName;
+        return false;
+    }
+    auto descriptor = android::fs_mgr::GetHashtreeDescriptor(partitionName, std::move(*vbmeta));
+    if (descriptor == nullptr) {
+        LOG(ERROR) << "GetHashtreeDescriptor failed for partition: " << partitionName;
+        return false;
+    }
+
+    std::fstream input(src_file_, std::ios::in | std::ios::binary);
+    VerityHash hash;
+    if (!hash.ParseFromIstream(&input)) {
+        LOG(ERROR) << "Failed to parse message.";
+        return false;
+    }
+
+    std::string source_salt = hash.salt();
+    source_salt.erase(std::remove(source_salt.begin(), source_salt.end(), '\0'), source_salt.end());
+    if (!android::base::HexToBytes(source_salt, &source_salt_)) {
+        LOG(ERROR) << "HexToBytes conversion failed for source salt: " << source_salt;
+        return false;
+    }
+
+    std::string target_salt = descriptor->salt;
+    if (!android::base::HexToBytes(target_salt, &target_salt_)) {
+        LOG(ERROR) << "HexToBytes conversion failed for target salt: " << target_salt;
+        return false;
+    }
+
+    std::vector<uint8_t> digest(32, 0);
+    for (int i = 0; i < hash.block_hash_size(); i++) {
+        CalculateDigest(hash.block_hash(i).data(), hash.block_hash(i).size(), target_salt_.data(),
+                        target_salt_.size(), digest.data());
+        source_block_hash_[ToHexString(digest.data(), 32)] = i;
+    }
+
+    return true;
 }
 
 /*
@@ -386,10 +470,21 @@
         while (num_blocks) {
             const void* bufptr = (char*)buffer.get() + buffer_offset;
             uint64_t blkindex = foffset / BLOCK_SZ;
+            std::string hash;
 
-            uint8_t checksum[32];
-            SHA256(bufptr, BLOCK_SZ, checksum);
-            std::string hash = ToHexString(checksum, sizeof(checksum));
+            if (create_snapshot_patch_ && use_merkel_tree_) {
+                std::vector<uint8_t> digest(32, 0);
+                CalculateDigest(bufptr, BLOCK_SZ, target_salt_.data(), target_salt_.size(),
+                                digest.data());
+                std::vector<uint8_t> final_digest(32, 0);
+                CalculateDigest(digest.data(), digest.size(), source_salt_.data(),
+                                source_salt_.size(), final_digest.data());
+                hash = ToHexString(final_digest.data(), final_digest.size());
+            } else {
+                uint8_t checksum[32];
+                SHA256(bufptr, BLOCK_SZ, checksum);
+                hash = ToHexString(checksum, sizeof(checksum));
+            }
 
             if (create_snapshot_patch_) {
                 PrepareMergeBlock(bufptr, blkindex, hash);
@@ -497,7 +592,7 @@
     auto parts = android::base::Split(fname, ".");
     std::string snapshotfile = parts[0] + ".patch";
     android::snapshot::CreateSnapshot snapshot(FLAGS_source, FLAGS_target, snapshotfile,
-                                               FLAGS_compression);
+                                               FLAGS_compression, FLAGS_merkel_tree);
 
     if (!snapshot.CreateSnapshotPatch()) {
         LOG(ERROR) << "Snapshot creation failed";