Merge "ClatCoordinatorTest: add test for startClat error handling"
diff --git a/tests/unit/java/com/android/server/connectivity/ClatCoordinatorTest.java b/tests/unit/java/com/android/server/connectivity/ClatCoordinatorTest.java
index 94f0c2a..3047a16 100644
--- a/tests/unit/java/com/android/server/connectivity/ClatCoordinatorTest.java
+++ b/tests/unit/java/com/android/server/connectivity/ClatCoordinatorTest.java
@@ -30,6 +30,7 @@
 import static com.android.testutils.MiscAsserts.assertThrows;
 
 import static org.junit.Assert.assertEquals;
+import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.assertNull;
 import static org.junit.Assert.fail;
 import static org.mockito.Mockito.argThat;
@@ -37,6 +38,7 @@
 import static org.mockito.Mockito.eq;
 import static org.mockito.Mockito.inOrder;
 import static org.mockito.Mockito.spy;
+import static org.mockito.Mockito.verify;
 
 import android.annotation.NonNull;
 import android.net.INetd;
@@ -105,12 +107,12 @@
     private static final int RAW_SOCK_FD = 535;
     private static final int PACKET_SOCK_FD = 536;
     private static final long RAW_SOCK_COOKIE = 27149;
-    private static final ParcelFileDescriptor TUN_PFD = new ParcelFileDescriptor(
-            new FileDescriptor());
-    private static final ParcelFileDescriptor RAW_SOCK_PFD = new ParcelFileDescriptor(
-            new FileDescriptor());
-    private static final ParcelFileDescriptor PACKET_SOCK_PFD = new ParcelFileDescriptor(
-            new FileDescriptor());
+    private static final ParcelFileDescriptor TUN_PFD = spy(new ParcelFileDescriptor(
+            new FileDescriptor()));
+    private static final ParcelFileDescriptor RAW_SOCK_PFD = spy(new ParcelFileDescriptor(
+            new FileDescriptor()));
+    private static final ParcelFileDescriptor PACKET_SOCK_PFD = spy(new ParcelFileDescriptor(
+            new FileDescriptor()));
 
     private static final String EGRESS_PROG_PATH =
             "/sys/fs/bpf/net_shared/prog_clatd_schedcls_egress4_clat_rawip";
@@ -533,4 +535,169 @@
         assertEquals("1001 /192.0.0.46 -> /2001:db8:0:b11::464 /64:ff9b::/96 1000 ether",
                 dumpStrings[4].trim());
     }
