Merge "Add call to bind for ConscryptPerfTests" into main
diff --git a/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ClientEndpoint.java b/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ClientEndpoint.java
index 1a7258a..4c34165 100644
--- a/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ClientEndpoint.java
+++ b/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ClientEndpoint.java
@@ -20,6 +20,7 @@
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.lang.AutoCloseable;
 import java.net.InetAddress;
 import java.net.SocketException;
 import java.nio.channels.ClosedChannelException;
@@ -33,7 +34,7 @@
  * Client-side endpoint. Provides basic services for sending/receiving messages from the client
  * socket.
  */
-final class ClientEndpoint {
+final class ClientEndpoint implements AutoCloseable {
     private final SSLSocket socket;
     private InputStream input;
     private OutputStream output;
@@ -56,6 +57,11 @@
         }
     }
 
+    @Override
+    public void close() {
+        stop();
+    }
+
     void stop() {
         try {
             socket.close();
diff --git a/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ClientSocketPerfTest.java b/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ClientSocketPerfTest.java
index f20b170..9e45c4a 100644
--- a/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ClientSocketPerfTest.java
+++ b/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ClientSocketPerfTest.java
@@ -44,24 +44,21 @@
 import javax.crypto.NoSuchPaddingException;
 
 import org.junit.Rule;
+import org.junit.After;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 import junitparams.JUnitParamsRunner;
 import junitparams.Parameters;
 import android.conscrypt.ServerEndpoint.MessageProcessor;
 
-/**
- * Benchmark for comparing performance of server socket implementations.
- */
+/** Benchmark for comparing performance of server socket implementations. */
 @RunWith(JUnitParamsRunner.class)
 @LargeTest
 public final class ClientSocketPerfTest {
 
     @Rule public PerfStatusReporter mPerfStatusReporter = new PerfStatusReporter();
 
-    /**
-     * Provider for the test configuration
-     */
+    /** Provider for the test configuration */
     private class Config {
         EndpointFactory a_clientFactory;
         EndpointFactory b_serverFactory;
@@ -69,19 +66,22 @@
         String d_cipher;
         ChannelType e_channelType;
         PerfTestProtocol f_protocol;
-        Config(EndpointFactory clientFactory,
-            EndpointFactory serverFactory,
-            int messageSize,
-            String cipher,
-            ChannelType channelType,
-            PerfTestProtocol protocol) {
-          a_clientFactory = clientFactory;
-          b_serverFactory = serverFactory;
-          c_messageSize = messageSize;
-          d_cipher = cipher;
-          e_channelType = channelType;
-          f_protocol = protocol;
+
+        Config(
+                EndpointFactory clientFactory,
+                EndpointFactory serverFactory,
+                int messageSize,
+                String cipher,
+                ChannelType channelType,
+                PerfTestProtocol protocol) {
+            a_clientFactory = clientFactory;
+            b_serverFactory = serverFactory;
+            c_messageSize = messageSize;
+            d_cipher = cipher;
+            e_channelType = channelType;
+            f_protocol = protocol;
         }
+
         public EndpointFactory clientFactory() {
             return a_clientFactory;
         }
@@ -112,23 +112,43 @@
         for (EndpointFactory endpointFactory : EndpointFactory.values()) {
             for (ChannelType channelType : ChannelType.values()) {
                 for (PerfTestProtocol protocol : PerfTestProtocol.values()) {
-                    params.add(new Object[] {new Config(endpointFactory,
-                        endpointFactory, 64, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
-                        channelType, protocol)});
-                    params.add(new Object[] {new Config(endpointFactory,
-                        endpointFactory, 512, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
-                        channelType, protocol)});
-                    params.add(new Object[] {new Config(endpointFactory,
-                        endpointFactory, 4096, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
-                        channelType, protocol)});
+                    params.add(
+                            new Object[] {
+                                new Config(
+                                        endpointFactory,
+                                        endpointFactory,
+                                        64,
+                                        "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
+                                        channelType,
+                                        protocol)
+                            });
+                    params.add(
+                            new Object[] {
+                                new Config(
+                                        endpointFactory,
+                                        endpointFactory,
+                                        512,
+                                        "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
+                                        channelType,
+                                        protocol)
+                            });
+                    params.add(
+                            new Object[] {
+                                new Config(
+                                        endpointFactory,
+                                        endpointFactory,
+                                        4096,
+                                        "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
+                                        channelType,
+                                        protocol)
+                            });
                 }
             }
         }
         return params;
     }
 
-    private ClientEndpoint client;
-    private ServerEndpoint server;
+    private SocketPair socketPair = new SocketPair();
     private byte[] message;
     private ExecutorService executor;
     private Future<?> sendingFuture;
@@ -137,46 +157,78 @@
     private static final AtomicLong bytesCounter = new AtomicLong();
     private AtomicBoolean recording = new AtomicBoolean();
 
+    private static class SocketPair implements AutoCloseable {
+        public ClientEndpoint client;
+        public ServerEndpoint server;
+
+        SocketPair() {
+            client = null;
+            server = null;
+        }
+
+        @Override
+        public void close() {
+            if (client != null) {
+                client.stop();
+            }
+            if (server != null) {
+                server.stop();
+            }
+        }
+    }
+
     private void setup(Config config) throws Exception {
         message = newTextMessage(512);
 
         // Always use the same server for consistency across the benchmarks.
-        server = config.serverFactory().newServer(
-                config.messageSize(), config.protocol().getProtocols(),
-                ciphers(config));
+        socketPair.server =
+                config.serverFactory()
+                        .newServer(
+                                config.messageSize(),
+                                config.protocol().getProtocols(),
+                                ciphers(config));
+        socketPair.server.init();
 
-        server.setMessageProcessor(new ServerEndpoint.MessageProcessor() {
-            @Override
-            public void processMessage(byte[] inMessage, int numBytes, OutputStream os) {
-                if (recording.get()) {
-                    // Server received a message, increment the count.
-                    bytesCounter.addAndGet(numBytes);
-                }
-            }
-        });
-        Future<?> connectedFuture = server.start();
+        socketPair.server.setMessageProcessor(
+                new ServerEndpoint.MessageProcessor() {
+                    @Override
+                    public void processMessage(byte[] inMessage, int numBytes, OutputStream os) {
+                        if (recording.get()) {
+                            // Server received a message, increment the count.
+                            bytesCounter.addAndGet(numBytes);
+                        }
+                    }
+                });
+        Future<?> connectedFuture = socketPair.server.start();
 
-        client = config.clientFactory().newClient(
-            config.channelType(), server.port(), config.protocol().getProtocols(), ciphers(config));
-        client.start();
+        socketPair.client =
+                config.clientFactory()
+                        .newClient(
+                                config.channelType(),
+                                socketPair.server.port(),
+                                config.protocol().getProtocols(),
+                                ciphers(config));
+        socketPair.client.start();
 
         // Wait for the initial connection to complete.
         connectedFuture.get(5, TimeUnit.SECONDS);
 
         executor = Executors.newSingleThreadExecutor();
-        sendingFuture = executor.submit(new Runnable() {
-            @Override
-            public void run() {
-                try {
-                    Thread thread = Thread.currentThread();
-                    while (!stopping && !thread.isInterrupted()) {
-                        client.sendMessage(message);
-                    }
-                } finally {
-                    client.flush();
-                }
-            }
-        });
+        sendingFuture =
+                executor.submit(
+                        new Runnable() {
+                            @Override
+                            public void run() {
+                                try {
+                                    Thread thread = Thread.currentThread();
+                                    while (!stopping && !thread.isInterrupted()) {
+                                        socketPair.client.sendMessage(message);
+                                    }
+                                } finally {
+                                    socketPair.client.flush();
+                                }
+                            }
+                        });
     }
 
     void close() throws Exception {
@@ -185,29 +237,37 @@
         // Wait for the sending thread to stop.
         sendingFuture.get(5, TimeUnit.SECONDS);
 
-        client.stop();
-        server.stop();
-        executor.shutdown();
-        executor.awaitTermination(5, TimeUnit.SECONDS);
+        if (socketPair != null) {
+            socketPair.close();
+        }
+        if (executor != null) {
+            executor.shutdown();
+            executor.awaitTermination(5, TimeUnit.SECONDS);
+        }
     }
 
-    /**
-     * Simple benchmark for the amount of time to send a given number of messages
-     */
+    /** Simple benchmark for the amount of time to send a given number of messages */
     @Test
     @Parameters(method = "getParams")
     public void time(Config config) throws Exception {
-        reset();
-        setup(config);
-        recording.set(true);
+        try {
+            reset();
+            setup(config);
+            recording.set(true);
 
-        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
-        while (state.keepRunning()) {
-          while (bytesCounter.get() < config.messageSize()) {
-          }
-          bytesCounter.set(0);
+            BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
+            while (state.keepRunning()) {
+                while (bytesCounter.get() < config.messageSize()) {}
+                bytesCounter.set(0);
+            }
+            recording.set(false);
+        } finally {
+            close();
         }
-        recording.set(false);
+    }
+
+    @After
+    public void tearDown() throws Exception {
         close();
     }
 
@@ -219,4 +279,4 @@
     private String[] ciphers(Config config) {
         return new String[] {config.cipher()};
     }
-}
\ No newline at end of file
+}
diff --git a/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ServerEndpoint.java b/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ServerEndpoint.java
index 1e4f124..83eaaa1 100644
--- a/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ServerEndpoint.java
+++ b/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ServerEndpoint.java
@@ -16,10 +16,14 @@
 
 package android.conscrypt;
 
+import static org.conscrypt.TestUtils.getLoopbackAddress;
+
 import java.io.EOFException;
 import java.io.IOException;
 import java.io.InputStream;
 import java.io.OutputStream;
+import java.lang.AutoCloseable;
+import java.net.InetSocketAddress;
 import java.net.ServerSocket;
 import java.net.SocketException;
 import java.nio.channels.ClosedChannelException;
@@ -37,7 +41,7 @@
 /**
  * A simple socket-based test server.
  */
-final class ServerEndpoint {
+final class ServerEndpoint implements AutoCloseable {
     /**
      * A processor for receipt of a single message.
      */
@@ -82,7 +86,11 @@
         this.messageSize = messageSize;
         this.protocols = protocols;
         this.cipherSuites = cipherSuites;
-        buffer = new byte[messageSize];
+        this.buffer = new byte[messageSize];
+    }
+
+    void init() throws IOException {
+        serverSocket.bind(new InetSocketAddress(getLoopbackAddress(), 0));
     }
 
     void setMessageProcessor(MessageProcessor messageProcessor) {
@@ -94,6 +102,11 @@
         return executor.submit(new AcceptTask());
     }
 
+    @Override
+    public void close() {
+        stop();
+    }
+
     void stop() {
         try {
             stopping = true;
diff --git a/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ServerSocketPerfTest.java b/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ServerSocketPerfTest.java
index af3c405..90a87ce 100644
--- a/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ServerSocketPerfTest.java
+++ b/apct-tests/perftests/core/src/android/conscrypt/conscrypt/ServerSocketPerfTest.java
@@ -44,6 +44,7 @@
 import junitparams.Parameters;
 
 import org.junit.Rule;
+import org.junit.After;
 import org.junit.Test;
 import org.junit.runner.RunWith;
 
@@ -115,14 +116,33 @@
         return params;
     }
 
-    private ClientEndpoint client;
-    private ServerEndpoint server;
+    private SocketPair socketPair = new SocketPair();
     private ExecutorService executor;
     private Future<?> receivingFuture;
     private volatile boolean stopping;
     private static final AtomicLong bytesCounter = new AtomicLong();
     private AtomicBoolean recording = new AtomicBoolean();
 
+    private static class SocketPair implements AutoCloseable {
+        public ClientEndpoint client;
+        public ServerEndpoint server;
+
+        SocketPair() {
+            client = null;
+            server = null;
+        }
+
+        @Override
+        public void close() {
+            if (client != null) {
+                client.stop();
+            }
+            if (server != null) {
+                server.stop();
+            }
+        }
+    }
+
     private void setup(final Config config) throws Exception {
         recording.set(false);
 
@@ -130,9 +150,10 @@
 
         final ChannelType channelType = config.channelType();
 
-        server = config.serverFactory().newServer(config.messageSize(),
+        socketPair.server = config.serverFactory().newServer(config.messageSize(),
             new String[] {"TLSv1.3", "TLSv1.2"}, ciphers(config));
-        server.setMessageProcessor(new MessageProcessor() {
+        socketPair.server.init();
+        socketPair.server.setMessageProcessor(new MessageProcessor() {
             @Override
             public void processMessage(byte[] inMessage, int numBytes, OutputStream os) {
                 try {
@@ -151,20 +172,20 @@
             }
         });
 
-        Future<?> connectedFuture = server.start();
+        Future<?> connectedFuture = socketPair.server.start();
 
         // Always use the same client for consistency across the benchmarks.
-        client = config.clientFactory().newClient(
-                ChannelType.CHANNEL, server.port(),
+        socketPair.client = config.clientFactory().newClient(
+                ChannelType.CHANNEL, socketPair.server.port(),
                 new String[] {"TLSv1.3", "TLSv1.2"}, ciphers(config));
-        client.start();
+        socketPair.client.start();
 
         // Wait for the initial connection to complete.
         connectedFuture.get(5, TimeUnit.SECONDS);
 
         // Start the server-side streaming by sending a message to the server.
-        client.sendMessage(message);
-        client.flush();
+        socketPair.client.sendMessage(message);
+        socketPair.client.flush();
 
         executor = Executors.newSingleThreadExecutor();
         receivingFuture = executor.submit(new Runnable() {
@@ -173,7 +194,7 @@
                 Thread thread = Thread.currentThread();
                 byte[] buffer = new byte[config.messageSize()];
                 while (!stopping && !thread.isInterrupted()) {
-                    int numBytes = client.readMessage(buffer);
+                    int numBytes = socketPair.client.readMessage(buffer);
                     if (numBytes < 0) {
                         return;
                     }
@@ -191,25 +212,38 @@
     void close() throws Exception {
         stopping = true;
         // Stop and wait for sending to complete.
-        server.stop();
-        client.stop();
-        executor.shutdown();
-        receivingFuture.get(5, TimeUnit.SECONDS);
-        executor.awaitTermination(5, TimeUnit.SECONDS);
+        if (socketPair != null) {
+            socketPair.close();
+        }
+        if (executor != null) {
+            executor.shutdown();
+            executor.awaitTermination(5, TimeUnit.SECONDS);
+        }
+        if (receivingFuture != null) {
+            receivingFuture.get(5, TimeUnit.SECONDS);
+        }
     }
 
     @Test
     @Parameters(method = "getParams")
     public void throughput(Config config) throws Exception {
-        setup(config);
-        BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
-        while (state.keepRunning()) {
-          recording.set(true);
-          while (bytesCounter.get() < config.messageSize()) {
-          }
-          bytesCounter.set(0);
-          recording.set(false);
+        try {
+            setup(config);
+            BenchmarkState state = mPerfStatusReporter.getBenchmarkState();
+            while (state.keepRunning()) {
+                recording.set(true);
+                while (bytesCounter.get() < config.messageSize()) {
+                }
+                bytesCounter.set(0);
+                recording.set(false);
+            }
+        } finally {
+            close();
         }
+    }
+
+    @After
+    public void tearDown() throws Exception {
         close();
     }