Improve checksum calculation and address checking

1. Add a function that calculates the checksum of all the packet
   components starting from the specified position. This
   simplifies the code a bit and makes it easier to translate
   nested packets like ICMP error messages.

2. Don't hardcode IP source and destination addresses. This is
   required to translate ICMP error messages.

Bug: 8276725
Change-Id: I2cae45683ae3943e508608fd0a140180dbc60823
diff --git a/translate.c b/translate.c
index c0bd59a..4092bcc 100644
--- a/translate.c
+++ b/translate.c
@@ -33,21 +33,81 @@
 #include "logging.h"
 #include "debug.h"
 
-/* function: payload_length
- * calculates the total length of the packet components after pos
+/* function: packet_checksum
+ * calculates the checksum over all the packet components starting from pos
+ * checksum - checksum of packet components before pos
+ * packet   - packet to calculate the checksum of
+ * pos      - position to start counting from
+ * returns  - the completed 16-bit checksum, ready to write into a checksum header field
+ */
+uint16_t packet_checksum(uint32_t checksum, clat_packet packet, int pos) {
+  int i;
+  for (i = pos; i < CLAT_POS_MAX; i++) {
+    if (packet[i].iov_len > 0) {
+      checksum = ip_checksum_add(checksum, packet[i].iov_base, packet[i].iov_len);
+    }
+  }
+  return ip_checksum_finish(checksum);
+}
+
+/* function: packet_length
+ * returns the total length of all the packet components after pos
  * packet - packet to calculate the length of
  * pos    - position to start counting from
  * returns: the total length of the packet components after pos
  */
