Merge changes Ic69fc75e,I6373d251 into main am: 8a88d85145

Original change: https://android-review.googlesource.com/c/platform/packages/modules/Connectivity/+/2749654

Change-Id: I647baa696ac6e0f99aacb5de35148ad74f5a5498
Signed-off-by: Automerger Merge Worker <android-build-automerger-merge-worker@system.gserviceaccount.com>
diff --git a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
index ec63e41..9b1b72d 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTraceHandler.cpp
@@ -18,6 +18,7 @@
 
 #include "netdbpf/NetworkTraceHandler.h"
 
+#include <android-base/macros.h>
 #include <arpa/inet.h>
 #include <bpf/BpfUtils.h>
 #include <log/log.h>
@@ -75,9 +76,35 @@
   uint32_t bytes = 0;
 };
 
-#define AGG_FIELDS(x)                                              \
-  (x).ifindex, (x).uid, (x).tag, (x).sport, (x).dport, (x).egress, \
-      (x).ipProto, (x).tcpFlags
+BundleKey::BundleKey(const PacketTrace& pkt)
+    : ifindex(pkt.ifindex),
+      uid(pkt.uid),
+      tag(pkt.tag),
+      egress(pkt.egress),
+      ipProto(pkt.ipProto),
+      ipVersion(pkt.ipVersion) {
+  switch (ipProto) {
+    case IPPROTO_TCP:
+      tcpFlags = pkt.tcpFlags;
+      FALLTHROUGH_INTENDED;
+    case IPPROTO_DCCP:
+    case IPPROTO_UDP:
+    case IPPROTO_UDPLITE:
+    case IPPROTO_SCTP:
+      localPort = ntohs(pkt.egress ? pkt.sport : pkt.dport);
+      remotePort = ntohs(pkt.egress ? pkt.dport : pkt.sport);
+      break;
+    case IPPROTO_ICMP:
+    case IPPROTO_ICMPV6:
+      icmpType = ntohs(pkt.sport);
+      icmpCode = ntohs(pkt.dport);
+      break;
+  }
+}
+
+#define AGG_FIELDS(x)                                                    \
+  (x).ifindex, (x).uid, (x).tag, (x).egress, (x).ipProto, (x).ipVersion, \
+      (x).tcpFlags, (x).localPort, (x).remotePort, (x).icmpType, (x).icmpCode
 
 std::size_t BundleHash::operator()(const BundleKey& a) const {
   std::size_t seed = 0;
@@ -179,7 +206,7 @@
       dst->set_timestamp(pkt.timestampNs);
       auto* event = dst->set_network_packet();
       event->set_length(pkt.length);
-      Fill(pkt, event);
+      Fill(BundleKey(pkt), event);
     }
     return;
   }
@@ -187,14 +214,13 @@
   uint64_t minTs = std::numeric_limits<uint64_t>::max();
   std::unordered_map<BundleKey, BundleDetails, BundleHash, BundleEq> bundles;
   for (const PacketTrace& pkt : packets) {
-    BundleKey key = pkt;
+    BundleKey key(pkt);
 
     // Dropping fields should remove them from the output and remove them from
-    // the aggregation key. In order to do the latter without changing the hash
-    // function, set the dropped fields to zero.
-    if (mDropTcpFlags) key.tcpFlags = 0;
-    if (mDropLocalPort) (key.egress ? key.sport : key.dport) = 0;
-    if (mDropRemotePort) (key.egress ? key.dport : key.sport) = 0;
+    // the aggregation key. Reset the optionals to indicate omission.
+    if (mDropTcpFlags) key.tcpFlags.reset();
+    if (mDropLocalPort) key.localPort.reset();
+    if (mDropRemotePort) key.remotePort.reset();
 
     minTs = std::min(minTs, pkt.timestampNs);
 
@@ -245,22 +271,18 @@
   }
 }
 
