[res] Make type iteration more efficient in aapt2

The iterator type that is only used in aapt2 currently
used to perform a binary search for every dereferencing call of
a sparse type table, while it knows for sure the position of the
next element in constant time.

+ Add a test for a sparse type iteration

Test: atest libandroidfw_tests
Flag: EXEMPT bugfix / refactor
Change-Id: I70f74b1a1cfb4deffd1592a87c31501388eaf46c
diff --git a/libs/androidfw/Android.bp b/libs/androidfw/Android.bp
index 1bc15d7..cc4a29b 100644
--- a/libs/androidfw/Android.bp
+++ b/libs/androidfw/Android.bp
@@ -199,6 +199,7 @@
         // This is to suppress warnings/errors from gtest
         "-Wno-unnamed-type-template-args",
     ],
+    require_root: true,
     srcs: [
         // Helpers/infra for testing.
         "tests/CommonHelpers.cpp",
diff --git a/libs/androidfw/TypeWrappers.cpp b/libs/androidfw/TypeWrappers.cpp
index 70d14a1..9704634 100644
--- a/libs/androidfw/TypeWrappers.cpp
+++ b/libs/androidfw/TypeWrappers.cpp
@@ -16,8 +16,6 @@
 
 #include <androidfw/TypeWrappers.h>
 
-#include <algorithm>
-
 namespace android {
 
 TypeVariant::TypeVariant(const ResTable_type* data) : data(data), mLength(dtohl(data->entryCount)) {
@@ -31,30 +29,44 @@
             ALOGE("Type's entry indices extend beyond its boundaries");
             mLength = 0;
         } else {
-          mLength = ResTable_sparseTypeEntry{entryIndices[entryCount - 1]}.idx + 1;
+          mLength = dtohs(ResTable_sparseTypeEntry{entryIndices[entryCount - 1]}.idx) + 1;
         }
     }
 }
 
 TypeVariant::iterator& TypeVariant::iterator::operator++() {
-    mIndex++;
+    ++mIndex;
     if (mIndex > mTypeVariant->mLength) {
         mIndex = mTypeVariant->mLength;
     }
+
+    const ResTable_type* type = mTypeVariant->data;
+    if ((type->flags & ResTable_type::FLAG_SPARSE) == 0) {
+      return *this;
+    }
+
+    // Need to adjust |mSparseIndex| as well if we've passed its current element.
+    const uint32_t entryCount = dtohl(type->entryCount);
+    const auto entryIndices = reinterpret_cast<const uint32_t*>(
+        reinterpret_cast<uintptr_t>(type) + dtohs(type->header.headerSize));
+    if (mSparseIndex >= entryCount) {
+      return *this; // done
+    }
+    const auto element = (const ResTable_sparseTypeEntry*)(entryIndices + mSparseIndex);
+    if (mIndex > dtohs(element->idx)) {
+      ++mSparseIndex;
+    }
+
     return *this;
 }
 