+
+    @Test
+    public void testNotStartClatWithInvalidPrefix() throws Exception {
+        final ClatCoordinator coordinator = makeClatCoordinator();
+        final IpPrefix invalidPrefix = new IpPrefix("2001:db8::/64");
+        assertThrows(IOException.class,
+                () -> coordinator.clatStart(BASE_IFACE, NETID, invalidPrefix));
+    }
+
+    private void assertStartClat(final TestDependencies deps) throws Exception {
+        final ClatCoordinator coordinator = new ClatCoordinator(deps);
+        assertNotNull(coordinator.clatStart(BASE_IFACE, NETID, NAT64_IP_PREFIX));
+    }
+
+    private void assertNotStartClat(final TestDependencies deps) {
+        // Expect that the injection function of TestDependencies causes clatStart() failed.
+        final ClatCoordinator coordinator = new ClatCoordinator(deps);
+        assertThrows(IOException.class,
+                () -> coordinator.clatStart(BASE_IFACE, NETID, NAT64_IP_PREFIX));
+    }
+
+    private void checkNotStartClat(final TestDependencies deps, final boolean verifyTunFd,
+            final boolean verifyPacketSockFd, final boolean verifyRawSockFd) throws Exception {
+        // [1] Expect that modified TestDependencies can't start clatd.
+        clearInvocations(TUN_PFD, RAW_SOCK_PFD, PACKET_SOCK_PFD);
+        assertNotStartClat(deps);
+        if (verifyTunFd) verify(TUN_PFD).close();
+        if (verifyPacketSockFd) verify(PACKET_SOCK_PFD).close();
+        if (verifyRawSockFd) verify(RAW_SOCK_PFD).close();
+
+        // [2] Expect that unmodified TestDependencies can start clatd.
+        // Used to make sure that the above modified TestDependencies has really broken the
+        // clatd starting.
+        assertStartClat(new TestDependencies());
+    }
+
+    // The following testNotStartClat* tests verifies bunches of code for unwinding the
+    // failure if any.
+    @Test
+    public void testNotStartClatWithNativeFailureSelectIpv4Address() throws Exception {
+        class FailureDependencies extends TestDependencies {
+            @Override
+            public String selectIpv4Address(@NonNull String v4addr, int prefixlen)
+                    throws IOException {
+                throw new IOException();
+            }
+        }
+        checkNotStartClat(new FailureDependencies(), false /* verifyTunFd */,
+                false /* verifyPacketSockFd */, false /* verifyRawSockFd */);
+    }
+
+    @Test
+    public void testNotStartClatWithNativeFailureGenerateIpv6Address() throws Exception {
+        class FailureDependencies extends TestDependencies {
+            @Override
+            public String generateIpv6Address(@NonNull String iface, @NonNull String v4,
+                    @NonNull String prefix64) throws IOException {
+                throw new IOException();
+            }
+        }
+        checkNotStartClat(new FailureDependencies(), false /* verifyTunFd */,
+                false /* verifyPacketSockFd */, false /* verifyRawSockFd */);
+    }
+
+    @Test
+    public void testNotStartClatWithNativeFailureCreateTunInterface() throws Exception {
+        class FailureDependencies extends TestDependencies {
+            @Override
+            public int createTunInterface(@NonNull String tuniface) throws IOException {
+                throw new IOException();
+            }
+        }
+        checkNotStartClat(new FailureDependencies(), false /* verifyTunFd */,
+                false /* verifyPacketSockFd */, false /* verifyRawSockFd */);
+    }
+
+    @Test
+    public void testNotStartClatWithNativeFailureDetectMtu() throws Exception {
+        class FailureDependencies extends TestDependencies {
+            @Override
+            public int detectMtu(@NonNull String platSubnet, int platSuffix, int mark)
+                    throws IOException {
+                throw new IOException();
+            }
+        }
+        checkNotStartClat(new FailureDependencies(), true /* verifyTunFd */,
+                false /* verifyPacketSockFd */, false /* verifyRawSockFd */);
+    }
+
+    @Test
+    public void testNotStartClatWithNativeFailureOpenPacketSocket() throws Exception {
+        class FailureDependencies extends TestDependencies {
+            @Override
+            public int openPacketSocket() throws IOException {
+                throw new IOException();
+            }
+        }
+        checkNotStartClat(new FailureDependencies(), true /* verifyTunFd */,
+                false /* verifyPacketSockFd */, false /* verifyRawSockFd */);
+    }
+
+    @Test
+    public void testNotStartClatWithNativeFailureOpenRawSocket6() throws Exception {
+        class FailureDependencies extends TestDependencies {
+            @Override
+            public int openRawSocket6(int mark) throws IOException {
+                throw new IOException();
+            }
+        }
+        checkNotStartClat(new FailureDependencies(), true /* verifyTunFd */,
+                true /* verifyPacketSockFd */, false /* verifyRawSockFd */);
+    }
+
+    @Test
+    public void testNotStartClatWithNativeFailureAddAnycastSetsockopt() throws Exception {
+        class FailureDependencies extends TestDependencies {
+            @Override
+            public void addAnycastSetsockopt(@NonNull FileDescriptor sock, String v6,
+                    int ifindex) throws IOException {
+                throw new IOException();
+            }
+        }
+        checkNotStartClat(new FailureDependencies(), true /* verifyTunFd */,
+                true /* verifyPacketSockFd */, true /* verifyRawSockFd */);
+    }
+
+    @Test
+    public void testNotStartClatWithNativeFailureTagSocketAsClat() throws Exception {
+        class FailureDependencies extends TestDependencies {
+            @Override
+            public long tagSocketAsClat(@NonNull FileDescriptor sock) throws IOException {
+                throw new IOException();
+            }
+        }
+        checkNotStartClat(new FailureDependencies(), true /* verifyTunFd */,
+                true /* verifyPacketSockFd */, true /* verifyRawSockFd */);
+    }
+
+    @Test
+    public void testNotStartClatWithNativeFailureConfigurePacketSocket() throws Exception {
+        class FailureDependencies extends TestDependencies {
+            @Override
+            public void configurePacketSocket(@NonNull FileDescriptor sock, String v6,
+                    int ifindex) throws IOException {
+                throw new IOException();
+            }
+        }
+        checkNotStartClat(new FailureDependencies(), true /* verifyTunFd */,
+                true /* verifyPacketSockFd */, true /* verifyRawSockFd */);
+    }
+
+    @Test
+    public void testNotStartClatWithNativeFailureStartClatd() throws Exception {
+        class FailureDependencies extends TestDependencies {
+            @Override
+            public int startClatd(@NonNull FileDescriptor tunfd, @NonNull FileDescriptor readsock6,
+                    @NonNull FileDescriptor writesock6, @NonNull String iface,
+                    @NonNull String pfx96, @NonNull String v4, @NonNull String v6)
+                    throws IOException {
+                throw new IOException();
+            }
+        }
+        checkNotStartClat(new FailureDependencies(), true /* verifyTunFd */,
+                true /* verifyPacketSockFd */, true /* verifyRawSockFd */);
+    }
 }