-uint16_t payload_length(clat_packet packet, int pos) {
+uint16_t packet_length(clat_packet packet, int pos) {
   size_t len = 0;
   int i;
-  for (i = pos + 1; i < POS_MAX; i++) {
+  for (i = pos + 1; i < CLAT_POS_MAX; i++) {
     len += packet[i].iov_len;
   }
   return len;
 }
 
+/* function: is_in_plat_subnet
+ * returns true iff the given IPv6 address is in the plat subnet.
+ * addr - IPv6 address
+ */
+int is_in_plat_subnet(const struct in6_addr *addr6) {
+  // Assumes a /96 plat subnet.
+  return (addr6 != NULL) && (memcmp(addr6, &Global_Clatd_Config.plat_subnet, 12) == 0);
+}
+
+/* function: ipv6_addr_to_ipv4_addr
+ * return the corresponding ipv4 address for the given ipv6 address
+ * addr6 - ipv6 address
+ * returns: the IPv4 address
+ */
+uint32_t ipv6_addr_to_ipv4_addr(const struct in6_addr *addr6) {
+
+  if (is_in_plat_subnet(addr6)) {
+    // Assumes a /96 plat subnet.
+    return addr6->s6_addr32[3];
+  } else {
+    // Currently this can only be our own address; other packets are dropped by ipv6_packet.
+    return Global_Clatd_Config.ipv4_local_subnet.s_addr;
+  }
+}
+
+/* function: ipv4_addr_to_ipv6_addr
+ * return the corresponding ipv6 address for the given ipv4 address
+ * addr4 - ipv4 address
+ */
+struct in6_addr ipv4_addr_to_ipv6_addr(uint32_t addr4) {
+  struct in6_addr addr6;
+  // Both addresses are in network byte order (addr4 comes from a network packet, and the config
+  // file entry is read using inet_ntop).
+  if (addr4 == Global_Clatd_Config.ipv4_local_subnet.s_addr) {
+    return Global_Clatd_Config.ipv6_local_subnet;
+  } else {
+    // Assumes a /96 plat subnet.
+    addr6 = Global_Clatd_Config.plat_subnet;
+    addr6.s6_addr32[3] = addr4;
+    return addr6;
+  }
+}
+
 /* function: fill_tun_header
  * fill in the header for the tun fd
  * tun_header - tunnel header, already allocated
@@ -58,16 +118,6 @@
   tun_header->proto = htons(proto);
 }
 
-/* function: ipv6_src_to_ipv4_src
- * return the corresponding ipv4 address for the given ipv6 address
- * sourceaddr - ipv6 source address
- * returns: the IPv4 address
- */
-uint32_t ipv6_src_to_ipv4_src(const struct in6_addr *sourceaddr) {
-  // assumes a /96 plat subnet
-  return sourceaddr->s6_addr32[3];
-}
-
 /* function: fill_ip_header
  * generate an ipv4 header from an ipv6 header
  * ip_targ     - (ipv4) target packet header, source: original ipv4 addr, dest: local subnet addr
@@ -89,22 +139,8 @@
   ip->protocol = protocol;
   ip->check = 0;
 
-  ip->saddr = ipv6_src_to_ipv4_src(&old_header->ip6_src);
-  ip->daddr = Global_Clatd_Config.ipv4_local_subnet.s_addr;
-}
-
-/* function: ipv4_dst_to_ipv6_dst
- * return the corresponding ipv6 address for the given ipv4 address
- * destination - ipv4 destination address (network byte order)
- */
-struct in6_addr ipv4_dst_to_ipv6_dst(uint32_t destination) {
-  struct in6_addr v6_destination;
-
-  // assumes a /96 plat subnet
-  v6_destination = Global_Clatd_Config.plat_subnet;
-  v6_destination.s6_addr32[3] = destination;
-
-  return v6_destination;
+  ip->saddr = ipv6_addr_to_ipv4_addr(&old_header->ip6_src);
+  ip->daddr = ipv6_addr_to_ipv4_addr(&old_header->ip6_dst);
 }
 
 /* function: fill_ip6_header
@@ -123,8 +159,8 @@
   ip6->ip6_nxt = protocol;
   ip6->ip6_hlim = old_header->ttl;
 
-  ip6->ip6_src = Global_Clatd_Config.ipv6_local_subnet;
-  ip6->ip6_dst = ipv4_dst_to_ipv6_dst(old_header->daddr);
+  ip6->ip6_src = ipv4_addr_to_ipv6_addr(old_header->saddr);
+  ip6->ip6_dst = ipv4_addr_to_ipv6_addr(old_header->daddr);
 }
 
 /* function: icmp_to_icmp6
@@ -152,16 +188,14 @@
   icmp6_targ->icmp6_id = icmp->un.echo.id;
   icmp6_targ->icmp6_seq = icmp->un.echo.sequence;
 
-  icmp6_targ->icmp6_cksum = 0;
-  checksum = ip_checksum_add(checksum, icmp6_targ, sizeof(struct icmp6_hdr));
-  checksum = ip_checksum_add(checksum, payload, payload_size);
-  icmp6_targ->icmp6_cksum = ip_checksum_finish(checksum);
-
   out[pos].iov_len = sizeof(struct icmp6_hdr);
-  out[POS_PAYLOAD].iov_base = (char *) payload;
-  out[POS_PAYLOAD].iov_len = payload_size;
+  out[CLAT_POS_PAYLOAD].iov_base = (char *) payload;
+  out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  return POS_PAYLOAD + 1;
+  icmp6_targ->icmp6_cksum = 0;  // Checksum field must be 0 when calculating checksum.
+  icmp6_targ->icmp6_cksum = packet_checksum(checksum, out, pos);
+
+  return CLAT_POS_PAYLOAD + 1;
 }
 
 /* function: icmp6_to_icmp
@@ -189,16 +223,14 @@
   icmp_targ->un.echo.id = icmp6->icmp6_id;
   icmp_targ->un.echo.sequence = icmp6->icmp6_seq;
 
-  icmp_targ->checksum = 0;
-  checksum = ip_checksum_add(0, icmp_targ, sizeof(struct icmphdr));
-  checksum = ip_checksum_add(checksum, (void *)payload, payload_size);
-  icmp_targ->checksum = ip_checksum_finish(checksum);
-
   out[pos].iov_len = sizeof(struct icmphdr);
-  out[POS_PAYLOAD].iov_base = (char *) payload;
-  out[POS_PAYLOAD].iov_len = payload_size;
+  out[CLAT_POS_PAYLOAD].iov_base = (char *) payload;
+  out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  return POS_PAYLOAD + 1;
+  icmp_targ->checksum = 0;  // Checksum field must be 0 when calculating checksum.
+  icmp_targ->checksum = packet_checksum(0, out, pos);
+
+  return CLAT_POS_PAYLOAD + 1;
 }
 
 /* function: udp_packet
@@ -271,17 +303,15 @@
   struct udphdr *udp_targ = out[pos].iov_base;
 
   memcpy(udp_targ, udp, sizeof(struct udphdr));
-  udp_targ->check = 0; // reset checksum, to be calculated
-
-  checksum = ip_checksum_add(checksum, udp_targ, sizeof(struct udphdr));
-  checksum = ip_checksum_add(checksum, payload, payload_size);
-  udp_targ->check = ip_checksum_finish(checksum);
 
   out[pos].iov_len = sizeof(struct udphdr);
-  out[POS_PAYLOAD].iov_base = (char *) payload;
-  out[POS_PAYLOAD].iov_len = payload_size;
+  out[CLAT_POS_PAYLOAD].iov_base = (char *) payload;
+  out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  return POS_PAYLOAD + 1;
+  udp_targ->check = 0;  // Checksum field must be 0 when calculating checksum.
+  udp_targ->check = packet_checksum(checksum, out, pos);
+
+  return CLAT_POS_PAYLOAD + 1;
 }
 
 /* function: tcp_translate
@@ -312,13 +342,11 @@
 
   memcpy(tcp_targ, tcp, header_size);
 
-  tcp_targ->check = 0;
-  checksum = ip_checksum_add(checksum, tcp_targ, header_size);
-  checksum = ip_checksum_add(checksum, payload, payload_size);
-  tcp_targ->check = ip_checksum_finish(checksum);
+  out[CLAT_POS_PAYLOAD].iov_base = (char *)payload;
+  out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  out[POS_PAYLOAD].iov_base = (char *)payload;
-  out[POS_PAYLOAD].iov_len = payload_size;
+  tcp_targ->check = 0;  // Checksum field must be 0 when calculating checksum.
+  tcp_targ->check = packet_checksum(checksum, out, pos);
 
-  return POS_PAYLOAD + 1;
+  return CLAT_POS_PAYLOAD + 1;
 }