-void NetworkTraceHandler::Fill(const PacketTrace& src,
+void NetworkTraceHandler::Fill(const BundleKey& src,
                                NetworkPacketEvent* event) {
   event->set_direction(src.egress ? TrafficDirection::DIR_EGRESS
                                   : TrafficDirection::DIR_INGRESS);
   event->set_uid(src.uid);
   event->set_tag(src.tag);
 
-  if (!mDropLocalPort) {
-    event->set_local_port(ntohs(src.egress ? src.sport : src.dport));
-  }
-  if (!mDropRemotePort) {
-    event->set_remote_port(ntohs(src.egress ? src.dport : src.sport));
-  }
-  if (!mDropTcpFlags) {
-    event->set_tcp_flags(src.tcpFlags);
-  }
+  if (src.tcpFlags.has_value()) event->set_tcp_flags(*src.tcpFlags);
+  if (src.localPort.has_value()) event->set_local_port(*src.localPort);
+  if (src.remotePort.has_value()) event->set_remote_port(*src.remotePort);
+  if (src.icmpType.has_value()) event->set_icmp_type(*src.icmpType);
+  if (src.icmpCode.has_value()) event->set_icmp_code(*src.icmpCode);
 
   event->set_ip_proto(src.ipProto);
 
diff --git a/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp b/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp
index f2c1a86..0c4f049 100644
--- a/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp
+++ b/service-t/native/libs/libnetworkstats/NetworkTraceHandlerTest.cpp
@@ -113,7 +113,7 @@
           .length = 100,
           .uid = 10,
           .tag = 123,
-          .ipProto = 6,
+          .ipProto = IPPROTO_TCP,
           .tcpFlags = 1,
       },
   };
@@ -138,12 +138,14 @@
           .sport = htons(8080),
           .dport = htons(443),
           .egress = true,
+          .ipProto = IPPROTO_TCP,
       },
       PacketTrace{
           .timestampNs = 2,
           .sport = htons(443),
           .dport = htons(8080),
           .egress = false,
+          .ipProto = IPPROTO_TCP,
       },
   };
 
@@ -161,6 +163,42 @@
               TrafficDirection::DIR_INGRESS);
 }
 
