Revert "DO NOT MERGE: Fix up checksums instead of recalculating them."

This reverts commit 489e108988036facb25c59d59eb5250cf076fd3a.

Change-Id: I39e24afd8e9f1c862c0b7eea872c4fe31240aecf
diff --git a/checksum.c b/checksum.c
index 099be6a..a4dc9b8 100644
--- a/checksum.c
+++ b/checksum.c
@@ -49,25 +49,17 @@
   return checksum;
 }
 
-/* function: ip_checksum_fold
- * folds a 32-bit partial checksum into 16 bits
+/* function: ip_checksum_finish
+ * close the checksum
  * temp_sum - sum from ip_checksum_add
- * returns: the folded checksum in network byte order
  */
-uint16_t ip_checksum_fold(uint32_t temp_sum) {
+uint16_t ip_checksum_finish(uint32_t temp_sum) {
   while(temp_sum > 0xffff)
     temp_sum = (temp_sum >> 16) + (temp_sum & 0xFFFF);
 
-  return temp_sum;
-}
+  temp_sum = (~temp_sum) & 0xffff;
 
-/* function: ip_checksum_finish
- * folds and closes the checksum
- * temp_sum - sum from ip_checksum_add
- * returns: a header checksum value in network byte order
- */
-uint16_t ip_checksum_finish(uint32_t temp_sum) {
-  return ~ip_checksum_fold(temp_sum);
+  return temp_sum;
 }
 
 /* function: ip_checksum
@@ -121,23 +113,3 @@
 
   return current;
 }
-
-/* function: ip_checksum_adjust
- * calculates a new checksum given a previous checksum and the old and new pseudo-header checksums
- * checksum    - the header checksum in the original packet in network byte order
- * old_hdr_sum - the pseudo-header checksum of the original packet
- * new_hdr_sum - the pseudo-header checksum of the translated packet
- * returns: the new header checksum in network byte order
- */
-uint16_t ip_checksum_adjust(uint16_t checksum, uint32_t old_hdr_sum, uint32_t new_hdr_sum) {
-  // Algorithm suggested in RFC 1624.
-  // http://tools.ietf.org/html/rfc1624#section-3
-  checksum = ~checksum;
-  uint16_t folded_sum = ip_checksum_fold(checksum + new_hdr_sum);
-  uint16_t folded_old = ip_checksum_fold(old_hdr_sum);
-  if (folded_sum > folded_old) {
-    return ~(folded_sum - folded_old);
-  } else {
-    return ~(folded_sum - folded_old - 1);  // end-around borrow
-  }
-}
diff --git a/checksum.h b/checksum.h
index 44921f0..473f5f5 100644
--- a/checksum.h
+++ b/checksum.h
@@ -25,6 +25,4 @@
 uint32_t ipv6_pseudo_header_checksum(uint32_t current, const struct ip6_hdr *ip6, uint16_t len);
 uint32_t ipv4_pseudo_header_checksum(uint32_t current, const struct iphdr *ip, uint16_t len);
 
-uint16_t ip_checksum_adjust(uint16_t checksum, uint32_t old_hdr_sum, uint32_t new_hdr_sum);
-
 #endif /* __CHECKSUM_H__ */
diff --git a/ipv4.c b/ipv4.c
index 1d5b0b2..b5cbf80 100644
--- a/ipv4.c
+++ b/ipv4.c
@@ -70,7 +70,7 @@
   uint8_t nxthdr;
   const char *next_header;
   size_t len_left;
-  uint32_t old_sum, new_sum;
+  uint32_t checksum;
   int iov_len;
 
   if(len < sizeof(struct iphdr)) {
@@ -121,17 +121,14 @@
   out[pos].iov_len = sizeof(struct ip6_hdr);
 
   // Calculate the pseudo-header checksum.
-  old_sum = ipv4_pseudo_header_checksum(0, header, len_left);
-  new_sum = ipv6_pseudo_header_checksum(0, ip6_targ, len_left);
+  checksum = ipv6_pseudo_header_checksum(0, ip6_targ, len_left);
 
   if (nxthdr == IPPROTO_ICMPV6) {
-    iov_len = icmp_packet(out, pos + 1, (const struct icmphdr *) next_header, new_sum, len_left);
+    iov_len = icmp_packet(out, pos + 1, (const struct icmphdr *) next_header, checksum, len_left);
   } else if (nxthdr == IPPROTO_TCP) {
-    iov_len = tcp_packet(out, pos + 1, (const struct tcphdr *) next_header, old_sum, new_sum,
-                         len_left);
+    iov_len = tcp_packet(out, pos + 1, (const struct tcphdr *) next_header, checksum, len_left);
   } else if (nxthdr == IPPROTO_UDP) {
-    iov_len = udp_packet(out, pos + 1, (const struct udphdr *) next_header, old_sum, new_sum,
-                         len_left);
+    iov_len = udp_packet(out, pos + 1, (const struct udphdr *) next_header, checksum, len_left);
   } else if (nxthdr == IPPROTO_GRE) {
     iov_len = generic_packet(out, pos + 1, next_header, len_left);
   } else {
diff --git a/ipv6.c b/ipv6.c
index e4a73fe..79303ec 100644
--- a/ipv6.c
+++ b/ipv6.c
@@ -88,7 +88,7 @@
   uint8_t protocol;
   const char *next_header;
   size_t len_left;
-  uint32_t old_sum, new_sum;
+  uint32_t checksum;
   int iov_len;
 
   if(len < sizeof(struct ip6_hdr)) {
@@ -133,17 +133,16 @@
   out[pos].iov_len = sizeof(struct iphdr);
 
   // Calculate the pseudo-header checksum.
-  old_sum = ipv6_pseudo_header_checksum(0, ip6, len_left);
-  new_sum = ipv4_pseudo_header_checksum(0, ip_targ, len_left);
+  checksum = ipv4_pseudo_header_checksum(0, ip_targ, len_left);
 
   // does not support IPv6 extension headers, this will drop any packet with them
   if (protocol == IPPROTO_ICMP) {
     iov_len = icmp6_packet(out, pos + 1, (const struct icmp6_hdr *) next_header, len_left);
   } else if (ip6->ip6_nxt == IPPROTO_TCP) {
-    iov_len = tcp_packet(out, pos + 1, (const struct tcphdr *) next_header, old_sum, new_sum,
+    iov_len = tcp_packet(out, pos + 1, (const struct tcphdr *) next_header, checksum,
                          len_left);
   } else if (ip6->ip6_nxt == IPPROTO_UDP) {
-    iov_len = udp_packet(out, pos + 1, (const struct udphdr *) next_header, old_sum, new_sum,
+    iov_len = udp_packet(out, pos + 1, (const struct udphdr *) next_header, checksum,
                          len_left);
   } else if (ip6->ip6_nxt == IPPROTO_GRE) {
     iov_len = generic_packet(out, pos + 1, next_header, len_left);
diff --git a/translate.c b/translate.c
index 9a0f1b5..00ea0b9 100644
--- a/translate.c
+++ b/translate.c
@@ -208,10 +208,12 @@
     // The pseudo-header checksum was calculated on the transport length of the original IPv4
     // packet that we were asked to translate. This transport length is 20 bytes smaller than it
     // needs to be, because the ICMP error contains an IPv4 header, which we will be translating to
-    // an IPv6 header, which is 20 bytes longer. Fix it up here.
+    // an IPv6 header, which is 20 bytes longer. Fix it up here. This is simpler than the
+    // alternative, which is to always update the pseudo-header checksum in all UDP/TCP/ICMP
+    // translation functions (rather than pre-calculating it when translating the IPv4 header).
     // We only need to do this for ICMP->ICMPv6, not ICMPv6->ICMP, because ICMP does not use the
     // pseudo-header when calculating its checksum (as the IPv4 header has its own checksum).
-    checksum = checksum + htons(20);
+    checksum = htonl(ntohl(checksum) + 20);
   } else if (icmp6_type == ICMP6_ECHO_REQUEST || icmp6_type == ICMP6_ECHO_REPLY) {
     // Ping packet.
     icmp6_targ->icmp6_id = icmp->un.echo.id;
@@ -296,12 +298,10 @@
  * takes a udp packet and sets it up for translation
  * out      - output packet
  * udp      - pointer to udp header in packet
- * old_sum  - pseudo-header checksum of old header
- * new_sum  - pseudo-header checksum of new header
+ * checksum - pseudo-header checksum
  * len      - size of ip payload
  */
-int udp_packet(clat_packet out, int pos, const struct udphdr *udp,
-               uint32_t old_sum, uint32_t new_sum, size_t len) {
+int udp_packet(clat_packet out, int pos, const struct udphdr *udp, uint32_t checksum, size_t len) {
   const char *payload;
   size_t payload_size;
 
@@ -313,7 +313,7 @@
   payload = (const char *) (udp + 1);
   payload_size = len - sizeof(struct udphdr);
 
-  return udp_translate(out, pos, udp, old_sum, new_sum, payload, payload_size);
+  return udp_translate(out, pos, udp, checksum, payload, payload_size);
 }
 
 /* function: tcp_packet
@@ -324,8 +324,7 @@
  * len      - size of ip payload
  * returns: the highest position in the output clat_packet that's filled in
  */
-int tcp_packet(clat_packet out, int pos, const struct tcphdr *tcp,
-               uint32_t old_sum, uint32_t new_sum, size_t len) {
+int tcp_packet(clat_packet out, int pos, const struct tcphdr *tcp, uint32_t checksum, size_t len) {
   const char *payload;
   size_t payload_size, header_size;
 
@@ -348,21 +347,20 @@
   payload = ((const char *) tcp) + header_size;
   payload_size = len - header_size;
 
-  return tcp_translate(out, pos, tcp, header_size, old_sum, new_sum, payload, payload_size);
+  return tcp_translate(out, pos, tcp, header_size, checksum, payload, payload_size);
 }
 
 /* function: udp_translate
  * common between ipv4/ipv6 - setup checksum and send udp packet
  * out          - output packet
  * udp          - udp header
- * old_sum      - pseudo-header checksum of old header
- * new_sum      - pseudo-header checksum of new header
+ * checksum     - pseudo-header checksum
  * payload      - tcp payload
  * payload_size - size of payload
  * returns: the highest position in the output clat_packet that's filled in
  */
-int udp_translate(clat_packet out, int pos, const struct udphdr *udp, uint32_t old_sum,
-                  uint32_t new_sum, const char *payload, size_t payload_size) {
+int udp_translate(clat_packet out, int pos, const struct udphdr *udp, uint32_t checksum,
+                  const char *payload, size_t payload_size) {
   struct udphdr *udp_targ = out[pos].iov_base;
 
   memcpy(udp_targ, udp, sizeof(struct udphdr));
@@ -371,22 +369,8 @@
   out[CLAT_POS_PAYLOAD].iov_base = (char *) payload;
   out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  if (udp_targ->check) {
-    udp_targ->check = ip_checksum_adjust(udp->check, old_sum, new_sum);
-  } else {
-    // Zero checksums are special. RFC 768 says, "An all zero transmitted checksum value means that
-    // the transmitter generated no checksum (for debugging or for higher level protocols that
-    // don't care)." However, in IPv6 zero UDP checksums were only permitted by RFC 6935 (2013). So
-    // for safety we recompute it.
-    udp_targ->check = 0;  // Checksum field must be 0 when calculating checksum.
-    udp_targ->check = packet_checksum(new_sum, out, pos);
-  }
-
-  // RFC 768: "If the computed checksum is zero, it is transmitted as all ones (the equivalent
-  // in one's complement arithmetic)."
-  if (!udp_targ->check) {
-    udp_targ->check = 0xffff;
-  }
+  udp_targ->check = 0;  // Checksum field must be 0 when calculating checksum.
+  udp_targ->check = packet_checksum(checksum, out, pos);
 
   return CLAT_POS_PAYLOAD + 1;
 }
@@ -405,7 +389,7 @@
  * TODO: hosts without pmtu discovery - non DF packets will rely on fragmentation (unimplemented)
  */
 int tcp_translate(clat_packet out, int pos, const struct tcphdr *tcp, size_t header_size,
-                  uint32_t old_sum, uint32_t new_sum, const char *payload, size_t payload_size) {
+                  uint32_t checksum, const char *payload, size_t payload_size) {
   struct tcphdr *tcp_targ = out[pos].iov_base;
   out[pos].iov_len = header_size;
 
@@ -422,7 +406,8 @@
   out[CLAT_POS_PAYLOAD].iov_base = (char *)payload;
   out[CLAT_POS_PAYLOAD].iov_len = payload_size;
 
-  tcp_targ->check = ip_checksum_adjust(tcp->check, old_sum, new_sum);
+  tcp_targ->check = 0;  // Checksum field must be 0 when calculating checksum.
+  tcp_targ->check = packet_checksum(checksum, out, pos);
 
   return CLAT_POS_PAYLOAD + 1;
 }
diff --git a/translate.h b/translate.h
index cfb7bbb..9f1ac15 100644
--- a/translate.h
+++ b/translate.h
@@ -61,14 +61,12 @@
 int generic_packet(clat_packet out, int pos, const char *payload, size_t len);
 
 // Translate TCP and UDP packets.
-int tcp_packet(clat_packet out, int pos, const struct tcphdr *tcp,
-               uint32_t old_sum, uint32_t new_sum, size_t len);
-int udp_packet(clat_packet out, int pos, const struct udphdr *udp,
-               uint32_t old_sum, uint32_t new_sum, size_t len);
+int tcp_packet(clat_packet out, int pos, const struct tcphdr *tcp, uint32_t checksum, size_t len);
+int udp_packet(clat_packet out, int pos, const struct udphdr *udp, uint32_t checksum, size_t len);
 
 int tcp_translate(clat_packet out, int pos, const struct tcphdr *tcp, size_t header_size,
-                  uint32_t old_sum, uint32_t new_sum, const char *payload, size_t payload_size);
-int udp_translate(clat_packet out, int pos, const struct udphdr *udp,
-                  uint32_t old_sum, uint32_t new_sum, const char *payload, size_t payload_size);
+                  uint32_t checksum, const char *payload, size_t payload_size);
+int udp_translate(clat_packet out, int pos, const struct udphdr *udp, uint32_t checksum,
+                  const char *payload, size_t payload_size);
 
 #endif /* __TRANSLATE_H__ */