Merge "BpfCoordinator: publish upstream interface mtu to v4 offload rule"
diff --git a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
index 72f83fa..44d3ffc 100644
--- a/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
+++ b/Tethering/src/com/android/networkstack/tethering/BpfCoordinator.java
@@ -27,6 +27,7 @@
 import static android.system.OsConstants.ETH_P_IP;
 import static android.system.OsConstants.ETH_P_IPV6;
 
+import static com.android.net.module.util.NetworkStackConstants.IPV4_MIN_MTU;
 import static com.android.net.module.util.ip.ConntrackMonitor.ConntrackEvent;
 import static com.android.networkstack.tethering.BpfUtils.DOWNSTREAM;
 import static com.android.networkstack.tethering.BpfUtils.UPSTREAM;
@@ -82,6 +83,8 @@
 import java.net.Inet4Address;
 import java.net.Inet6Address;
 import java.net.InetAddress;
+import java.net.NetworkInterface;
+import java.net.SocketException;
 import java.net.UnknownHostException;
 import java.util.ArrayList;
 import java.util.Arrays;
@@ -143,6 +146,8 @@
     static final int NF_CONNTRACK_TCP_TIMEOUT_ESTABLISHED = 432_000;
     @VisibleForTesting
     static final int NF_CONNTRACK_UDP_TIMEOUT_STREAM = 180;
+    @VisibleForTesting
+    static final int INVALID_MTU = 0;
 
     // List of TCP port numbers which aren't offloaded because the packets require the netfilter
     // conntrack helper. See also TetherController::setForwardRules in netd.
@@ -263,6 +268,10 @@
     // TODO: Support multi-upstream interfaces.
     private int mLastIPv4UpstreamIfindex = 0;
 
+    // Tracks the IPv4 upstream interface information.
+    @Nullable
+    private UpstreamInfo mIpv4UpstreamInfo = null;
+
     // Runnable that used by scheduling next polling of stats.
     private final Runnable mScheduledPollingStats = () -> {
         updateForwardedStats();
@@ -320,6 +329,19 @@
             return SdkLevel.isAtLeastS();
         }
 