+TEST_F(NetworkTraceHandlerTest, WriteIcmpTypeAndCode) {
+  std::vector<PacketTrace> input = {
+      PacketTrace{
+          .timestampNs = 1,
+          .sport = htons(11),  // type
+          .dport = htons(22),  // code
+          .egress = true,
+          .ipProto = IPPROTO_ICMP,
+      },
+      PacketTrace{
+          .timestampNs = 2,
+          .sport = htons(33),  // type
+          .dport = htons(44),  // code
+          .egress = false,
+          .ipProto = IPPROTO_ICMPV6,
+      },
+  };
+
+  std::vector<TracePacket> events;
+  ASSERT_TRUE(TraceAndSortPackets(input, &events));
+
+  ASSERT_EQ(events.size(), 2);
+  EXPECT_FALSE(events[0].network_packet().has_local_port());
+  EXPECT_FALSE(events[0].network_packet().has_remote_port());
+  EXPECT_THAT(events[0].network_packet().icmp_type(), 11);
+  EXPECT_THAT(events[0].network_packet().icmp_code(), 22);
+  EXPECT_THAT(events[0].network_packet().direction(),
+              TrafficDirection::DIR_EGRESS);
+  EXPECT_FALSE(events[1].network_packet().local_port());
+  EXPECT_FALSE(events[1].network_packet().remote_port());
+  EXPECT_THAT(events[1].network_packet().icmp_type(), 33);
+  EXPECT_THAT(events[1].network_packet().icmp_code(), 44);
+  EXPECT_THAT(events[1].network_packet().direction(),
+              TrafficDirection::DIR_INGRESS);
+}
+
 TEST_F(NetworkTraceHandlerTest, BasicBundling) {
   // TODO: remove this once bundling becomes default. Until then, set arbitrary
   // aggregation threshold to enable bundling.
@@ -168,12 +206,12 @@
   config.set_aggregation_threshold(10);
 
   std::vector<PacketTrace> input = {
-      PacketTrace{.uid = 123, .timestampNs = 2, .length = 200},
-      PacketTrace{.uid = 123, .timestampNs = 1, .length = 100},
-      PacketTrace{.uid = 123, .timestampNs = 4, .length = 300},
+      PacketTrace{.timestampNs = 2, .length = 200, .uid = 123},
+      PacketTrace{.timestampNs = 1, .length = 100, .uid = 123},
+      PacketTrace{.timestampNs = 4, .length = 300, .uid = 123},
 
-      PacketTrace{.uid = 456, .timestampNs = 2, .length = 400},
-      PacketTrace{.uid = 456, .timestampNs = 4, .length = 100},
+      PacketTrace{.timestampNs = 2, .length = 400, .uid = 456},
+      PacketTrace{.timestampNs = 4, .length = 100, .uid = 456},
   };
 
   std::vector<TracePacket> events;
@@ -203,12 +241,12 @@
   config.set_aggregation_threshold(3);
 
   std::vector<PacketTrace> input = {
-      PacketTrace{.uid = 123, .timestampNs = 2, .length = 200},
-      PacketTrace{.uid = 123, .timestampNs = 1, .length = 100},
-      PacketTrace{.uid = 123, .timestampNs = 4, .length = 300},
+      PacketTrace{.timestampNs = 2, .length = 200, .uid = 123},
+      PacketTrace{.timestampNs = 1, .length = 100, .uid = 123},
+      PacketTrace{.timestampNs = 4, .length = 300, .uid = 123},
 
-      PacketTrace{.uid = 456, .timestampNs = 2, .length = 400},
-      PacketTrace{.uid = 456, .timestampNs = 4, .length = 100},
+      PacketTrace{.timestampNs = 2, .length = 400, .uid = 456},
+      PacketTrace{.timestampNs = 4, .length = 100, .uid = 456},
   };
 
   std::vector<TracePacket> events;
@@ -239,12 +277,17 @@
   __be16 b = htons(10001);
   std::vector<PacketTrace> input = {
       // Recall that local is `src` for egress and `dst` for ingress.
-      PacketTrace{.timestampNs = 1, .length = 2, .egress = true, .sport = a},
-      PacketTrace{.timestampNs = 2, .length = 4, .egress = false, .dport = a},
-      PacketTrace{.timestampNs = 3, .length = 6, .egress = true, .sport = b},
-      PacketTrace{.timestampNs = 4, .length = 8, .egress = false, .dport = b},
+      PacketTrace{.timestampNs = 1, .length = 2, .sport = a, .egress = true},
+      PacketTrace{.timestampNs = 2, .length = 4, .dport = a, .egress = false},
+      PacketTrace{.timestampNs = 3, .length = 6, .sport = b, .egress = true},
+      PacketTrace{.timestampNs = 4, .length = 8, .dport = b, .egress = false},
   };
 
+  // Set common fields.
+  for (PacketTrace& pkt : input) {
+    pkt.ipProto = IPPROTO_TCP;
+  }
+
   std::vector<TracePacket> events;
   ASSERT_TRUE(TraceAndSortPackets(input, &events, config));
   ASSERT_EQ(events.size(), 2);
@@ -274,12 +317,17 @@
   __be16 b = htons(80);
   std::vector<PacketTrace> input = {
       // Recall that remote is `dst` for egress and `src` for ingress.
-      PacketTrace{.timestampNs = 1, .length = 2, .egress = true, .dport = a},
-      PacketTrace{.timestampNs = 2, .length = 4, .egress = false, .sport = a},
-      PacketTrace{.timestampNs = 3, .length = 6, .egress = true, .dport = b},
-      PacketTrace{.timestampNs = 4, .length = 8, .egress = false, .sport = b},
+      PacketTrace{.timestampNs = 1, .length = 2, .dport = a, .egress = true},
+      PacketTrace{.timestampNs = 2, .length = 4, .sport = a, .egress = false},
+      PacketTrace{.timestampNs = 3, .length = 6, .dport = b, .egress = true},
+      PacketTrace{.timestampNs = 4, .length = 8, .sport = b, .egress = false},
   };
 
+  // Set common fields.
+  for (PacketTrace& pkt : input) {
+    pkt.ipProto = IPPROTO_TCP;
+  }
+
   std::vector<TracePacket> events;
   ASSERT_TRUE(TraceAndSortPackets(input, &events, config));
   ASSERT_EQ(events.size(), 2);
@@ -306,12 +354,17 @@
   config.set_aggregation_threshold(10);
 
   std::vector<PacketTrace> input = {
-      PacketTrace{.timestampNs = 1, .uid = 123, .length = 1, .tcpFlags = 1},
-      PacketTrace{.timestampNs = 2, .uid = 123, .length = 2, .tcpFlags = 2},
-      PacketTrace{.timestampNs = 3, .uid = 456, .length = 3, .tcpFlags = 1},
-      PacketTrace{.timestampNs = 4, .uid = 456, .length = 4, .tcpFlags = 2},
+      PacketTrace{.timestampNs = 1, .length = 1, .uid = 123, .tcpFlags = 1},
+      PacketTrace{.timestampNs = 2, .length = 2, .uid = 123, .tcpFlags = 2},
+      PacketTrace{.timestampNs = 3, .length = 3, .uid = 456, .tcpFlags = 1},
+      PacketTrace{.timestampNs = 4, .length = 4, .uid = 456, .tcpFlags = 2},
   };
 
+  // Set common fields.
+  for (PacketTrace& pkt : input) {
+    pkt.ipProto = IPPROTO_TCP;
+  }
+
   std::vector<TracePacket> events;
   ASSERT_TRUE(TraceAndSortPackets(input, &events, config));
 
diff --git a/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTraceHandler.h b/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTraceHandler.h
index bc10e68..6bf186a 100644
--- a/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTraceHandler.h
+++ b/service-t/native/libs/libnetworkstats/include/netdbpf/NetworkTraceHandler.h
@@ -30,15 +30,33 @@
 namespace android {
 namespace bpf {
 
-// BundleKeys are PacketTraces where timestamp and length are ignored.
-using BundleKey = PacketTrace;
+// BundleKey encodes a PacketTrace minus timestamp and length. The key should
+// match many packets over time for interning. For convenience, sport/dport
+// are parsed here as either local/remote port or icmp type/code.
+struct BundleKey {
+  explicit BundleKey(const PacketTrace& pkt);
 
-// BundleKeys are hashed using all fields except timestamp/length.
+  uint32_t ifindex;
+  uint32_t uid;
+  uint32_t tag;
+
+  bool egress;
+  uint8_t ipProto;
+  uint8_t ipVersion;
+
+  std::optional<uint8_t> tcpFlags;
+  std::optional<uint16_t> localPort;
+  std::optional<uint16_t> remotePort;
+  std::optional<uint8_t> icmpType;
+  std::optional<uint8_t> icmpCode;
+};
+
+// BundleKeys are hashed using a simple hash combine.
 struct BundleHash {
   std::size_t operator()(const BundleKey& a) const;
 };
 
-// BundleKeys are equal if all fields except timestamp/length are equal.
+// BundleKeys are equal if all fields are equal.
 struct BundleEq {
   bool operator()(const BundleKey& a, const BundleKey& b) const;
 };
@@ -84,13 +102,13 @@
              NetworkTraceHandler::TraceContext& ctx);
 
  private:
-  // Convert a PacketTrace into a Perfetto trace packet.
-  void Fill(const PacketTrace& src,
+  // Fills in contextual information from a bundle without interning.
+  void Fill(const BundleKey& src,
             ::perfetto::protos::pbzero::NetworkPacketEvent* event);
 
   // Fills in contextual information either inline or via interning.
   ::perfetto::protos::pbzero::NetworkPacketBundle* FillWithInterning(
-      NetworkTraceState* state, const BundleKey& key,
+      NetworkTraceState* state, const BundleKey& src,
       ::perfetto::protos::pbzero::TracePacket* dst);
 
   static internal::NetworkTracePoller sPoller;