-static bool keyCompare(uint32_t entry, uint16_t index) {
-  return dtohs(ResTable_sparseTypeEntry{entry}.idx) < index;
-}
-
 const ResTable_entry* TypeVariant::iterator::operator*() const {
-    const ResTable_type* type = mTypeVariant->data;
     if (mIndex >= mTypeVariant->mLength) {
-        return NULL;
+        return nullptr;
     }
 
-    const uint32_t entryCount = dtohl(mTypeVariant->data->entryCount);
+    const ResTable_type* type = mTypeVariant->data;
+    const uint32_t entryCount = dtohl(type->entryCount);
     const uintptr_t containerEnd = reinterpret_cast<uintptr_t>(type)
             + dtohl(type->header.size);
     const uint32_t* const entryIndices = reinterpret_cast<const uint32_t*>(
@@ -63,18 +75,19 @@
                                     sizeof(uint16_t) : sizeof(uint32_t);
     if (reinterpret_cast<uintptr_t>(entryIndices) + (indexSize * entryCount) > containerEnd) {
         ALOGE("Type's entry indices extend beyond its boundaries");
-        return NULL;
+        return nullptr;
     }
 
     uint32_t entryOffset;
     if (type->flags & ResTable_type::FLAG_SPARSE) {
-      auto iter = std::lower_bound(entryIndices, entryIndices + entryCount, mIndex, keyCompare);
-      if (iter == entryIndices + entryCount
-              || dtohs(ResTable_sparseTypeEntry{*iter}.idx) != mIndex) {
-        return NULL;
+      if (mSparseIndex >= entryCount) {
+        return nullptr;
       }
-
-      entryOffset = static_cast<uint32_t>(dtohs(ResTable_sparseTypeEntry{*iter}.offset)) * 4u;
+      const auto element = (const ResTable_sparseTypeEntry*)(entryIndices + mSparseIndex);
+      if (dtohs(element->idx) != mIndex) {
+        return nullptr;
+      }
+      entryOffset = static_cast<uint32_t>(dtohs(element->offset)) * 4u;
     } else if (type->flags & ResTable_type::FLAG_OFFSET16) {
       auto entryIndices16 = reinterpret_cast<const uint16_t*>(entryIndices);
       entryOffset = offset_from16(entryIndices16[mIndex]);
@@ -83,25 +96,25 @@
     }
 
     if (entryOffset == ResTable_type::NO_ENTRY) {
-        return NULL;
+        return nullptr;
     }
 
     if ((entryOffset & 0x3) != 0) {
         ALOGE("Index %u points to entry with unaligned offset 0x%08x", mIndex, entryOffset);
-        return NULL;
+        return nullptr;
     }
 
     const ResTable_entry* entry = reinterpret_cast<const ResTable_entry*>(
             reinterpret_cast<uintptr_t>(type) + dtohl(type->entriesStart) + entryOffset);
     if (reinterpret_cast<uintptr_t>(entry) > containerEnd - sizeof(*entry)) {
         ALOGE("Entry offset at index %u points outside the Type's boundaries", mIndex);
-        return NULL;
+        return nullptr;
     } else if (reinterpret_cast<uintptr_t>(entry) + entry->size() > containerEnd) {
         ALOGE("Entry at index %u extends beyond Type's boundaries", mIndex);
-        return NULL;
+        return nullptr;
     } else if (entry->size() < sizeof(*entry)) {
         ALOGE("Entry at index %u is too small (%zu)", mIndex, entry->size());
-        return NULL;
+        return nullptr;
     }
     return entry;
 }
diff --git a/libs/androidfw/include/androidfw/TypeWrappers.h b/libs/androidfw/include/androidfw/TypeWrappers.h
index fb2fad6..db641b7 100644
--- a/libs/androidfw/include/androidfw/TypeWrappers.h
+++ b/libs/androidfw/include/androidfw/TypeWrappers.h
@@ -27,24 +27,14 @@
 
     class iterator {
     public:
-        iterator& operator=(const iterator& rhs) {
-            mTypeVariant = rhs.mTypeVariant;
-            mIndex = rhs.mIndex;
-            return *this;
-        }
-
         bool operator==(const iterator& rhs) const {
             return mTypeVariant == rhs.mTypeVariant && mIndex == rhs.mIndex;
         }
 
-        bool operator!=(const iterator& rhs) const {
-            return mTypeVariant != rhs.mTypeVariant || mIndex != rhs.mIndex;
-        }
-
         iterator operator++(int) {
-            uint32_t prevIndex = mIndex;
+            iterator prev = *this;
             operator++();
-            return iterator(mTypeVariant, prevIndex);
+            return prev;
         }
 
         const ResTable_entry* operator->() const {
@@ -60,18 +50,26 @@
 
     private:
         friend struct TypeVariant;
-        iterator(const TypeVariant* tv, uint32_t index)
-            : mTypeVariant(tv), mIndex(index) {}
+
+        enum class Kind { Begin, End };
+        iterator(const TypeVariant* tv, Kind kind)
+            : mTypeVariant(tv) {
+          mSparseIndex = mIndex = kind == Kind::Begin ? 0 : tv->mLength;
+          // mSparseIndex here is technically past the number of sparse entries, but it is still
+          // ok as it is enough to infer that this is the end iterator.
+        }
+
         const TypeVariant* mTypeVariant;
         uint32_t mIndex;
+        uint32_t mSparseIndex;
     };
 
     iterator beginEntries() const {
-        return iterator(this, 0);
+        return iterator(this, iterator::Kind::Begin);
     }
 
     iterator endEntries() const {
-        return iterator(this, mLength);
+        return iterator(this, iterator::Kind::End);
     }
 
     const ResTable_type* data;
diff --git a/libs/androidfw/tests/TypeWrappers_test.cpp b/libs/androidfw/tests/TypeWrappers_test.cpp
index ed30904..d66e058 100644
--- a/libs/androidfw/tests/TypeWrappers_test.cpp
+++ b/libs/androidfw/tests/TypeWrappers_test.cpp
@@ -14,28 +14,42 @@
  * limitations under the License.
  */
 
-#include <algorithm>
 #include <androidfw/ResourceTypes.h>
 #include <androidfw/TypeWrappers.h>
-#include <utils/String8.h>
+#include <androidfw/Util.h>
+
+#include <optional>
+#include <vector>
 
 #include <gtest/gtest.h>
 
 namespace android {
 
-// create a ResTable_type in memory with a vector of Res_value*
-static ResTable_type* createTypeTable(std::vector<Res_value*>& values,
-                             bool compact_entry = false,
-                             bool short_offsets = false)
+using ResValueVector = std::vector<std::optional<Res_value>>;
+
+// create a ResTable_type in memory
+static util::unique_cptr<ResTable_type> createTypeTable(
+    const ResValueVector& in_values, bool compact_entry, bool short_offsets, bool sparse)
 {
+    ResValueVector sparse_values;
+    if (sparse) {
+      std::ranges::copy_if(in_values, std::back_inserter(sparse_values),
+                           [](auto&& v) { return v.has_value(); });
+    }
+    const ResValueVector& values = sparse ? sparse_values : in_values;
+
     ResTable_type t{};
     t.header.type = RES_TABLE_TYPE_TYPE;
     t.header.headerSize = sizeof(t);
     t.header.size = sizeof(t);
     t.id = 1;
-    t.flags = short_offsets ? ResTable_type::FLAG_OFFSET16 : 0;
+    t.flags = sparse
+                  ? ResTable_type::FLAG_SPARSE
+                  : short_offsets ? ResTable_type::FLAG_OFFSET16 : 0;
 
-    t.header.size += values.size() * (short_offsets ? sizeof(uint16_t) : sizeof(uint32_t));
+    t.header.size += values.size() *
+                     (sparse ? sizeof(ResTable_sparseTypeEntry) :
+                         short_offsets ? sizeof(uint16_t) : sizeof(uint32_t));
     t.entriesStart = t.header.size;
     t.entryCount = values.size();
 
@@ -53,9 +67,18 @@
     memcpy(p_header, &t, sizeof(t));
 
     size_t i = 0, entry_offset = 0;
-    uint32_t k = 0;
-    for (auto const& v : values) {
-        if (short_offsets) {
+    uint32_t sparse_index = 0;
+
+    for (auto const& v : in_values) {
+        if (sparse) {
+            if (!v) {
+                ++i;
+                continue;
+            }
+            const auto p = reinterpret_cast<ResTable_sparseTypeEntry*>(p_offsets) + sparse_index++;
+            p->idx = i;
+            p->offset = (entry_offset >> 2) & 0xffffu;
+        } else if (short_offsets) {
             uint16_t *p = reinterpret_cast<uint16_t *>(p_offsets) + i;
             *p = v ? (entry_offset >> 2) & 0xffffu : 0xffffu;
         } else {
@@ -83,62 +106,92 @@
         }
         i++;
     }
-    return reinterpret_cast<ResTable_type*>(data);
+    return util::unique_cptr<ResTable_type>{reinterpret_cast<ResTable_type*>(data)};
 }
 
 TEST(TypeVariantIteratorTest, shouldIterateOverTypeWithoutErrors) {
-    std::vector<Res_value *> values;
+    ResValueVector values;
 
-    Res_value *v1 = new Res_value{};
-    values.push_back(v1);
-
-    values.push_back(nullptr);
-
-    Res_value *v2 = new Res_value{};
-    values.push_back(v2);
-
-    Res_value *v3 = new Res_value{ sizeof(Res_value), 0, Res_value::TYPE_STRING, 0x12345678};
-    values.push_back(v3);
+    values.push_back(std::nullopt);
+    values.push_back(Res_value{});
+    values.push_back(std::nullopt);
+    values.push_back(Res_value{});
+    values.push_back(Res_value{ sizeof(Res_value), 0, Res_value::TYPE_STRING, 0x12345678});
+    values.push_back(std::nullopt);
+    values.push_back(std::nullopt);
+    values.push_back(std::nullopt);
+    values.push_back(Res_value{ sizeof(Res_value), 0, Res_value::TYPE_STRING, 0x87654321});
 
     // test for combinations of compact_entry and short_offsets
-    for (size_t i = 0; i < 4; i++) {
-        bool compact_entry = i & 0x1, short_offsets = i & 0x2;
-        ResTable_type* data = createTypeTable(values, compact_entry, short_offsets);
-        TypeVariant v(data);
+    for (size_t i = 0; i < 8; i++) {
+        bool compact_entry = i & 0x1, short_offsets = i & 0x2, sparse = i & 0x4;
+        auto data = createTypeTable(values, compact_entry, short_offsets, sparse);
+        TypeVariant v(data.get());
 
         TypeVariant::iterator iter = v.beginEntries();
         ASSERT_EQ(uint32_t(0), iter.index());
-        ASSERT_TRUE(NULL != *iter);
-        ASSERT_EQ(uint32_t(0), iter->key());
+        ASSERT_TRUE(NULL == *iter);
         ASSERT_NE(v.endEntries(), iter);
 
-        iter++;
+        ++iter;
 
         ASSERT_EQ(uint32_t(1), iter.index());
-        ASSERT_TRUE(NULL == *iter);
+        ASSERT_TRUE(NULL != *iter);
+        ASSERT_EQ(uint32_t(1), iter->key());
         ASSERT_NE(v.endEntries(), iter);
 
         iter++;
 
         ASSERT_EQ(uint32_t(2), iter.index());
+        ASSERT_TRUE(NULL == *iter);
+        ASSERT_NE(v.endEntries(), iter);
+
+        ++iter;
+
+        ASSERT_EQ(uint32_t(3), iter.index());
         ASSERT_TRUE(NULL != *iter);
-        ASSERT_EQ(uint32_t(2), iter->key());
+        ASSERT_EQ(uint32_t(3), iter->key());
         ASSERT_NE(v.endEntries(), iter);
 
         iter++;
 
-        ASSERT_EQ(uint32_t(3), iter.index());
+        ASSERT_EQ(uint32_t(4), iter.index());
         ASSERT_TRUE(NULL != *iter);
         ASSERT_EQ(iter->is_compact(), compact_entry);
-        ASSERT_EQ(uint32_t(3), iter->key());
+        ASSERT_EQ(uint32_t(4), iter->key());
         ASSERT_EQ(uint32_t(0x12345678), iter->value().data);
         ASSERT_EQ(Res_value::TYPE_STRING, iter->value().dataType);
 
+        ++iter;
+
+        ASSERT_EQ(uint32_t(5), iter.index());
+        ASSERT_TRUE(NULL == *iter);
+        ASSERT_NE(v.endEntries(), iter);
+
+        ++iter;
+
+        ASSERT_EQ(uint32_t(6), iter.index());
+        ASSERT_TRUE(NULL == *iter);
+        ASSERT_NE(v.endEntries(), iter);
+
+        ++iter;
+
+        ASSERT_EQ(uint32_t(7), iter.index());
+        ASSERT_TRUE(NULL == *iter);
+        ASSERT_NE(v.endEntries(), iter);
+
         iter++;
 
-        ASSERT_EQ(v.endEntries(), iter);
+        ASSERT_EQ(uint32_t(8), iter.index());
+        ASSERT_TRUE(NULL != *iter);
+        ASSERT_EQ(iter->is_compact(), compact_entry);
+        ASSERT_EQ(uint32_t(8), iter->key());
+        ASSERT_EQ(uint32_t(0x87654321), iter->value().data);
+        ASSERT_EQ(Res_value::TYPE_STRING, iter->value().dataType);
 
-        free(data);
+        ++iter;
+
+        ASSERT_EQ(v.endEntries(), iter);
     }
 }