Merge "Transpose the encoding matrix"
diff --git a/rebootescrow/aidl/default/HadamardUtils.cpp b/rebootescrow/aidl/default/HadamardUtils.cpp
index 5853d2d..8ee77e1 100644
--- a/rebootescrow/aidl/default/HadamardUtils.cpp
+++ b/rebootescrow/aidl/default/HadamardUtils.cpp
@@ -16,6 +16,8 @@
 
 #include <HadamardUtils.h>
 
+#include <limits>
+
 #include <android-base/logging.h>
 
 namespace aidl {
@@ -24,99 +26,52 @@
 namespace rebootescrow {
 namespace hadamard {
 
-constexpr auto BYTE_LENGTH = 8u;
+static inline void or_bit(std::vector<uint8_t>* input, size_t bit, uint8_t val) {
+    (*input)[bit >> 3] |= (val & 1u) << (bit & 7);
+}
 
-std::vector<uint8_t> BitsetToBytes(const std::bitset<ENCODE_LENGTH>& encoded_bits) {
-    CHECK_EQ(0, (encoded_bits.size() % BYTE_LENGTH));
-    std::vector<uint8_t> result;
-    for (size_t i = 0; i < encoded_bits.size(); i += 8) {
-        uint8_t current = 0;
-        // Set each byte starting from the LSB.
-        for (size_t j = 0; j < BYTE_LENGTH; j++) {
-            CHECK_LE(i + j, encoded_bits.size());
-            if (encoded_bits[i + j]) {
-                current |= (1u << j);
-            }
+static inline uint8_t read_bit(const std::vector<uint8_t>& input, size_t bit) {
+    return (input[bit >> 3] >> (bit & 7)) & 1u;
+}
+
+// Apply an error correcting encoding.
+//
+// The error correcting code used is an augmented Hadamard code with
+// k=15, so it takes a 16-bit input and produces a 2^15-bit output.
+// We break the 32-byte key into 16 16-bit codewords and encode
+// each codeword to a 2^15-bit output.
+//
+// To better defend against clustered errors, we stripe together the encoded
+// codewords. Thus if a single 512-byte DRAM line is lost, instead of losing
+// 2^11 bits from the encoding of a single code word, we lose 2^7 bits
+// from the encoding of each of the 16 codewords.
+std::vector<uint8_t> EncodeKey(const std::vector<uint8_t>& input) {
+    CHECK_EQ(input.size(), KEY_SIZE_IN_BYTES);
+    std::vector<uint8_t> result(OUTPUT_SIZE_BYTES, 0);
+    static_assert(OUTPUT_SIZE_BYTES == 64 * 1024);
+    for (size_t i = 0; i < KEY_CODEWORDS; i++) {
+        uint16_t word = input[i * 2 + 1] << 8 | input[i * 2];
+        for (size_t j = 0; j < ENCODE_LENGTH; j++) {
+            uint16_t wi = word & (j + ENCODE_LENGTH);
+            // Sum all the bits in the word and check its parity.
+            wi ^= wi >> 8u;
+            wi ^= wi >> 4u;
+            wi ^= wi >> 2u;
+            wi ^= wi >> 1u;
+            or_bit(&result, (j * KEY_CODEWORDS) + i, wi & 1);
         }
-        result.push_back(current);
     }
     return result;
 }
 
-std::bitset<ENCODE_LENGTH> BytesToBitset(const std::vector<uint8_t>& encoded) {
-    CHECK_EQ(ENCODE_LENGTH, encoded.size() * BYTE_LENGTH);
-
-    std::bitset<ENCODE_LENGTH> result;
-    size_t offset = 0;
-    for (const auto& byte : encoded) {
-        // Set each byte starting from the LSB.
-        for (size_t j = 0; j < BYTE_LENGTH; j++) {
-            result[offset + j] = byte & (1u << j);
-        }
-        offset += BYTE_LENGTH;
-    }
-    return result;
-}
-
-// The encoding is equivalent to multiply the word with the generator matrix (and take the module
-// of 2). Here is an example of encoding a number with 3 bits. The encoded length is thus
-// 2^(3-1) = 4 bits.
-//              |1 1 1 1|     |0|
-//  |0 1 1|  *  |0 0 1 1|  =  |1|
-//              |0 1 0 1|     |1|
-//                            |0|
-std::bitset<ENCODE_LENGTH> EncodeWord(uint16_t word) {
-    std::bitset<ENCODE_LENGTH> result;
-    for (uint64_t i = ENCODE_LENGTH; i < 2 * ENCODE_LENGTH; i++) {
-        uint32_t wi = word & i;
-        // Sum all the bits in the word and check its parity.
-        wi ^= wi >> 8u;
-        wi ^= wi >> 4u;
-        wi ^= wi >> 2u;
-        wi ^= wi >> 1u;
-        result[i - ENCODE_LENGTH] = wi & 1u;
-    }
-    return result;
-}
-
-std::vector<uint8_t> EncodeKey(const std::vector<uint8_t>& key) {
-    CHECK_EQ(KEY_SIZE_IN_BYTES, key.size());
-
-    std::vector<uint8_t> result;
-    for (size_t i = 0; i < key.size(); i += 2) {
-        uint16_t word = static_cast<uint16_t>(key[i + 1]) << BYTE_LENGTH | key[i];
-        auto encoded_bits = EncodeWord(word);
-        auto byte_array = BitsetToBytes(encoded_bits);
-        std::move(byte_array.begin(), byte_array.end(), std::back_inserter(result));
-    }
-    return result;
-}
-
-std::vector<uint8_t> DecodeKey(const std::vector<uint8_t>& encoded) {
-    CHECK_EQ(0, (encoded.size() * 8) % ENCODE_LENGTH);
-    std::vector<uint8_t> result;
-    for (size_t i = 0; i < encoded.size(); i += ENCODE_LENGTH / 8) {
-        auto current =
-                std::vector<uint8_t>{encoded.begin() + i, encoded.begin() + i + ENCODE_LENGTH / 8};
-        auto bits = BytesToBitset(current);
-        auto candidates = DecodeWord(bits);
-        CHECK(!candidates.empty());
-        // TODO(xunchang) Do we want to try other candidates?
-        uint16_t val = candidates.top().second;
-        result.push_back(val & 0xffu);
-        result.push_back(val >> BYTE_LENGTH);
-    }
-
-    return result;
-}
-
-std::priority_queue<std::pair<int32_t, uint16_t>> DecodeWord(
-        const std::bitset<ENCODE_LENGTH>& encoded) {
+// Decode a single codeword. Because of the way codewords are striped together
+// this takes the entire input, plus an offset telling it which word to decode.
+static uint16_t DecodeWord(size_t word, const std::vector<uint8_t>& encoded) {
     std::vector<int32_t> scores;
     scores.reserve(ENCODE_LENGTH);
-    // Convert 0 -> -1 in the encoded bits. e.g [0, 1, 1, 0] -> [-1, 1, 1, -1]
+    // Convert x -> -1^x in the encoded bits. e.g [1, 0, 0, 1] -> [-1, 1, 1, -1]
     for (uint32_t i = 0; i < ENCODE_LENGTH; i++) {
-        scores.push_back(2 * encoded[i] - 1);
+        scores.push_back(1 - 2 * read_bit(encoded, i * KEY_CODEWORDS + word));
     }
 
     // Multiply the hadamard matrix by the transformed input.
@@ -135,19 +90,31 @@
             }
         }
     }
+    auto hiscore = std::numeric_limits<int32_t>::min();
+    uint16_t winner;
+    // TODO(b/146520538): this needs to be constant time
+    for (size_t i = 0; i < ENCODE_LENGTH; i++) {
+        if (scores[i] > hiscore) {
+            winner = i;
+            hiscore = scores[i];
 
-    // Assign the corresponding score to each index; larger score indicates higher probability. e.g.
-    // value 3, encoding [0, 1, 1, 0] -> score: 4
-    // value 7, encoding [1, 0, 0, 1] (3's complement) -> score: -4
-    std::priority_queue<std::pair<int32_t, uint16_t>> candidates;
-    // TODO(xunchang) limit the candidate size since we don't need all of them?
-    for (uint32_t i = 0; i < scores.size(); i++) {
-        candidates.emplace(-scores[i], i);
-        candidates.emplace(scores[i], (1u << CODE_K) | i);
+        } else if (-scores[i] > hiscore) {
+            winner = i | (1 << CODE_K);
+            hiscore = -scores[i];
+        }
     }
+    return winner;
+}
 
-    CHECK_EQ(2 * ENCODE_LENGTH, candidates.size());
-    return candidates;
+std::vector<uint8_t> DecodeKey(const std::vector<uint8_t>& encoded) {
+    CHECK_EQ(OUTPUT_SIZE_BYTES, encoded.size());
+    std::vector<uint8_t> result(KEY_SIZE_IN_BYTES, 0);
+    for (size_t i = 0; i < KEY_CODEWORDS; i++) {
+        uint16_t val = DecodeWord(i, encoded);
+        result[i * CODEWORD_BYTES] = val & 0xffu;
+        result[i * CODEWORD_BYTES + 1] = val >> 8u;
+    }
+    return result;
 }
 
 }  // namespace hadamard
diff --git a/rebootescrow/aidl/default/HadamardUtils.h b/rebootescrow/aidl/default/HadamardUtils.h
index 21cfc78..85e635f 100644
--- a/rebootescrow/aidl/default/HadamardUtils.h
+++ b/rebootescrow/aidl/default/HadamardUtils.h
@@ -18,9 +18,6 @@
 
 #include <stdint.h>
 
-#include <bitset>
-#include <queue>
-#include <utility>
 #include <vector>
 
 namespace aidl {
@@ -29,18 +26,14 @@
 namespace rebootescrow {
 namespace hadamard {
 
-constexpr uint32_t CODE_K = 15;
+constexpr auto BYTE_LENGTH = 8u;
+constexpr auto CODEWORD_BYTES = 2u;  // uint16_t
+constexpr auto CODEWORD_BITS = CODEWORD_BYTES * BYTE_LENGTH;
+constexpr uint32_t CODE_K = CODEWORD_BITS - 1;
 constexpr uint32_t ENCODE_LENGTH = 1u << CODE_K;
-constexpr auto KEY_SIZE_IN_BYTES = 32u;
-
-// Encodes a 2 bytes word with hadamard code. The encoding expands a word of k+1 bits to a 2^k
-// bitset. Returns the encoded bitset.
-std::bitset<ENCODE_LENGTH> EncodeWord(uint16_t word);
-
-// Decodes the input bitset, and returns a sorted list of pair with (score, value). The value with
-// a higher score indicates a greater likehood.
-std::priority_queue<std::pair<int32_t, uint16_t>> DecodeWord(
-        const std::bitset<ENCODE_LENGTH>& encoded);
+constexpr auto KEY_CODEWORDS = 16u;
+constexpr auto KEY_SIZE_IN_BYTES = KEY_CODEWORDS * CODEWORD_BYTES;
+constexpr auto OUTPUT_SIZE_BYTES = KEY_CODEWORDS * ENCODE_LENGTH / BYTE_LENGTH;
 
 // Encodes a key that has a size of KEY_SIZE_IN_BYTES. Returns a byte array representation of the
 // encoded bitset. So a 32 bytes key will expand to 16*(2^15) bits = 64KiB.
@@ -49,12 +42,6 @@
 // Given a byte array representation of the encoded keys, decodes it and return the result.
 std::vector<uint8_t> DecodeKey(const std::vector<uint8_t>& encoded);
 
-// Converts a bitset of length |ENCODE_LENGTH| to a byte array.
-std::vector<uint8_t> BitsetToBytes(const std::bitset<ENCODE_LENGTH>& encoded_bits);
-
-// Converts a byte array of encoded words back to the bitset.
-std::bitset<ENCODE_LENGTH> BytesToBitset(const std::vector<uint8_t>& encoded);
-
 }  // namespace hadamard
 }  // namespace rebootescrow
 }  // namespace hardware
diff --git a/rebootescrow/aidl/default/HadamardUtilsTest.cpp b/rebootescrow/aidl/default/HadamardUtilsTest.cpp
index e397e76..1c9a2fb 100644
--- a/rebootescrow/aidl/default/HadamardUtilsTest.cpp
+++ b/rebootescrow/aidl/default/HadamardUtilsTest.cpp
@@ -17,110 +17,35 @@
 #include <stdint.h>
 #include <random>
 
-#include <bitset>
-#include <utility>
-#include <vector>
-
 #include <gtest/gtest.h>
 
 #include <HadamardUtils.h>
 
 using namespace aidl::android::hardware::rebootescrow::hadamard;
 
-class HadamardTest : public testing::Test {
-  protected:
-    void SetUp() override {
-        auto ones = std::bitset<ENCODE_LENGTH>{}.set();
-        // Expects 0x4000 to encode as top half as ones, and lower half as zeros. i.e.
-        // [1, 1 .. 1, 0, 0 .. 0]
-        expected_half_size_ = ones << half_size_;
+class HadamardTest : public testing::Test {};
 
-        // Expects 0x1 to encode as interleaved 1 and 0s  i.e. [1, 0, 1, 0 ..]
-        expected_one_ = ones;
-        for (uint32_t i = ENCODE_LENGTH / 2; i >= 1; i /= 2) {
-            expected_one_ ^= (expected_one_ >> i);
+static void AddError(std::vector<uint8_t>* data) {
+    for (size_t i = 0; i < data->size(); i++) {
+        for (size_t j = 0; j < BYTE_LENGTH; j++) {
+            if (random() % 100 < 47) {
+                (*data)[i] ^= (1 << j);
+            }
         }
     }
-
-    uint16_t half_size_ = ENCODE_LENGTH / 2;
-    std::bitset<ENCODE_LENGTH> expected_one_;
-    std::bitset<ENCODE_LENGTH> expected_half_size_;
-};
-
-static void AddError(std::bitset<ENCODE_LENGTH>* corrupted_bits) {
-    // The hadamard code has a hamming distance of ENCODE_LENGTH/2. So we should always be able to
-    // correct the data if less than a quarter of the encoded bits are corrupted.
-    auto corrupted_max = 0.24f * corrupted_bits->size();
-    auto corrupted_num = 0;
-    for (size_t i = 0; i < corrupted_bits->size() && corrupted_num < corrupted_max; i++) {
-        if (random() % 2 == 0) {
-            (*corrupted_bits)[i] = !(*corrupted_bits)[i];
-            corrupted_num += 1;
-        }
-    }
-}
-
-static void EncodeAndDecodeKeys(const std::vector<uint8_t>& key) {
-    auto encoded = EncodeKey(key);
-    ASSERT_EQ(64 * 1024, encoded.size());
-    auto decoded = DecodeKey(encoded);
-    ASSERT_EQ(key, std::vector<uint8_t>(decoded.begin(), decoded.begin() + key.size()));
-}
-
-TEST_F(HadamardTest, Encode_smoke) {
-    ASSERT_EQ(expected_half_size_, EncodeWord(half_size_));
-    ASSERT_EQ(expected_one_, EncodeWord(1));
-    // Check the complement of 1.
-    ASSERT_EQ(~expected_one_, EncodeWord(1u << CODE_K | 1u));
-}
-
-TEST_F(HadamardTest, Decode_smoke) {
-    auto candidate = DecodeWord(expected_half_size_);
-    auto expected = std::pair<int32_t, uint16_t>{ENCODE_LENGTH, half_size_};
-    ASSERT_EQ(expected, candidate.top());
-
-    candidate = DecodeWord(expected_one_);
-    expected = std::pair<int32_t, uint16_t>{ENCODE_LENGTH, 1};
-    ASSERT_EQ(expected, candidate.top());
 }
 
 TEST_F(HadamardTest, Decode_error_correction) {
     constexpr auto iteration = 10;
     for (int i = 0; i < iteration; i++) {
-        uint16_t word = random() % (ENCODE_LENGTH * 2);
-        auto corrupted_bits = EncodeWord(word);
-        AddError(&corrupted_bits);
-
-        auto candidate = DecodeWord(corrupted_bits);
-        ASSERT_EQ(word, candidate.top().second);
+        std::vector<uint8_t> key;
+        for (int j = 0; j < KEY_SIZE_IN_BYTES; j++) {
+            key.emplace_back(random() & 0xff);
+        }
+        auto encoded = EncodeKey(key);
+        ASSERT_EQ(64 * 1024, encoded.size());
+        AddError(&encoded);
+        auto decoded = DecodeKey(encoded);
+        ASSERT_EQ(key, std::vector<uint8_t>(decoded.begin(), decoded.begin() + key.size()));
     }
 }
-
-TEST_F(HadamardTest, BytesToBitset_smoke) {
-    auto bytes = BitsetToBytes(expected_one_);
-
-    auto read_back = BytesToBitset(bytes);
-    ASSERT_EQ(expected_one_, read_back);
-}
-
-TEST_F(HadamardTest, EncodeAndDecodeKey) {
-    std::vector<uint8_t> KEY_1{
-            0xA5, 0x00, 0xFF, 0x01, 0xA5, 0x5a, 0xAA, 0x55, 0x00, 0xD3, 0x2A,
-            0x8C, 0x2E, 0x83, 0x0E, 0x65, 0x9E, 0x8D, 0xC6, 0xAC, 0x1E, 0x83,
-            0x21, 0xB3, 0x95, 0x02, 0x89, 0x64, 0x64, 0x92, 0x12, 0x1F,
-    };
-    std::vector<uint8_t> KEY_2{
-            0xFF, 0x00, 0x00, 0xAA, 0x5A, 0x19, 0x20, 0x71, 0x9F, 0xFB, 0xDA,
-            0xB6, 0x2D, 0x06, 0xD5, 0x49, 0x7E, 0xEF, 0x63, 0xAC, 0x18, 0xFF,
-            0x5A, 0xA3, 0x40, 0xBB, 0x64, 0xFA, 0x67, 0xC1, 0x10, 0x18,
-    };
-
-    EncodeAndDecodeKeys(KEY_1);
-    EncodeAndDecodeKeys(KEY_2);
-
-    std::vector<uint8_t> key;
-    for (uint8_t i = 0; i < KEY_SIZE_IN_BYTES; i++) {
-        key.push_back(i);
-    };
-    EncodeAndDecodeKeys(key);
-}