Transpose the encoding matrix

Stripe together the encodings from each of the 16 codewords, so that
if a 512-byte DRAM line is knocked out, it affects 256 bits from each
codeword rather than 4096 bits from a single encoded codeword.

Rather than using std::bitset, we directly set and read bits in
the std::vector<uint8_t>, because the striping means that copying it
will now cost not4k in allocation but 64k.

Decode directly to a word, without using list decoding. It seems
we don't need list decoding for the error rates that matter here,
and we never completed the implementation of it anyway.

Declare and test only the full interface, now that it doesn't decompose
quite so neatly.

Bug: 63928581
Test: atest HadamardTest
Change-Id: If022d3f4a8d6fccdf68119d4666f83ce5005bccb
diff --git a/rebootescrow/aidl/default/HadamardUtils.cpp b/rebootescrow/aidl/default/HadamardUtils.cpp
index 5853d2d..8ee77e1 100644
--- a/rebootescrow/aidl/default/HadamardUtils.cpp
+++ b/rebootescrow/aidl/default/HadamardUtils.cpp
@@ -16,6 +16,8 @@
 
 #include <HadamardUtils.h>
 
+#include <limits>
+
 #include <android-base/logging.h>
 
 namespace aidl {
@@ -24,99 +26,52 @@
 namespace rebootescrow {
 namespace hadamard {
 
-constexpr auto BYTE_LENGTH = 8u;
+static inline void or_bit(std::vector<uint8_t>* input, size_t bit, uint8_t val) {
+    (*input)[bit >> 3] |= (val & 1u) << (bit & 7);
+}
 
-std::vector<uint8_t> BitsetToBytes(const std::bitset<ENCODE_LENGTH>& encoded_bits) {
-    CHECK_EQ(0, (encoded_bits.size() % BYTE_LENGTH));
-    std::vector<uint8_t> result;
-    for (size_t i = 0; i < encoded_bits.size(); i += 8) {
-        uint8_t current = 0;
-        // Set each byte starting from the LSB.
-        for (size_t j = 0; j < BYTE_LENGTH; j++) {
-            CHECK_LE(i + j, encoded_bits.size());
-            if (encoded_bits[i + j]) {
-                current |= (1u << j);
-            }
+static inline uint8_t read_bit(const std::vector<uint8_t>& input, size_t bit) {
+    return (input[bit >> 3] >> (bit & 7)) & 1u;
+}
+
+// Apply an error correcting encoding.
+//
+// The error correcting code used is an augmented Hadamard code with
+// k=15, so it takes a 16-bit input and produces a 2^15-bit output.
+// We break the 32-byte key into 16 16-bit codewords and encode
+// each codeword to a 2^15-bit output.
+//
+// To better defend against clustered errors, we stripe together the encoded
+// codewords. Thus if a single 512-byte DRAM line is lost, instead of losing
+// 2^11 bits from the encoding of a single code word, we lose 2^7 bits
+// from the encoding of each of the 16 codewords.
+std::vector<uint8_t> EncodeKey(const std::vector<uint8_t>& input) {
+    CHECK_EQ(input.size(), KEY_SIZE_IN_BYTES);
+    std::vector<uint8_t> result(OUTPUT_SIZE_BYTES, 0);
+    static_assert(OUTPUT_SIZE_BYTES == 64 * 1024);
+    for (size_t i = 0; i < KEY_CODEWORDS; i++) {
+        uint16_t word = input[i * 2 + 1] << 8 | input[i * 2];
+        for (size_t j = 0; j < ENCODE_LENGTH; j++) {
+            uint16_t wi = word & (j + ENCODE_LENGTH);
+            // Sum all the bits in the word and check its parity.
+            wi ^= wi >> 8u;
+            wi ^= wi >> 4u;
+            wi ^= wi >> 2u;
+            wi ^= wi >> 1u;
+            or_bit(&result, (j * KEY_CODEWORDS) + i, wi & 1);
         }
-        result.push_back(current);
     }
     return result;
 }
 
-std::bitset<ENCODE_LENGTH> BytesToBitset(const std::vector<uint8_t>& encoded) {
-    CHECK_EQ(ENCODE_LENGTH, encoded.size() * BYTE_LENGTH);
-
-    std::bitset<ENCODE_LENGTH> result;
-    size_t offset = 0;
-    for (const auto& byte : encoded) {
-        // Set each byte starting from the LSB.
-        for (size_t j = 0; j < BYTE_LENGTH; j++) {
-            result[offset + j] = byte & (1u << j);
-        }
-        offset += BYTE_LENGTH;
-    }
-    return result;
-}
-
-// The encoding is equivalent to multiply the word with the generator matrix (and take the module
-// of 2). Here is an example of encoding a number with 3 bits. The encoded length is thus
-// 2^(3-1) = 4 bits.
-//              |1 1 1 1|     |0|
-//  |0 1 1|  *  |0 0 1 1|  =  |1|
-//              |0 1 0 1|     |1|
-//                            |0|
-std::bitset<ENCODE_LENGTH> EncodeWord(uint16_t word) {
-    std::bitset<ENCODE_LENGTH> result;
-    for (uint64_t i = ENCODE_LENGTH; i < 2 * ENCODE_LENGTH; i++) {
-        uint32_t wi = word & i;
-        // Sum all the bits in the word and check its parity.
-        wi ^= wi >> 8u;
-        wi ^= wi >> 4u;
-        wi ^= wi >> 2u;
-        wi ^= wi >> 1u;
-        result[i - ENCODE_LENGTH] = wi & 1u;
-    }
-    return result;
-}
-
-std::vector<uint8_t> EncodeKey(const std::vector<uint8_t>& key) {
-    CHECK_EQ(KEY_SIZE_IN_BYTES, key.size());
-
-    std::vector<uint8_t> result;
-    for (size_t i = 0; i < key.size(); i += 2) {
-        uint16_t word = static_cast<uint16_t>(key[i + 1]) << BYTE_LENGTH | key[i];
-        auto encoded_bits = EncodeWord(word);
-        auto byte_array = BitsetToBytes(encoded_bits);
-        std::move(byte_array.begin(), byte_array.end(), std::back_inserter(result));
-    }
-    return result;
-}
-
-std::vector<uint8_t> DecodeKey(const std::vector<uint8_t>& encoded) {
-    CHECK_EQ(0, (encoded.size() * 8) % ENCODE_LENGTH);
-    std::vector<uint8_t> result;
-    for (size_t i = 0; i < encoded.size(); i += ENCODE_LENGTH / 8) {
-        auto current =
-                std::vector<uint8_t>{encoded.begin() + i, encoded.begin() + i + ENCODE_LENGTH / 8};
-        auto bits = BytesToBitset(current);
-        auto candidates = DecodeWord(bits);
-        CHECK(!candidates.empty());
-        // TODO(xunchang) Do we want to try other candidates?
-        uint16_t val = candidates.top().second;
-        result.push_back(val & 0xffu);
-        result.push_back(val >> BYTE_LENGTH);
-    }
-
-    return result;
-}
-
-std::priority_queue<std::pair<int32_t, uint16_t>> DecodeWord(
-        const std::bitset<ENCODE_LENGTH>& encoded) {
+// Decode a single codeword. Because of the way codewords are striped together
+// this takes the entire input, plus an offset telling it which word to decode.
+static uint16_t DecodeWord(size_t word, const std::vector<uint8_t>& encoded) {
     std::vector<int32_t> scores;
     scores.reserve(ENCODE_LENGTH);
-    // Convert 0 -> -1 in the encoded bits. e.g [0, 1, 1, 0] -> [-1, 1, 1, -1]
+    // Convert x -> -1^x in the encoded bits. e.g [1, 0, 0, 1] -> [-1, 1, 1, -1]
     for (uint32_t i = 0; i < ENCODE_LENGTH; i++) {
-        scores.push_back(2 * encoded[i] - 1);
+        scores.push_back(1 - 2 * read_bit(encoded, i * KEY_CODEWORDS + word));
     }
 
     // Multiply the hadamard matrix by the transformed input.
@@ -135,19 +90,31 @@
             }
         }
     }
+    auto hiscore = std::numeric_limits<int32_t>::min();
+    uint16_t winner;
+    // TODO(b/146520538): this needs to be constant time
+    for (size_t i = 0; i < ENCODE_LENGTH; i++) {
+        if (scores[i] > hiscore) {
+            winner = i;
+            hiscore = scores[i];
 
-    // Assign the corresponding score to each index; larger score indicates higher probability. e.g.
-    // value 3, encoding [0, 1, 1, 0] -> score: 4
-    // value 7, encoding [1, 0, 0, 1] (3's complement) -> score: -4
-    std::priority_queue<std::pair<int32_t, uint16_t>> candidates;
-    // TODO(xunchang) limit the candidate size since we don't need all of them?
-    for (uint32_t i = 0; i < scores.size(); i++) {
-        candidates.emplace(-scores[i], i);
-        candidates.emplace(scores[i], (1u << CODE_K) | i);
+        } else if (-scores[i] > hiscore) {
+            winner = i | (1 << CODE_K);
+            hiscore = -scores[i];
+        }
     }
+    return winner;
+}
 
-    CHECK_EQ(2 * ENCODE_LENGTH, candidates.size());
-    return candidates;
+std::vector<uint8_t> DecodeKey(const std::vector<uint8_t>& encoded) {
+    CHECK_EQ(OUTPUT_SIZE_BYTES, encoded.size());
+    std::vector<uint8_t> result(KEY_SIZE_IN_BYTES, 0);
+    for (size_t i = 0; i < KEY_CODEWORDS; i++) {
+        uint16_t val = DecodeWord(i, encoded);
+        result[i * CODEWORD_BYTES] = val & 0xffu;
+        result[i * CODEWORD_BYTES + 1] = val >> 8u;
+    }
+    return result;
 }
 
 }  // namespace hadamard