+        /**
+         * Gets the MTU of the given interface.
+         */
+        public int getNetworkInterfaceMtu(@NonNull String iface) {
+            try {
+                final NetworkInterface networkInterface = NetworkInterface.getByName(iface);
+                return networkInterface == null ? INVALID_MTU : networkInterface.getMTU();
+            } catch (SocketException e) {
+                Log.e(TAG, "Could not get MTU for interface " + iface, e);
+                return INVALID_MTU;
+            }
+        }
+
         /** Get downstream4 BPF map. */
         @Nullable public IBpfMap<Tether4Key, Tether4Value> getBpfDownstream4Map() {
             if (!isAtLeastS()) return null;
@@ -868,6 +890,7 @@
         if (!isUsingBpf()) return;
 
         int upstreamIndex = 0;
+        int mtu = INVALID_MTU;
 
         // This will not work on a network that is using 464xlat because hasIpv4Address will not be
         // true.
@@ -877,6 +900,17 @@
             final String ifaceName = ns.linkProperties.getInterfaceName();
             final InterfaceParams params = mDeps.getInterfaceParams(ifaceName);
             final boolean isVcn = isVcnInterface(ifaceName);
+            mtu = ns.linkProperties.getMtu();
+            if (mtu == INVALID_MTU) {
+                // Get mtu via kernel if mtu is not found in LinkProperties.
+                mtu = mDeps.getNetworkInterfaceMtu(ifaceName);
+            }
+
+            // Use default mtu if can't find any.
+            if (mtu == INVALID_MTU) mtu = NetworkStackConstants.ETHER_MTU;
+            // Clamp to minimum ipv4 mtu
+            if (mtu < IPV4_MIN_MTU) mtu = IPV4_MIN_MTU;
+
             if (!isVcn && params != null && !params.hasMacAddress /* raw ip upstream only */) {
                 upstreamIndex = params.index;
             }
@@ -905,8 +939,11 @@
         // after the upstream is lost do not incorrectly add rules pointing at the upstream.
         if (upstreamIndex == 0) {
             mIpv4UpstreamIndices.clear();
+            mIpv4UpstreamInfo = null;
             return;
         }
+
+        mIpv4UpstreamInfo = new UpstreamInfo(upstreamIndex, mtu);
         Collection<InetAddress> addresses = ns.linkProperties.getAddresses();
         for (final InetAddress addr: addresses) {
             if (isValidUpstreamIpv4Address(addr)) {
@@ -1051,6 +1088,9 @@
         }
         pw.decreaseIndent();
 
+        pw.println("IPv4 Upstream Information: "
+                + (mIpv4UpstreamInfo != null ? mIpv4UpstreamInfo : "<empty>"));
+
         pw.println();
         pw.println("Forwarding counters:");
         pw.increaseIndent();
@@ -1258,10 +1298,10 @@
 
         final String ageStr = (value.lastUsed == 0) ? "-"
                 : String.format("%dms", (now - value.lastUsed) / 1_000_000);
-        return String.format("%s [%s] %d(%s) %s:%d -> %d(%s) %s:%d -> %s:%d [%s] %s",
+        return String.format("%s [%s] %d(%s) %s:%d -> %d(%s) %s:%d -> %s:%d [%s] %d %s",
                 l4protoToString(key.l4proto), key.dstMac, key.iif, getIfName(key.iif),
                 src4, key.srcPort, value.oif, getIfName(value.oif),
-                public4, publicPort, dst4, value.dstPort, value.ethDstMac, ageStr);
+                public4, publicPort, dst4, value.dstPort, value.ethDstMac, value.pmtu, ageStr);
     }
 
     private void dumpIpv4ForwardingRuleMap(long now, boolean downstream,
@@ -1283,13 +1323,13 @@
         try (IBpfMap<Tether4Key, Tether4Value> upstreamMap = mDeps.getBpfUpstream4Map();
                 IBpfMap<Tether4Key, Tether4Value> downstreamMap = mDeps.getBpfDownstream4Map()) {
             pw.println("IPv4 Upstream: proto [inDstMac] iif(iface) src -> nat -> "
-                    + "dst [outDstMac] age");
+                    + "dst [outDstMac] pmtu age");
             pw.increaseIndent();
             dumpIpv4ForwardingRuleMap(now, UPSTREAM, upstreamMap, pw);
             pw.decreaseIndent();
 
             pw.println("IPv4 Downstream: proto [inDstMac] iif(iface) src -> nat -> "
-                    + "dst [outDstMac] age");
+                    + "dst [outDstMac] pmtu age");
             pw.increaseIndent();
             dumpIpv4ForwardingRuleMap(now, DOWNSTREAM, downstreamMap, pw);
             pw.decreaseIndent();
@@ -1540,6 +1580,28 @@
         }
     }
 
+    /** Upstream information class. */
+    private static final class UpstreamInfo {
+        // TODO: add clat interface information
+        public final int ifIndex;
+        public final int mtu;
+
+        private UpstreamInfo(final int ifIndex, final int mtu) {
+            this.ifIndex = ifIndex;
+            this.mtu = mtu;
+        }
+
+        @Override
+        public int hashCode() {
+            return Objects.hash(ifIndex, mtu);
+        }
+
+        @Override
+        public String toString() {
+            return String.format("ifIndex: %d, mtu: %d", ifIndex, mtu);
+        }
+    }
+
     /**
      * A BPF tethering stats provider to provide network statistics to the system.
      * Note that this class' data may only be accessed on the handler thread.
@@ -1711,20 +1773,20 @@
 
         @NonNull
         private Tether4Value makeTetherUpstream4Value(@NonNull ConntrackEvent e,
-                int upstreamIndex) {
-            return new Tether4Value(upstreamIndex,
+                @NonNull UpstreamInfo upstreamInfo) {
+            return new Tether4Value(upstreamInfo.ifIndex,
                     NULL_MAC_ADDRESS /* ethDstMac (rawip) */,
                     NULL_MAC_ADDRESS /* ethSrcMac (rawip) */, ETH_P_IP,
