FTL: Parametrize SmallMap key equality

Bug: 185536303
Test: ftl_test
Change-Id: I8cc8287c1dfcaf60350039323492de1ddb2101a9
diff --git a/include/ftl/initializer_list.h b/include/ftl/initializer_list.h
index 769c09f..2102c25 100644
--- a/include/ftl/initializer_list.h
+++ b/include/ftl/initializer_list.h
@@ -16,6 +16,7 @@
 
 #pragma once
 
+#include <functional>
 #include <tuple>
 #include <utility>
 
@@ -65,18 +66,18 @@
   std::tuple<Types...> tuple;
 };
 
-template <typename K, typename V>
+template <typename K, typename V, typename KeyEqual = std::equal_to<K>>
 struct KeyValue {};
 
 // Shorthand for key-value pairs that assigns the first argument to the key, and the rest to the
 // value. The specialization is on KeyValue rather than std::pair, so that ftl::init::list works
 // with the latter.
-template <typename K, typename V, std::size_t... Sizes, typename... Types>
-struct InitializerList<KeyValue<K, V>, std::index_sequence<Sizes...>, Types...> {
+template <typename K, typename V, typename E, std::size_t... Sizes, typename... Types>
+struct InitializerList<KeyValue<K, V, E>, std::index_sequence<Sizes...>, Types...> {
   // Accumulate the three arguments to std::pair's piecewise constructor.
   template <typename... Args>
   [[nodiscard]] constexpr auto operator()(K&& k, Args&&... args) && -> InitializerList<
-      KeyValue<K, V>, std::index_sequence<Sizes..., 3>, Types..., std::piecewise_construct_t,
+      KeyValue<K, V, E>, std::index_sequence<Sizes..., 3>, Types..., std::piecewise_construct_t,
       std::tuple<K&&>, std::tuple<Args&&...>> {
     return {std::tuple_cat(
         std::move(tuple),
@@ -94,9 +95,9 @@
   return InitializerList<T>{}(std::forward<Args>(args)...);
 }
 
-template <typename K, typename V, typename... Args>
+template <typename K, typename V, typename E = std::equal_to<K>, typename... Args>
 [[nodiscard]] constexpr auto map(Args&&... args) {
-  return list<KeyValue<K, V>>(std::forward<Args>(args)...);
+  return list<KeyValue<K, V, E>>(std::forward<Args>(args)...);
 }
 
 template <typename K, typename V>
diff --git a/include/ftl/small_map.h b/include/ftl/small_map.h
index 988431a..2effaa4 100644
--- a/include/ftl/small_map.h
+++ b/include/ftl/small_map.h
@@ -61,7 +61,7 @@
 //
 //   assert(map == SmallMap(ftl::init::map(-1, "xyz")(0, "nil")(42, "???")(123, "abc")));
 //
-template <typename K, typename V, std::size_t N>
+template <typename K, typename V, std::size_t N, typename KeyEqual = std::equal_to<K>>
 class SmallMap final {
   using Map = SmallVector<std::pair<const K, V>, N>;
 
@@ -153,7 +153,7 @@
   auto get(const key_type& key, F f) const
       -> std::conditional_t<std::is_void_v<R>, bool, std::optional<R>> {
     for (auto& [k, v] : *this) {
-      if (k == key) {
+      if (KeyEqual{}(k, key)) {
         if constexpr (std::is_void_v<R>) {
           f(v);
           return true;
@@ -245,7 +245,8 @@
 
  private:
   iterator find(const key_type& key, iterator first) {
-    return std::find_if(first, end(), [&key](const auto& pair) { return pair.first == key; });
+    return std::find_if(first, end(),
+                        [&key](const auto& pair) { return KeyEqual{}(pair.first, key); });
   }
 
   bool erase(const key_type& key, iterator first) {
@@ -267,13 +268,13 @@
 };
 
 // Deduction guide for in-place constructor.
-template <typename K, typename V, std::size_t... Sizes, typename... Types>
-SmallMap(InitializerList<KeyValue<K, V>, std::index_sequence<Sizes...>, Types...>&&)
-    -> SmallMap<K, V, sizeof...(Sizes)>;
+template <typename K, typename V, typename E, std::size_t... Sizes, typename... Types>
+SmallMap(InitializerList<KeyValue<K, V, E>, std::index_sequence<Sizes...>, Types...>&&)
+    -> SmallMap<K, V, sizeof...(Sizes), E>;
 
 // Returns whether the key-value pairs of two maps are equal.
-template <typename K, typename V, std::size_t N, typename Q, typename W, std::size_t M>
-bool operator==(const SmallMap<K, V, N>& lhs, const SmallMap<Q, W, M>& rhs) {
+template <typename K, typename V, std::size_t N, typename Q, typename W, std::size_t M, typename E>
+bool operator==(const SmallMap<K, V, N, E>& lhs, const SmallMap<Q, W, M, E>& rhs) {
   if (lhs.size() != rhs.size()) return false;
 
   for (const auto& [k, v] : lhs) {
@@ -287,8 +288,8 @@
 }
 
 // TODO: Remove in C++20.
-template <typename K, typename V, std::size_t N, typename Q, typename W, std::size_t M>
-inline bool operator!=(const SmallMap<K, V, N>& lhs, const SmallMap<Q, W, M>& rhs) {
+template <typename K, typename V, std::size_t N, typename Q, typename W, std::size_t M, typename E>
+inline bool operator!=(const SmallMap<K, V, N, E>& lhs, const SmallMap<Q, W, M, E>& rhs) {
   return !(lhs == rhs);
 }
 
diff --git a/libs/ftl/small_map_test.cpp b/libs/ftl/small_map_test.cpp
index c292108..ee650e5 100644
--- a/libs/ftl/small_map_test.cpp
+++ b/libs/ftl/small_map_test.cpp
@@ -362,4 +362,24 @@
   EXPECT_TRUE(map.dynamic());
 }
 
+TEST(SmallMap, KeyEqual) {
+  struct KeyEqual {
+    bool operator()(int lhs, int rhs) const { return lhs % 10 == rhs % 10; }
+  };
+
+  SmallMap<int, char, 1, KeyEqual> map;
+
+  EXPECT_TRUE(map.try_emplace(3, '3').second);
+  EXPECT_FALSE(map.try_emplace(13, '3').second);
+
+  EXPECT_TRUE(map.try_emplace(22, '2').second);
+  EXPECT_TRUE(map.contains(42));
+
+  EXPECT_TRUE(map.try_emplace(111, '1').second);
+  EXPECT_EQ(map.get(321), '1');
+
+  map.erase(123);
+  EXPECT_EQ(map, SmallMap(ftl::init::map<int, char, KeyEqual>(1, '1')(2, '2')));
+}
+
 }  // namespace android::test