-                    NetworkStackConstants.ETHER_MTU, toIpv4MappedAddressBytes(e.tupleReply.dstIp),
+                    upstreamInfo.mtu, toIpv4MappedAddressBytes(e.tupleReply.dstIp),
                     toIpv4MappedAddressBytes(e.tupleReply.srcIp), e.tupleReply.dstPort,
                     e.tupleReply.srcPort, 0 /* lastUsed, filled by bpf prog only */);
         }
 
         @NonNull
         private Tether4Value makeTetherDownstream4Value(@NonNull ConntrackEvent e,
-                @NonNull ClientInfo c, int upstreamIndex) {
+                @NonNull ClientInfo c, @NonNull UpstreamInfo upstreamInfo) {
             return new Tether4Value(c.downstreamIfindex,
-                    c.clientMac, c.downstreamMac, ETH_P_IP, NetworkStackConstants.ETHER_MTU,
+                    c.clientMac, c.downstreamMac, ETH_P_IP, upstreamInfo.mtu,
                     toIpv4MappedAddressBytes(e.tupleOrig.dstIp),
                     toIpv4MappedAddressBytes(e.tupleOrig.srcIp),
                     e.tupleOrig.dstPort, e.tupleOrig.srcPort,
@@ -1773,9 +1835,11 @@
                 return;
             }
 
-            final Tether4Value upstream4Value = makeTetherUpstream4Value(e, upstreamIndex);
+            if (mIpv4UpstreamInfo == null || mIpv4UpstreamInfo.ifIndex != upstreamIndex) return;
+
+            final Tether4Value upstream4Value = makeTetherUpstream4Value(e, mIpv4UpstreamInfo);
             final Tether4Value downstream4Value = makeTetherDownstream4Value(e, tetherClient,
-                    upstreamIndex);
+                    mIpv4UpstreamInfo);
 
             maybeAddDevMap(upstreamIndex, tetherClient.downstreamIfindex);
             maybeSetLimit(upstreamIndex);
diff --git a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
index ace5f15..1978e99 100644
--- a/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
+++ b/Tethering/tests/unit/src/com/android/networkstack/tethering/BpfCoordinatorTest.java
@@ -32,6 +32,7 @@
 
 import static com.android.dx.mockito.inline.extended.ExtendedMockito.doReturn;
 import static com.android.dx.mockito.inline.extended.ExtendedMockito.staticMockMarker;
+import static com.android.net.module.util.NetworkStackConstants.IPV4_MIN_MTU;
 import static com.android.net.module.util.ip.ConntrackMonitor.ConntrackEvent;
 import static com.android.net.module.util.netlink.ConntrackMessage.DYING_MASK;
 import static com.android.net.module.util.netlink.ConntrackMessage.ESTABLISHED_MASK;
@@ -41,6 +42,7 @@
 import static com.android.net.module.util.netlink.NetlinkConstants.IPCTNL_MSG_CT_DELETE;
 import static com.android.net.module.util.netlink.NetlinkConstants.IPCTNL_MSG_CT_NEW;
 import static com.android.networkstack.tethering.BpfCoordinator.CONNTRACK_TIMEOUT_UPDATE_INTERVAL_MS;
+import static com.android.networkstack.tethering.BpfCoordinator.INVALID_MTU;
 import static com.android.networkstack.tethering.BpfCoordinator.NF_CONNTRACK_TCP_TIMEOUT_ESTABLISHED;
 import static com.android.networkstack.tethering.BpfCoordinator.NF_CONNTRACK_UDP_TIMEOUT_STREAM;
 import static com.android.networkstack.tethering.BpfCoordinator.NON_OFFLOADED_UPSTREAM_IPV4_TCP_PORTS;
@@ -283,6 +285,11 @@
             private int mDstPort = REMOTE_PORT;
             private long mLastUsed = 0;
 
+            public Builder setPmtu(short pmtu) {
+                mPmtu = pmtu;
+                return this;
+            }
+
             public Tether4Value build() {
                 return new Tether4Value(mOif, mEthDstMac, mEthSrcMac, mEthProto, mPmtu,
                         mSrc46, mDst46, mSrcPort, mDstPort, mLastUsed);
@@ -303,6 +310,11 @@
             private int mDstPort = PRIVATE_PORT;
             private long mLastUsed = 0;
 
+            public Builder setPmtu(short pmtu) {
+                mPmtu = pmtu;
+                return this;
+            }
+
             public Tether4Value build() {
                 return new Tether4Value(mOif, mEthDstMac, mEthSrcMac, mEthProto, mPmtu,
                         mSrc46, mDst46, mSrcPort, mDstPort, mLastUsed);
@@ -375,6 +387,7 @@
     private HashMap<IpServer, HashMap<Inet4Address, ClientInfo>> mTetherClients;
 
     private long mElapsedRealtimeNanos = 0;
+    private int mMtu = NetworkStackConstants.ETHER_MTU;
     private final ArgumentCaptor<ArrayList> mStringArrayCaptor =
             ArgumentCaptor.forClass(ArrayList.class);
     private final TestLooper mTestLooper = new TestLooper();
@@ -430,6 +443,10 @@
                         return mElapsedRealtimeNanos;
                     }
 
+                    public int getNetworkInterfaceMtu(@NonNull String iface) {
+                        return mMtu;
+                    }
+
                     @Nullable
                     public IBpfMap<Tether4Key, Tether4Value> getBpfDownstream4Map() {
                         return mBpfDownstream4Map;
@@ -1518,6 +1535,7 @@
         final LinkProperties lp = new LinkProperties();
         lp.setInterfaceName(upstreamInfo.interfaceParams.name);
         lp.addLinkAddress(new LinkAddress(upstreamInfo.address, 32 /* prefix length */));
+        lp.setMtu(mMtu);
         final NetworkCapabilities capabilities = new NetworkCapabilities()
                 .addTransportType(upstreamInfo.transportType);
         coordinator.updateUpstreamNetworkState(new UpstreamNetworkState(lp, capabilities,
@@ -2195,4 +2213,72 @@
 
         verifyDump(coordinator);
     }
+
+    private void verifyAddTetherOffloadRule4Mtu(final int ifaceMtu, final boolean isKernelMtu,
+            final int expectedMtu) throws Exception {
+        // BpfCoordinator#updateUpstreamNetworkState geta mtu from LinkProperties. If not found,
+        // try to get from kernel.
+        if (isKernelMtu) {
+            // LinkProperties mtu is invalid and kernel mtu is valid.
+            mMtu = INVALID_MTU;
+            doReturn(ifaceMtu).when(mDeps).getNetworkInterfaceMtu(any());
+        } else {
+            // LinkProperties mtu is valid and kernel mtu is invalid.
+            mMtu = ifaceMtu;
+            doReturn(INVALID_MTU).when(mDeps).getNetworkInterfaceMtu(any());
+        }
+
+        final BpfCoordinator coordinator = makeBpfCoordinator();
+        initBpfCoordinatorForRule4(coordinator);
+
+        final Tether4Key expectedUpstream4KeyTcp = new TestUpstream4Key.Builder()
+                .setProto(IPPROTO_TCP)
+                .build();
+        final Tether4Key expectedDownstream4KeyTcp = new TestDownstream4Key.Builder()
+                .setProto(IPPROTO_TCP)
+                .build();
+        final Tether4Value expectedUpstream4ValueTcp = new TestUpstream4Value.Builder()
+                .setPmtu((short) expectedMtu)
+                .build();
+        final Tether4Value expectedDownstream4ValueTcp = new TestDownstream4Value.Builder()
+                .setPmtu((short) expectedMtu)
+                .build();
+
+        mConsumer.accept(new TestConntrackEvent.Builder()
+                .setMsgType(IPCTNL_MSG_CT_NEW)
+                .setProto(IPPROTO_TCP)
+                .build());
+        verify(mBpfUpstream4Map)
+                .insertEntry(eq(expectedUpstream4KeyTcp), eq(expectedUpstream4ValueTcp));
+        verify(mBpfDownstream4Map)
+                .insertEntry(eq(expectedDownstream4KeyTcp), eq(expectedDownstream4ValueTcp));
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    public void testAddTetherOffloadRule4LowMtuFromLinkProperties() throws Exception {
+        verifyAddTetherOffloadRule4Mtu(
+                IPV4_MIN_MTU, false /* isKernelMtu */, IPV4_MIN_MTU /* expectedMtu */);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    public void testAddTetherOffloadRule4LowMtuFromKernel() throws Exception {
+        verifyAddTetherOffloadRule4Mtu(
+                IPV4_MIN_MTU, true /* isKernelMtu */, IPV4_MIN_MTU /* expectedMtu */);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    public void testAddTetherOffloadRule4LessThanIpv4MinMtu() throws Exception {
+        verifyAddTetherOffloadRule4Mtu(
+                IPV4_MIN_MTU - 1, false /* isKernelMtu */, IPV4_MIN_MTU /* expectedMtu */);
+    }
+
+    @Test
+    @IgnoreUpTo(Build.VERSION_CODES.R)
+    public void testAddTetherOffloadRule4InvalidMtu() throws Exception {
+        verifyAddTetherOffloadRule4Mtu(INVALID_MTU, false /* isKernelMtu */,
+                NetworkStackConstants.ETHER_MTU /* expectedMtu */);
+    }
 }